Skip to content

Commit ae57712

Browse files
l46kokcopybara-github
authored andcommitted
Fix wrapper types to properly unwrap in lists
PiperOrigin-RevId: 862902695
1 parent 83e91cc commit ae57712

3 files changed

Lines changed: 124 additions & 10 deletions

File tree

common/src/main/java/dev/cel/common/internal/ProtoAdapter.java

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ public Optional<Object> adaptFieldToValue(FieldDescriptor fieldDescriptor, Objec
192192
if (bidiConverter == BidiConverter.IDENTITY) {
193193
return Optional.of(fieldValue);
194194
}
195-
return Optional.of(AdaptingTypes.adaptingList((List<?>) fieldValue, bidiConverter));
195+
ArrayList<?> convertedList =
196+
new ArrayList<>(AdaptingTypes.adaptingList((List<?>) fieldValue, bidiConverter));
197+
return Optional.of(convertedList);
196198
}
197199

198200
return Optional.of(
@@ -244,28 +246,48 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) {
244246
case SFIXED32:
245247
case SINT32:
246248
case INT32:
247-
return INT_CONVERTER;
249+
return unwrapAndConvert(INT_CONVERTER);
248250
case FIXED32:
249251
case UINT32:
250252
if (celOptions.enableUnsignedLongs()) {
251-
return UNSIGNED_UINT32_CONVERTER;
253+
return unwrapAndConvert(UNSIGNED_UINT32_CONVERTER);
252254
}
253-
return SIGNED_UINT32_CONVERTER;
255+
return unwrapAndConvert(SIGNED_UINT32_CONVERTER);
254256
case FIXED64:
255257
case UINT64:
256258
if (celOptions.enableUnsignedLongs()) {
257-
return UNSIGNED_UINT64_CONVERTER;
259+
return unwrapAndConvert(UNSIGNED_UINT64_CONVERTER);
258260
}
259-
return BidiConverter.IDENTITY;
261+
return BidiConverter.of(
262+
BidiConverter.IDENTITY.forwardConverter(),
263+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
260264
case FLOAT:
261-
return DOUBLE_CONVERTER;
265+
return unwrapAndConvert(DOUBLE_CONVERTER);
266+
case DOUBLE:
267+
case SFIXED64:
268+
case SINT64:
269+
case INT64:
270+
return BidiConverter.of(
271+
BidiConverter.IDENTITY.forwardConverter(),
272+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
262273
case BYTES:
263274
if (celOptions.evaluateCanonicalTypesToNativeValues()) {
264275
return BidiConverter.<Object, Object>of(
265-
ProtoAdapter::adaptProtoByteStringToValue, ProtoAdapter::adaptCelByteStringToProto);
276+
ProtoAdapter::adaptProtoByteStringToValue,
277+
value -> adaptCelByteStringToProto(maybeUnwrap(value)));
266278
}
267279

268-
return BidiConverter.IDENTITY;
280+
return BidiConverter.of(
281+
BidiConverter.IDENTITY.forwardConverter(),
282+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
283+
case STRING:
284+
return BidiConverter.of(
285+
BidiConverter.IDENTITY.forwardConverter(),
286+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
287+
case BOOL:
288+
return BidiConverter.of(
289+
BidiConverter.IDENTITY.forwardConverter(),
290+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
269291
case ENUM:
270292
return BidiConverter.<Object, Long>of(
271293
value -> (long) ((EnumValueDescriptor) value).getNumber(),
@@ -371,4 +393,18 @@ private static int unsignedIntCheckedCast(long value) {
371393
throw new CelNumericOverflowException(e);
372394
}
373395
}
396+
397+
private Object maybeUnwrap(Object value) {
398+
if (value instanceof Message) {
399+
return adaptProtoToValue((MessageOrBuilder) value);
400+
}
401+
return value;
402+
}
403+
404+
private BidiConverter<Number, Object> unwrapAndConvert(
405+
final BidiConverter<Number, Number> original) {
406+
return BidiConverter.of(
407+
original.forwardConverter()::convert,
408+
value -> original.backwardConverter().convert((Number) maybeUnwrap(value)));
409+
}
374410
}

runtime/src/test/resources/wrappers.baseline

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,47 @@ declare dyn_var {
154154
bindings: {dyn_var=NULL_VALUE}
155155
result: NULL_VALUE
156156

157+
Source: TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']
158+
declare int32_list {
159+
value list(int)
160+
}
161+
declare int64_list {
162+
value list(int)
163+
}
164+
declare uint32_list {
165+
value list(uint)
166+
}
167+
declare uint64_list {
168+
value list(uint)
169+
}
170+
declare float_list {
171+
value list(double)
172+
}
173+
declare double_list {
174+
value list(double)
175+
}
176+
declare bool_list {
177+
value list(bool)
178+
}
179+
declare string_list {
180+
value list(string)
181+
}
182+
declare bytes_list {
183+
value list(bytes)
184+
}
185+
=====>
186+
bindings: {int32_list=[value: 1
187+
], int64_list=[value: 2
188+
], uint32_list=[value: 3
189+
], uint64_list=[value: 4
190+
], float_list=[value: 5.5
191+
], double_list=[value: 6.6
192+
], bool_list=[value: true
193+
], string_list=[value: "hello"
194+
], bytes_list=[value: "world"
195+
]}
196+
result: true
197+
157198
Source: google.protobuf.Timestamp{ seconds: 253402300800 }
158199
=====>
159200
bindings: {}

testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2058,6 +2058,42 @@ public void wrappers() throws Exception {
20582058
source = "dyn_var";
20592059
runTest(ImmutableMap.of("dyn_var", NullValue.NULL_VALUE));
20602060

2061+
clearAllDeclarations();
2062+
declareVariable("int32_list", ListType.create(SimpleType.INT));
2063+
declareVariable("int64_list", ListType.create(SimpleType.INT));
2064+
declareVariable("uint32_list", ListType.create(SimpleType.UINT));
2065+
declareVariable("uint64_list", ListType.create(SimpleType.UINT));
2066+
declareVariable("float_list", ListType.create(SimpleType.DOUBLE));
2067+
declareVariable("double_list", ListType.create(SimpleType.DOUBLE));
2068+
declareVariable("bool_list", ListType.create(SimpleType.BOOL));
2069+
declareVariable("string_list", ListType.create(SimpleType.STRING));
2070+
declareVariable("bytes_list", ListType.create(SimpleType.BYTES));
2071+
2072+
container = CelContainer.ofName(TestAllTypes.getDescriptor().getFullName());
2073+
source =
2074+
"TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && "
2075+
+ "TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && "
2076+
+ "TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && "
2077+
+ "TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && "
2078+
+ "TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && "
2079+
+ "TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && "
2080+
+ "TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && "
2081+
+ "TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && "
2082+
+ "TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']";
2083+
2084+
runTest(
2085+
ImmutableMap.<String, Object>builder()
2086+
.put("int32_list", ImmutableList.of(Int32Value.of(1)))
2087+
.put("int64_list", ImmutableList.of(Int64Value.of(2)))
2088+
.put("uint32_list", ImmutableList.of(UInt32Value.of(3)))
2089+
.put("uint64_list", ImmutableList.of(UInt64Value.of(4)))
2090+
.put("float_list", ImmutableList.of(FloatValue.of(5.5f)))
2091+
.put("double_list", ImmutableList.of(DoubleValue.of(6.6)))
2092+
.put("bool_list", ImmutableList.of(BoolValue.of(true)))
2093+
.put("string_list", ImmutableList.of(StringValue.of("hello")))
2094+
.put("bytes_list", ImmutableList.of(BytesValue.of(ByteString.copyFromUtf8("world"))))
2095+
.buildOrThrow());
2096+
20612097
clearAllDeclarations();
20622098
// Currently allowed, but will be an error
20632099
// See https://github.com/google/cel-spec/pull/501
@@ -2068,7 +2104,8 @@ public void wrappers() throws Exception {
20682104
@Test
20692105
public void longComprehension() {
20702106
ImmutableList<Long> l = LongStream.range(0L, 1000L).boxed().collect(toImmutableList());
2071-
addFunctionBinding(CelFunctionBinding.from("constantLongList", ImmutableList.of(), unused -> l));
2107+
addFunctionBinding(
2108+
CelFunctionBinding.from("constantLongList", ImmutableList.of(), unused -> l));
20722109

20732110
// Comprehension over compile-time constant long list.
20742111
declareFunction(

0 commit comments

Comments
 (0)