diff --git a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java index 3c3382ef2..962a9d2e9 100644 --- a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java +++ b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java @@ -192,7 +192,9 @@ public Optional adaptFieldToValue(FieldDescriptor fieldDescriptor, Objec if (bidiConverter == BidiConverter.IDENTITY) { return Optional.of(fieldValue); } - return Optional.of(AdaptingTypes.adaptingList((List) fieldValue, bidiConverter)); + ArrayList convertedList = + new ArrayList<>(AdaptingTypes.adaptingList((List) fieldValue, bidiConverter)); + return Optional.of(convertedList); } return Optional.of( @@ -244,28 +246,48 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) { case SFIXED32: case SINT32: case INT32: - return INT_CONVERTER; + return unwrapAndConvert(INT_CONVERTER); case FIXED32: case UINT32: if (celOptions.enableUnsignedLongs()) { - return UNSIGNED_UINT32_CONVERTER; + return unwrapAndConvert(UNSIGNED_UINT32_CONVERTER); } - return SIGNED_UINT32_CONVERTER; + return unwrapAndConvert(SIGNED_UINT32_CONVERTER); case FIXED64: case UINT64: if (celOptions.enableUnsignedLongs()) { - return UNSIGNED_UINT64_CONVERTER; + return unwrapAndConvert(UNSIGNED_UINT64_CONVERTER); } - return BidiConverter.IDENTITY; + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value))); case FLOAT: - return DOUBLE_CONVERTER; + return unwrapAndConvert(DOUBLE_CONVERTER); + case DOUBLE: + case SFIXED64: + case SINT64: + case INT64: + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value))); case BYTES: if (celOptions.evaluateCanonicalTypesToNativeValues()) { return BidiConverter.of( - ProtoAdapter::adaptProtoByteStringToValue, ProtoAdapter::adaptCelByteStringToProto); + ProtoAdapter::adaptProtoByteStringToValue, + value -> adaptCelByteStringToProto(maybeUnwrap(value))); } - return BidiConverter.IDENTITY; + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value))); + case STRING: + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value))); + case BOOL: + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value))); case ENUM: return BidiConverter.of( value -> (long) ((EnumValueDescriptor) value).getNumber(), @@ -371,4 +393,18 @@ private static int unsignedIntCheckedCast(long value) { throw new CelNumericOverflowException(e); } } + + private Object maybeUnwrap(Object value) { + if (value instanceof Message) { + return adaptProtoToValue((MessageOrBuilder) value); + } + return value; + } + + private BidiConverter unwrapAndConvert( + final BidiConverter original) { + return BidiConverter.of( + original.forwardConverter()::convert, + value -> original.backwardConverter().convert((Number) maybeUnwrap(value))); + } } diff --git a/runtime/src/test/resources/wrappers.baseline b/runtime/src/test/resources/wrappers.baseline index a971dcb29..d8212059b 100644 --- a/runtime/src/test/resources/wrappers.baseline +++ b/runtime/src/test/resources/wrappers.baseline @@ -154,6 +154,47 @@ declare dyn_var { bindings: {dyn_var=NULL_VALUE} result: NULL_VALUE +Source: TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world'] +declare int32_list { + value list(int) +} +declare int64_list { + value list(int) +} +declare uint32_list { + value list(uint) +} +declare uint64_list { + value list(uint) +} +declare float_list { + value list(double) +} +declare double_list { + value list(double) +} +declare bool_list { + value list(bool) +} +declare string_list { + value list(string) +} +declare bytes_list { + value list(bytes) +} +=====> +bindings: {int32_list=[value: 1 +], int64_list=[value: 2 +], uint32_list=[value: 3 +], uint64_list=[value: 4 +], float_list=[value: 5.5 +], double_list=[value: 6.6 +], bool_list=[value: true +], string_list=[value: "hello" +], bytes_list=[value: "world" +]} +result: true + Source: google.protobuf.Timestamp{ seconds: 253402300800 } =====> bindings: {} diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index 425271e1b..54cd86cfa 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -2058,6 +2058,42 @@ public void wrappers() throws Exception { source = "dyn_var"; runTest(ImmutableMap.of("dyn_var", NullValue.NULL_VALUE)); + clearAllDeclarations(); + declareVariable("int32_list", ListType.create(SimpleType.INT)); + declareVariable("int64_list", ListType.create(SimpleType.INT)); + declareVariable("uint32_list", ListType.create(SimpleType.UINT)); + declareVariable("uint64_list", ListType.create(SimpleType.UINT)); + declareVariable("float_list", ListType.create(SimpleType.DOUBLE)); + declareVariable("double_list", ListType.create(SimpleType.DOUBLE)); + declareVariable("bool_list", ListType.create(SimpleType.BOOL)); + declareVariable("string_list", ListType.create(SimpleType.STRING)); + declareVariable("bytes_list", ListType.create(SimpleType.BYTES)); + + container = CelContainer.ofName(TestAllTypes.getDescriptor().getFullName()); + source = + "TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && " + + "TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && " + + "TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && " + + "TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && " + + "TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && " + + "TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && " + + "TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && " + + "TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && " + + "TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']"; + + runTest( + ImmutableMap.builder() + .put("int32_list", ImmutableList.of(Int32Value.of(1))) + .put("int64_list", ImmutableList.of(Int64Value.of(2))) + .put("uint32_list", ImmutableList.of(UInt32Value.of(3))) + .put("uint64_list", ImmutableList.of(UInt64Value.of(4))) + .put("float_list", ImmutableList.of(FloatValue.of(5.5f))) + .put("double_list", ImmutableList.of(DoubleValue.of(6.6))) + .put("bool_list", ImmutableList.of(BoolValue.of(true))) + .put("string_list", ImmutableList.of(StringValue.of("hello"))) + .put("bytes_list", ImmutableList.of(BytesValue.of(ByteString.copyFromUtf8("world")))) + .buildOrThrow()); + clearAllDeclarations(); // Currently allowed, but will be an error // See https://github.com/google/cel-spec/pull/501 @@ -2068,7 +2104,8 @@ public void wrappers() throws Exception { @Test public void longComprehension() { ImmutableList l = LongStream.range(0L, 1000L).boxed().collect(toImmutableList()); - addFunctionBinding(CelFunctionBinding.from("constantLongList", ImmutableList.of(), unused -> l)); + addFunctionBinding( + CelFunctionBinding.from("constantLongList", ImmutableList.of(), unused -> l)); // Comprehension over compile-time constant long list. declareFunction(