Skip to content

Commit 328fee5

Browse files
committed
Fix wrapper adaptations
1 parent ff8cd30 commit 328fee5

4 files changed

Lines changed: 190 additions & 8 deletions

File tree

bundle/src/test/java/dev/cel/bundle/CelImplTest.java

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import static dev.cel.common.CelOverloadDecl.newMemberOverload;
2222
import static org.junit.Assert.assertThrows;
2323

24+
import com.google.common.primitives.UnsignedLong;
25+
import com.google.protobuf.UInt64Value;
2426
import dev.cel.expr.CheckedExpr;
2527
import dev.cel.expr.Constant;
2628
import dev.cel.expr.Decl;
@@ -2193,6 +2195,74 @@ public void toBuilder_isImmutable() {
21932195
assertThat(newRuntimeBuilder).isNotEqualTo(celImpl.toRuntimeBuilder());
21942196
}
21952197

2198+
@Test
2199+
public void mapSelection_uintWrapper() throws Exception {
2200+
Cel cel = CelFactory.standardCelBuilder()
2201+
.addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN))
2202+
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
2203+
.addMessageTypes(TestAllTypes.getDescriptor())
2204+
.build();
2205+
2206+
CelAbstractSyntaxTree ast = cel.compile("args.i[1]").getAst();
2207+
2208+
Object result = cel.createProgram(ast).eval(
2209+
ImmutableMap.of("args",
2210+
ImmutableMap.of("i", ImmutableMap.of(1L, UInt64Value.of(123L)))));
2211+
2212+
assertThat(result).isEqualTo(UnsignedLong.valueOf(123L));
2213+
}
2214+
2215+
@Test
2216+
public void messageCreation_listContainsUintWrapperCreation() throws Exception {
2217+
Cel cel = CelFactory.standardCelBuilder()
2218+
.addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN))
2219+
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
2220+
.addMessageTypes(TestAllTypes.getDescriptor())
2221+
.build();
2222+
2223+
CelAbstractSyntaxTree ast = cel.compile("TestAllTypes{repeated_uint64: [google.protobuf.UInt64Value{value: 123u}]}").getAst();
2224+
2225+
Object result = cel.createProgram(ast).eval(
2226+
ImmutableMap.of("args",
2227+
ImmutableMap.of("i", ImmutableList.of(UInt64Value.of(123L)))));
2228+
2229+
assertThat(result).isEqualTo(TestAllTypes.newBuilder().addRepeatedUint64(123L).build());
2230+
}
2231+
2232+
@Test
2233+
public void messageCreation_listContainsUintWrapper() throws Exception {
2234+
Cel cel = CelFactory.standardCelBuilder()
2235+
.addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN))
2236+
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
2237+
.addMessageTypes(TestAllTypes.getDescriptor())
2238+
.build();
2239+
2240+
CelAbstractSyntaxTree ast = cel.compile("TestAllTypes{repeated_uint64: args.i}").getAst();
2241+
2242+
Object result = cel.createProgram(ast).eval(
2243+
ImmutableMap.of("args",
2244+
ImmutableMap.of("i", ImmutableList.of(UInt64Value.of(123L)))));
2245+
2246+
assertThat(result).isEqualTo(TestAllTypes.newBuilder().addRepeatedUint64(123L).build());
2247+
}
2248+
2249+
@Test
2250+
public void messageCreation_mapContainsUintWrapper() throws Exception {
2251+
Cel cel = CelFactory.standardCelBuilder()
2252+
.addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN))
2253+
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
2254+
.addMessageTypes(TestAllTypes.getDescriptor())
2255+
.build();
2256+
2257+
CelAbstractSyntaxTree ast = cel.compile("TestAllTypes{map_int64_uint64 : args.i}").getAst();
2258+
2259+
Object result = cel.createProgram(ast).eval(
2260+
ImmutableMap.of("args",
2261+
ImmutableMap.of("i", ImmutableMap.of(1L, UInt64Value.of(123L)))));
2262+
2263+
assertThat(result).isEqualTo(TestAllTypes.newBuilder().putMapInt64Uint64(1L, 123L).build());
2264+
}
2265+
21962266
private static TypeProvider aliasingProvider(ImmutableMap<String, Type> typeAliases) {
21972267
return new TypeProvider() {
21982268
@Override

common/src/main/java/dev/cel/common/internal/ProtoAdapter.java

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.google.protobuf.InvalidProtocolBufferException;
3131
import com.google.protobuf.MapEntry;
3232
import com.google.protobuf.Message;
33+
import com.google.protobuf.MessageLite;
3334
import com.google.protobuf.MessageOrBuilder;
3435
import dev.cel.common.CelOptions;
3536
import dev.cel.common.annotations.Internal;
@@ -244,28 +245,48 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) {
244245
case SFIXED32:
245246
case SINT32:
246247
case INT32:
247-
return INT_CONVERTER;
248+
return unwrapAndConvert(INT_CONVERTER);
248249
case FIXED32:
249250
case UINT32:
250251
if (celOptions.enableUnsignedLongs()) {
251-
return UNSIGNED_UINT32_CONVERTER;
252+
return unwrapAndConvert(UNSIGNED_UINT32_CONVERTER);
252253
}
253-
return SIGNED_UINT32_CONVERTER;
254+
return unwrapAndConvert(SIGNED_UINT32_CONVERTER);
254255
case FIXED64:
255256
case UINT64:
256257
if (celOptions.enableUnsignedLongs()) {
257-
return UNSIGNED_UINT64_CONVERTER;
258+
return unwrapAndConvert(UNSIGNED_UINT64_CONVERTER);
258259
}
259-
return BidiConverter.IDENTITY;
260+
return BidiConverter.of(
261+
BidiConverter.IDENTITY.forwardConverter(),
262+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
260263
case FLOAT:
261-
return DOUBLE_CONVERTER;
264+
return unwrapAndConvert(DOUBLE_CONVERTER);
265+
case DOUBLE:
266+
case SFIXED64:
267+
case SINT64:
268+
case INT64:
269+
return BidiConverter.of(
270+
BidiConverter.IDENTITY.forwardConverter(),
271+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
262272
case BYTES:
263273
if (celOptions.evaluateCanonicalTypesToNativeValues()) {
264274
return BidiConverter.<Object, Object>of(
265-
ProtoAdapter::adaptProtoByteStringToValue, ProtoAdapter::adaptCelByteStringToProto);
275+
ProtoAdapter::adaptProtoByteStringToValue,
276+
value -> adaptCelByteStringToProto(unwrap(value)));
266277
}
267278

268-
return BidiConverter.IDENTITY;
279+
return BidiConverter.of(
280+
BidiConverter.IDENTITY.forwardConverter(),
281+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
282+
case STRING:
283+
return BidiConverter.of(
284+
BidiConverter.IDENTITY.forwardConverter(),
285+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
286+
case BOOL:
287+
return BidiConverter.of(
288+
BidiConverter.IDENTITY.forwardConverter(),
289+
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
269290
case ENUM:
270291
return BidiConverter.<Object, Long>of(
271292
value -> (long) ((EnumValueDescriptor) value).getNumber(),
@@ -371,4 +392,18 @@ private static int unsignedIntCheckedCast(long value) {
371392
throw new CelNumericOverflowException(e);
372393
}
373394
}
395+
396+
private Object unwrap(Object value) {
397+
if (value instanceof MessageLite) {
398+
return adaptProtoToValue((MessageOrBuilder) value);
399+
}
400+
return value;
401+
}
402+
403+
private BidiConverter<Number, Object> unwrapAndConvert(
404+
final BidiConverter<Number, Number> original) {
405+
return BidiConverter.of(
406+
original.forwardConverter()::convert,
407+
value -> original.backwardConverter().convert((Number) unwrap(value)));
408+
}
374409
}

runtime/src/test/resources/wrappers.baseline

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,47 @@ declare dyn_var {
154154
bindings: {dyn_var=NULL_VALUE}
155155
result: NULL_VALUE
156156

157+
Source: TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']
158+
declare int32_list {
159+
value list(int)
160+
}
161+
declare int64_list {
162+
value list(int)
163+
}
164+
declare uint32_list {
165+
value list(uint)
166+
}
167+
declare uint64_list {
168+
value list(uint)
169+
}
170+
declare float_list {
171+
value list(double)
172+
}
173+
declare double_list {
174+
value list(double)
175+
}
176+
declare bool_list {
177+
value list(bool)
178+
}
179+
declare string_list {
180+
value list(string)
181+
}
182+
declare bytes_list {
183+
value list(bytes)
184+
}
185+
=====>
186+
bindings: {int32_list=[value: 1
187+
], int64_list=[value: 2
188+
], uint32_list=[value: 3
189+
], uint64_list=[value: 4
190+
], float_list=[value: 5.5
191+
], double_list=[value: 6.6
192+
], bool_list=[value: true
193+
], string_list=[value: "hello"
194+
], bytes_list=[value: "world"
195+
]}
196+
result: true
197+
157198
Source: google.protobuf.Timestamp{ seconds: 253402300800 }
158199
=====>
159200
bindings: {}

testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,6 +2058,42 @@ public void wrappers() throws Exception {
20582058
source = "dyn_var";
20592059
runTest(ImmutableMap.of("dyn_var", NullValue.NULL_VALUE));
20602060

2061+
clearAllDeclarations();
2062+
declareVariable("int32_list", ListType.create(SimpleType.INT));
2063+
declareVariable("int64_list", ListType.create(SimpleType.INT));
2064+
declareVariable("uint32_list", ListType.create(SimpleType.UINT));
2065+
declareVariable("uint64_list", ListType.create(SimpleType.UINT));
2066+
declareVariable("float_list", ListType.create(SimpleType.DOUBLE));
2067+
declareVariable("double_list", ListType.create(SimpleType.DOUBLE));
2068+
declareVariable("bool_list", ListType.create(SimpleType.BOOL));
2069+
declareVariable("string_list", ListType.create(SimpleType.STRING));
2070+
declareVariable("bytes_list", ListType.create(SimpleType.BYTES));
2071+
2072+
container = CelContainer.ofName(TestAllTypes.getDescriptor().getFullName());
2073+
source =
2074+
"TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && "
2075+
+ "TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && "
2076+
+ "TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && "
2077+
+ "TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && "
2078+
+ "TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && "
2079+
+ "TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && "
2080+
+ "TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && "
2081+
+ "TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && "
2082+
+ "TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']";
2083+
2084+
runTest(
2085+
ImmutableMap.<String, Object>builder()
2086+
.put("int32_list", ImmutableList.of(Int32Value.of(1)))
2087+
.put("int64_list", ImmutableList.of(Int64Value.of(2)))
2088+
.put("uint32_list", ImmutableList.of(UInt32Value.of(3)))
2089+
.put("uint64_list", ImmutableList.of(UInt64Value.of(4)))
2090+
.put("float_list", ImmutableList.of(FloatValue.of(5.5f)))
2091+
.put("double_list", ImmutableList.of(DoubleValue.of(6.6)))
2092+
.put("bool_list", ImmutableList.of(BoolValue.of(true)))
2093+
.put("string_list", ImmutableList.of(StringValue.of("hello")))
2094+
.put("bytes_list", ImmutableList.of(BytesValue.of(ByteString.copyFromUtf8("world"))))
2095+
.build());
2096+
20612097
clearAllDeclarations();
20622098
// Currently allowed, but will be an error
20632099
// See https://github.com/google/cel-spec/pull/501

0 commit comments

Comments
 (0)