Skip to content

Commit e6efa9e

Browse files
l46kokcopybara-github
authored andcommitted
Fix wrapper types to properly unwrap in lists
PiperOrigin-RevId: 862902695
1 parent 83e91cc commit e6efa9e

3 files changed

Lines changed: 122 additions & 9 deletions

File tree

common/src/main/java/dev/cel/common/internal/ProtoAdapter.java

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.google.protobuf.InvalidProtocolBufferException;
3131
import com.google.protobuf.MapEntry;
3232
import com.google.protobuf.Message;
33+
import com.google.protobuf.MessageLite;
3334
import com.google.protobuf.MessageOrBuilder;
3435
import dev.cel.common.CelOptions;
3536
import dev.cel.common.annotations.Internal;
@@ -244,28 +245,48 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) {
244245
case SFIXED32:
245246
case SINT32:
246247
case INT32:
247-
return INT_CONVERTER;
248+
return unwrapAndConvert(INT_CONVERTER);
248249
case FIXED32:
249250
case UINT32:
250251
if (celOptions.enableUnsignedLongs()) {
251-
return UNSIGNED_UINT32_CONVERTER;
252+
return unwrapAndConvert(UNSIGNED_UINT32_CONVERTER);
252253
}
253-
return SIGNED_UINT32_CONVERTER;
254+
return unwrapAndConvert(SIGNED_UINT32_CONVERTER);
254255
case FIXED64:
255256
case UINT64:
256257
if (celOptions.enableUnsignedLongs()) {
257-
return UNSIGNED_UINT64_CONVERTER;
258+
return unwrapAndConvert(UNSIGNED_UINT64_CONVERTER);
258259
}
259-
return BidiConverter.IDENTITY;
260+
return BidiConverter.of(
261+
BidiConverter.IDENTITY.forwardConverter(),
262+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
260263
case FLOAT:
261-
return DOUBLE_CONVERTER;
264+
return unwrapAndConvert(DOUBLE_CONVERTER);
265+
case DOUBLE:
266+
case SFIXED64:
267+
case SINT64:
268+
case INT64:
269+
return BidiConverter.of(
270+
BidiConverter.IDENTITY.forwardConverter(),
271+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
262272
case BYTES:
263273
if (celOptions.evaluateCanonicalTypesToNativeValues()) {
264274
return BidiConverter.<Object, Object>of(
265-
ProtoAdapter::adaptProtoByteStringToValue, ProtoAdapter::adaptCelByteStringToProto);
275+
ProtoAdapter::adaptProtoByteStringToValue,
276+
value -> adaptCelByteStringToProto(unwrap(value)));
266277
}
267278

268-
return BidiConverter.IDENTITY;
279+
return BidiConverter.of(
280+
BidiConverter.IDENTITY.forwardConverter(),
281+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
282+
case STRING:
283+
return BidiConverter.of(
284+
BidiConverter.IDENTITY.forwardConverter(),
285+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
286+
case BOOL:
287+
return BidiConverter.of(
288+
BidiConverter.IDENTITY.forwardConverter(),
289+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
269290
case ENUM:
270291
return BidiConverter.<Object, Long>of(
271292
value -> (long) ((EnumValueDescriptor) value).getNumber(),
@@ -371,4 +392,18 @@ private static int unsignedIntCheckedCast(long value) {
371392
throw new CelNumericOverflowException(e);
372393
}
373394
}
395+
396+
private Object unwrap(Object value) {
397+
if (value instanceof MessageLite) {
398+
return adaptProtoToValue((MessageOrBuilder) value);
399+
}
400+
return value;
401+
}
402+
403+
private BidiConverter<Number, Object> unwrapAndConvert(
404+
final BidiConverter<Number, Number> original) {
405+
return BidiConverter.of(
406+
original.forwardConverter()::convert,
407+
value -> original.backwardConverter().convert((Number) unwrap(value)));
408+
}
374409
}

runtime/src/test/resources/wrappers.baseline

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,47 @@ declare dyn_var {
154154
bindings: {dyn_var=NULL_VALUE}
155155
result: NULL_VALUE
156156

157+
Source: TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']
158+
declare int32_list {
159+
value list(int)
160+
}
161+
declare int64_list {
162+
value list(int)
163+
}
164+
declare uint32_list {
165+
value list(uint)
166+
}
167+
declare uint64_list {
168+
value list(uint)
169+
}
170+
declare float_list {
171+
value list(double)
172+
}
173+
declare double_list {
174+
value list(double)
175+
}
176+
declare bool_list {
177+
value list(bool)
178+
}
179+
declare string_list {
180+
value list(string)
181+
}
182+
declare bytes_list {
183+
value list(bytes)
184+
}
185+
=====>
186+
bindings: {int32_list=[value: 1
187+
], int64_list=[value: 2
188+
], uint32_list=[value: 3
189+
], uint64_list=[value: 4
190+
], float_list=[value: 5.5
191+
], double_list=[value: 6.6
192+
], bool_list=[value: true
193+
], string_list=[value: "hello"
194+
], bytes_list=[value: "world"
195+
]}
196+
result: true
197+
157198
Source: google.protobuf.Timestamp{ seconds: 253402300800 }
158199
=====>
159200
bindings: {}

testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2058,6 +2058,42 @@ public void wrappers() throws Exception {
20582058
source = "dyn_var";
20592059
runTest(ImmutableMap.of("dyn_var", NullValue.NULL_VALUE));
20602060

2061+
clearAllDeclarations();
2062+
declareVariable("int32_list", ListType.create(SimpleType.INT));
2063+
declareVariable("int64_list", ListType.create(SimpleType.INT));
2064+
declareVariable("uint32_list", ListType.create(SimpleType.UINT));
2065+
declareVariable("uint64_list", ListType.create(SimpleType.UINT));
2066+
declareVariable("float_list", ListType.create(SimpleType.DOUBLE));
2067+
declareVariable("double_list", ListType.create(SimpleType.DOUBLE));
2068+
declareVariable("bool_list", ListType.create(SimpleType.BOOL));
2069+
declareVariable("string_list", ListType.create(SimpleType.STRING));
2070+
declareVariable("bytes_list", ListType.create(SimpleType.BYTES));
2071+
2072+
container = CelContainer.ofName(TestAllTypes.getDescriptor().getFullName());
2073+
source =
2074+
"TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && "
2075+
+ "TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && "
2076+
+ "TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && "
2077+
+ "TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && "
2078+
+ "TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && "
2079+
+ "TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && "
2080+
+ "TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && "
2081+
+ "TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && "
2082+
+ "TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']";
2083+
2084+
runTest(
2085+
ImmutableMap.<String, Object>builder()
2086+
.put("int32_list", ImmutableList.of(Int32Value.of(1)))
2087+
.put("int64_list", ImmutableList.of(Int64Value.of(2)))
2088+
.put("uint32_list", ImmutableList.of(UInt32Value.of(3)))
2089+
.put("uint64_list", ImmutableList.of(UInt64Value.of(4)))
2090+
.put("float_list", ImmutableList.of(FloatValue.of(5.5f)))
2091+
.put("double_list", ImmutableList.of(DoubleValue.of(6.6)))
2092+
.put("bool_list", ImmutableList.of(BoolValue.of(true)))
2093+
.put("string_list", ImmutableList.of(StringValue.of("hello")))
2094+
.put("bytes_list", ImmutableList.of(BytesValue.of(ByteString.copyFromUtf8("world"))))
2095+
.build());
2096+
20612097
clearAllDeclarations();
20622098
// Currently allowed, but will be an error
20632099
// See https://github.com/google/cel-spec/pull/501
@@ -2068,7 +2104,8 @@ public void wrappers() throws Exception {
20682104
@Test
20692105
public void longComprehension() {
20702106
ImmutableList<Long> l = LongStream.range(0L, 1000L).boxed().collect(toImmutableList());
2071-
addFunctionBinding(CelFunctionBinding.from("constantLongList", ImmutableList.of(), unused -> l));
2107+
addFunctionBinding(
2108+
CelFunctionBinding.from("constantLongList", ImmutableList.of(), unused -> l));
20722109

20732110
// Comprehension over compile-time constant long list.
20742111
declareFunction(

0 commit comments

Comments
 (0)