Skip to content

Commit cad23d1

Browse files
l46kokcopybara-github
authored andcommitted
Allow resolution of proto messages from dyn-typed functions in lite runtime
PiperOrigin-RevId: 761247740
1 parent 56717e4 commit cad23d1

5 files changed

Lines changed: 59 additions & 1 deletion

File tree

common/src/main/java/dev/cel/common/internal/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ java_library(
314314
deps = [
315315
"//protobuf:cel_lite_descriptor",
316316
"@maven//:com_google_errorprone_error_prone_annotations",
317+
"@maven_android//:com_google_protobuf_protobuf_javalite",
317318
],
318319
)
319320

@@ -325,6 +326,7 @@ cel_android_library(
325326
deps = [
326327
"//protobuf:cel_lite_descriptor",
327328
"@maven//:com_google_errorprone_error_prone_annotations",
329+
"@maven_android//:com_google_protobuf_protobuf_javalite",
328330
],
329331
)
330332

common/src/main/java/dev/cel/common/internal/CelLiteDescriptorPool.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package dev.cel.common.internal;
1616

1717
import com.google.errorprone.annotations.Immutable;
18+
import com.google.protobuf.MessageLite;
1819
import dev.cel.protobuf.CelLiteDescriptor.MessageLiteDescriptor;
1920
import java.util.Optional;
2021

@@ -25,5 +26,7 @@
2526
public interface CelLiteDescriptorPool {
2627
Optional<MessageLiteDescriptor> findDescriptor(String protoTypeName);
2728

29+
Optional<MessageLiteDescriptor> findDescriptor(MessageLite messageLite);
30+
2831
MessageLiteDescriptor getDescriptorOrThrow(String protoTypeName);
2932
}

common/src/main/java/dev/cel/common/internal/DefaultLiteDescriptorPool.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
@Internal
5151
public final class DefaultLiteDescriptorPool implements CelLiteDescriptorPool {
5252
private final ImmutableMap<String, MessageLiteDescriptor> protoFqnToMessageInfo;
53+
private final ImmutableMap<Class<?>, MessageLiteDescriptor> classToMessageInfo;
5354

5455
public static DefaultLiteDescriptorPool newInstance(CelLiteDescriptor... descriptors) {
5556
return newInstance(ImmutableSet.copyOf(descriptors));
@@ -59,6 +60,11 @@ public static DefaultLiteDescriptorPool newInstance(ImmutableSet<CelLiteDescript
5960
return new DefaultLiteDescriptorPool(descriptors);
6061
}
6162

63+
@Override
64+
public Optional<MessageLiteDescriptor> findDescriptor(MessageLite messageLite) {
65+
return Optional.ofNullable(classToMessageInfo.get(messageLite.getClass()));
66+
}
67+
6268
@Override
6369
public Optional<MessageLiteDescriptor> findDescriptor(String protoTypeName) {
6470
return Optional.ofNullable(protoFqnToMessageInfo.get(protoTypeName));
@@ -292,15 +298,28 @@ private static FieldLiteDescriptor newPrimitiveFieldDescriptor(
292298

293299
private DefaultLiteDescriptorPool(ImmutableSet<CelLiteDescriptor> descriptors) {
294300
ImmutableMap.Builder<String, MessageLiteDescriptor> protoFqnMapBuilder = ImmutableMap.builder();
301+
ImmutableMap.Builder<Class<?>, MessageLiteDescriptor> classMapBuilder = ImmutableMap.builder();
295302
for (WellKnownProto wellKnownProto : WellKnownProto.values()) {
296303
MessageLiteDescriptor wktMessageInfo = newMessageInfo(wellKnownProto);
297304
protoFqnMapBuilder.put(wellKnownProto.typeName(), wktMessageInfo);
305+
classMapBuilder.put(wellKnownProto.messageClass(), wktMessageInfo);
298306
}
299307

300308
for (CelLiteDescriptor descriptor : descriptors) {
301309
protoFqnMapBuilder.putAll(descriptor.getProtoTypeNamesToDescriptors());
310+
311+
for (MessageLiteDescriptor messageLiteDescriptor :
312+
descriptor.getProtoTypeNamesToDescriptors().values()) {
313+
// Note: message builder is null for proto maps.
314+
Optional.ofNullable(messageLiteDescriptor.newMessageBuilder())
315+
.ifPresent(
316+
builder ->
317+
classMapBuilder.put(
318+
builder.getDefaultInstanceForType().getClass(), messageLiteDescriptor));
319+
}
302320
}
303321

304322
this.protoFqnToMessageInfo = protoFqnMapBuilder.buildOrThrow();
323+
this.classToMessageInfo = classMapBuilder.buildOrThrow();
305324
}
306325
}

common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import java.util.LinkedHashMap;
4545
import java.util.List;
4646
import java.util.Map;
47+
import java.util.NoSuchElementException;
4748
import java.util.TreeMap;
4849

4950
/**
@@ -360,7 +361,13 @@ public CelValue fromProtoMessageToCelValue(String protoTypeName, MessageLite msg
360361
checkNotNull(msg);
361362
checkNotNull(protoTypeName);
362363

363-
MessageLiteDescriptor descriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
364+
MessageLiteDescriptor descriptor =
365+
descriptorPool
366+
.findDescriptor(msg)
367+
.orElseThrow(
368+
() ->
369+
new NoSuchElementException(
370+
"Could not find a descriptor for: " + protoTypeName));
364371
WellKnownProto wellKnownProto =
365372
WellKnownProto.getByTypeName(descriptor.getProtoTypeName()).orElse(null);
366373

runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,4 +633,31 @@ public void eval_withLateBoundFunction() throws Exception {
633633

634634
assertThat(result).isEqualTo("hello world");
635635
}
636+
637+
@Test
638+
public void eval_dynFunctionReturnsProto() throws Exception {
639+
CelCompiler celCompiler =
640+
CelCompilerFactory.standardCelCompilerBuilder()
641+
.addFunctionDeclarations(
642+
CelFunctionDecl.newFunctionDeclaration(
643+
"func", CelOverloadDecl.newGlobalOverload("func_identity", SimpleType.DYN)))
644+
.build();
645+
CelLiteRuntime celRuntime =
646+
CelLiteRuntimeFactory.newLiteRuntimeBuilder()
647+
.setValueProvider(
648+
ProtoMessageLiteValueProvider.newInstance(
649+
TestAllTypesCelDescriptor.getDescriptor()))
650+
.addFunctionBindings(
651+
CelFunctionBinding.from(
652+
"func_identity",
653+
ImmutableList.of(),
654+
unused -> TestAllTypes.getDefaultInstance()))
655+
.build();
656+
657+
CelAbstractSyntaxTree ast = celCompiler.compile("func()").getAst();
658+
659+
TestAllTypes result = (TestAllTypes) celRuntime.createProgram(ast).eval();
660+
661+
assertThat(result).isEqualToDefaultInstance();
662+
}
636663
}

0 commit comments

Comments
 (0)