diff --git a/bundle/src/main/java/dev/cel/bundle/CelBuilder.java b/bundle/src/main/java/dev/cel/bundle/CelBuilder.java index 1dadaeb39..f603b479f 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelBuilder.java +++ b/bundle/src/main/java/dev/cel/bundle/CelBuilder.java @@ -165,6 +165,14 @@ public interface CelBuilder { @CanIgnoreReturnValue CelBuilder addFunctionBindings(Iterable bindings); + /** Adds bindings for functions that are allowed to be late-bound (resolved at execution time). */ + @CanIgnoreReturnValue + CelBuilder addLateBoundFunctions(String... lateBoundFunctionNames); + + /** Adds bindings for functions that are allowed to be late-bound (resolved at execution time). */ + @CanIgnoreReturnValue + CelBuilder addLateBoundFunctions(Iterable lateBoundFunctionNames); + /** Set the expected {@code resultType} for the type-checked expression. */ @CanIgnoreReturnValue CelBuilder setResultType(CelType resultType); diff --git a/bundle/src/main/java/dev/cel/bundle/CelImpl.java b/bundle/src/main/java/dev/cel/bundle/CelImpl.java index ae0ab2395..f6b985065 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelImpl.java +++ b/bundle/src/main/java/dev/cel/bundle/CelImpl.java @@ -281,6 +281,18 @@ public CelBuilder addFunctionBindings(Iterable lateBoundFunctionNames) { + runtimeBuilder.addLateBoundFunctions(lateBoundFunctionNames); + return this; + } + @Override public CelBuilder setResultType(CelType resultType) { checkNotNull(resultType); diff --git a/common/src/main/java/dev/cel/common/values/CelValueConverter.java b/common/src/main/java/dev/cel/common/values/CelValueConverter.java index 2af0a76cb..70d04acc8 100644 --- a/common/src/main/java/dev/cel/common/values/CelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/CelValueConverter.java @@ -117,7 +117,7 @@ protected Object normalizePrimitive(Object value) { } /** Adapts a {@link CelValue} to a plain old Java Object. */ - private static Object unwrap(CelValue celValue) { + private Object unwrap(CelValue celValue) { Preconditions.checkNotNull(celValue); if (celValue instanceof OptionalValue) { @@ -126,7 +126,7 @@ private static Object unwrap(CelValue celValue) { return Optional.empty(); } - return Optional.of(optionalValue.value()); + return Optional.of(maybeUnwrap(optionalValue.value())); } if (celValue instanceof ErrorValue) { diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index c017911f9..8a8786ce8 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -46,7 +46,6 @@ import dev.cel.optimizer.AstMutator; import dev.cel.optimizer.CelAstOptimizer; import dev.cel.optimizer.CelOptimizationException; -import dev.cel.runtime.CelAttribute.Qualifier; import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.PartialVars; @@ -683,8 +682,7 @@ private static Object evaluateExpr(Cel cel, CelNavigableMutableExpr navigableMut .allNodes() .filter(node -> node.getKind().equals(Kind.IDENT)) .map(node -> node.expr().ident().name()) - .filter(Qualifier::isLegalIdentifier) - .map(CelAttributePattern::create) + .map(CelAttributePattern::fromQualifiedIdentifier) .collect(toImmutableList()); CelAbstractSyntaxTree ast = CelAbstractSyntaxTree.newParsedAst( diff --git a/policy/src/test/java/dev/cel/policy/BUILD.bazel b/policy/src/test/java/dev/cel/policy/BUILD.bazel index 9106caf70..8a28caee1 100644 --- a/policy/src/test/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/test/java/dev/cel/policy/BUILD.bazel @@ -35,7 +35,7 @@ java_library( "//policy:validation_exception", "//runtime", "//runtime:function_binding", - "//runtime:late_function_binding", + "//testing:cel_runtime_flavor", "//testing/protos:single_file_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@maven//:com_google_guava_guava", diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index fec5f9b94..d5254571d 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -26,9 +26,9 @@ import com.google.testing.junit.testparameterinjector.TestParameterValue; import com.google.testing.junit.testparameterinjector.TestParameterValuesProvider; import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; import dev.cel.bundle.CelEnvironment; import dev.cel.bundle.CelEnvironmentYamlParser; -import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelOptions; import dev.cel.common.types.OptionalType; @@ -45,6 +45,7 @@ import dev.cel.policy.PolicyTestHelper.TestYamlPolicy; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; +import dev.cel.testing.CelRuntimeFlavor; import dev.cel.testing.testdata.SingleFile; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.io.IOException; @@ -61,7 +62,12 @@ public final class CelPolicyCompilerImplTest { private static final CelEnvironmentYamlParser ENVIRONMENT_PARSER = CelEnvironmentYamlParser.newInstance(); private static final CelOptions CEL_OPTIONS = - CelOptions.current().populateMacroCalls(true).build(); + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build(); + + @TestParameter public CelRuntimeFlavor runtimeFlavor; @Test public void compileYamlPolicy_success(@TestParameter TestYamlPolicy yamlPolicy) throws Exception { @@ -258,7 +264,6 @@ public void evaluateYamlPolicy_nestedRuleProducesOptionalOutput() throws Excepti CelPolicy policy = POLICY_PARSER.parse(policySource); CelAbstractSyntaxTree compiledPolicyAst = CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy); - Optional evalResult = (Optional) cel.createProgram(compiledPolicyAst).eval(); // Result is Optional> @@ -278,7 +283,12 @@ public void evaluateYamlPolicy_lateBoundFunction() throws Exception { + " return:\n" + " type_name: 'string'\n"; CelEnvironment celEnvironment = ENVIRONMENT_PARSER.parse(configSource); - Cel cel = celEnvironment.extend(newCel(), CelOptions.DEFAULT); + CelBuilder celBuilder = newCel().toCelBuilder(); + if (runtimeFlavor == CelRuntimeFlavor.PLANNER) { + celBuilder.addLateBoundFunctions("lateBoundFunc"); + } + Cel cel = celEnvironment.extend(celBuilder.build(), CEL_OPTIONS); + String policySource = "name: late_bound_function_policy\n" + "rule:\n" @@ -298,7 +308,6 @@ public void evaluateYamlPolicy_lateBoundFunction() throws Exception { (String) cel.createProgram(compiledPolicyAst) .eval((unused) -> Optional.empty(), lateFunctionBindings); - assertThat(evalResult).isEqualTo("foo" + exampleValue); } @@ -319,7 +328,6 @@ public void evaluateYamlPolicy_withSimpleVariable() throws Exception { CelAbstractSyntaxTree compiledPolicyAst = CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy); - boolean evalResult = (boolean) cel.createProgram(compiledPolicyAst).eval(); assertThat(evalResult).isFalse(); @@ -358,8 +366,9 @@ protected ImmutableList provideValues(Context context) throw } } - private static Cel newCel() { - return CelFactory.standardCelBuilder() + private Cel newCel() { + return runtimeFlavor + .builder() .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .addCompilerLibraries(CelOptionalLibrary.INSTANCE) .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) @@ -367,19 +376,21 @@ private static Cel newCel() { .addMessageTypes(TestAllTypes.getDescriptor(), SingleFile.getDescriptor()) .setOptions(CEL_OPTIONS) .addFunctionBindings( - CelFunctionBinding.from( - "locationCode_string", - String.class, - (ip) -> { - switch (ip) { - case "10.0.0.1": - return "us"; - case "10.0.0.2": - return "de"; - default: - return "ir"; - } - })) + CelFunctionBinding.fromOverloads( + "locationCode", + CelFunctionBinding.from( + "locationCode_string", + String.class, + (ip) -> { + switch (ip) { + case "10.0.0.1": + return "us"; + case "10.0.0.2": + return "de"; + default: + return "ir"; + } + }))) .build(); }