Skip to content

Commit 77bd891

Browse files
l46kokcopybara-github
authored andcommitted
Add planner test coverage for policy compilation
PiperOrigin-RevId: 901483162
1 parent 61a01d8 commit 77bd891

6 files changed

Lines changed: 56 additions & 27 deletions

File tree

bundle/src/main/java/dev/cel/bundle/CelBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,14 @@ public interface CelBuilder {
165165
@CanIgnoreReturnValue
166166
CelBuilder addFunctionBindings(Iterable<CelFunctionBinding> bindings);
167167

168+
/** Adds bindings for functions that are allowed to be late-bound (resolved at execution time). */
169+
@CanIgnoreReturnValue
170+
CelBuilder addLateBoundFunctions(String... lateBoundFunctionNames);
171+
172+
/** Adds bindings for functions that are allowed to be late-bound (resolved at execution time). */
173+
@CanIgnoreReturnValue
174+
CelBuilder addLateBoundFunctions(Iterable<String> lateBoundFunctionNames);
175+
168176
/** Set the expected {@code resultType} for the type-checked expression. */
169177
@CanIgnoreReturnValue
170178
CelBuilder setResultType(CelType resultType);

bundle/src/main/java/dev/cel/bundle/CelImpl.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,18 @@ public CelBuilder addFunctionBindings(Iterable<dev.cel.runtime.CelFunctionBindin
281281
return this;
282282
}
283283

284+
@Override
285+
public CelBuilder addLateBoundFunctions(String... lateBoundFunctionNames) {
286+
runtimeBuilder.addLateBoundFunctions(lateBoundFunctionNames);
287+
return this;
288+
}
289+
290+
@Override
291+
public CelBuilder addLateBoundFunctions(Iterable<String> lateBoundFunctionNames) {
292+
runtimeBuilder.addLateBoundFunctions(lateBoundFunctionNames);
293+
return this;
294+
}
295+
284296
@Override
285297
public CelBuilder setResultType(CelType resultType) {
286298
checkNotNull(resultType);

common/src/main/java/dev/cel/common/values/CelValueConverter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ protected Object normalizePrimitive(Object value) {
117117
}
118118

119119
/** Adapts a {@link CelValue} to a plain old Java Object. */
120-
private static Object unwrap(CelValue celValue) {
120+
private Object unwrap(CelValue celValue) {
121121
Preconditions.checkNotNull(celValue);
122122

123123
if (celValue instanceof OptionalValue) {
@@ -126,7 +126,7 @@ private static Object unwrap(CelValue celValue) {
126126
return Optional.empty();
127127
}
128128

129-
return Optional.of(optionalValue.value());
129+
return Optional.of(maybeUnwrap(optionalValue.value()));
130130
}
131131

132132
if (celValue instanceof ErrorValue) {

optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import dev.cel.optimizer.AstMutator;
4747
import dev.cel.optimizer.CelAstOptimizer;
4848
import dev.cel.optimizer.CelOptimizationException;
49-
import dev.cel.runtime.CelAttribute.Qualifier;
5049
import dev.cel.runtime.CelAttributePattern;
5150
import dev.cel.runtime.CelEvaluationException;
5251
import dev.cel.runtime.PartialVars;
@@ -683,8 +682,7 @@ private static Object evaluateExpr(Cel cel, CelNavigableMutableExpr navigableMut
683682
.allNodes()
684683
.filter(node -> node.getKind().equals(Kind.IDENT))
685684
.map(node -> node.expr().ident().name())
686-
.filter(Qualifier::isLegalIdentifier)
687-
.map(CelAttributePattern::create)
685+
.map(CelAttributePattern::fromQualifiedIdentifier)
688686
.collect(toImmutableList());
689687
CelAbstractSyntaxTree ast =
690688
CelAbstractSyntaxTree.newParsedAst(

policy/src/test/java/dev/cel/policy/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ java_library(
3535
"//policy:validation_exception",
3636
"//runtime",
3737
"//runtime:function_binding",
38-
"//runtime:late_function_binding",
38+
"//testing:cel_runtime_flavor",
3939
"//testing/protos:single_file_java_proto",
4040
"@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto",
4141
"@maven//:com_google_guava_guava",

policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
import com.google.testing.junit.testparameterinjector.TestParameterValue;
2727
import com.google.testing.junit.testparameterinjector.TestParameterValuesProvider;
2828
import dev.cel.bundle.Cel;
29+
import dev.cel.bundle.CelBuilder;
2930
import dev.cel.bundle.CelEnvironment;
3031
import dev.cel.bundle.CelEnvironmentYamlParser;
31-
import dev.cel.bundle.CelFactory;
3232
import dev.cel.common.CelAbstractSyntaxTree;
3333
import dev.cel.common.CelOptions;
3434
import dev.cel.common.types.OptionalType;
@@ -45,6 +45,7 @@
4545
import dev.cel.policy.PolicyTestHelper.TestYamlPolicy;
4646
import dev.cel.runtime.CelFunctionBinding;
4747
import dev.cel.runtime.CelLateFunctionBindings;
48+
import dev.cel.testing.CelRuntimeFlavor;
4849
import dev.cel.testing.testdata.SingleFile;
4950
import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum;
5051
import java.io.IOException;
@@ -61,7 +62,12 @@ public final class CelPolicyCompilerImplTest {
6162
private static final CelEnvironmentYamlParser ENVIRONMENT_PARSER =
6263
CelEnvironmentYamlParser.newInstance();
6364
private static final CelOptions CEL_OPTIONS =
64-
CelOptions.current().populateMacroCalls(true).build();
65+
CelOptions.current()
66+
.populateMacroCalls(true)
67+
.enableHeterogeneousNumericComparisons(true)
68+
.build();
69+
70+
@TestParameter public CelRuntimeFlavor runtimeFlavor;
6571

6672
@Test
6773
public void compileYamlPolicy_success(@TestParameter TestYamlPolicy yamlPolicy) throws Exception {
@@ -258,7 +264,6 @@ public void evaluateYamlPolicy_nestedRuleProducesOptionalOutput() throws Excepti
258264
CelPolicy policy = POLICY_PARSER.parse(policySource);
259265
CelAbstractSyntaxTree compiledPolicyAst =
260266
CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy);
261-
262267
Optional<Object> evalResult = (Optional<Object>) cel.createProgram(compiledPolicyAst).eval();
263268

264269
// Result is Optional<Optional<Object>>
@@ -278,7 +283,12 @@ public void evaluateYamlPolicy_lateBoundFunction() throws Exception {
278283
+ " return:\n"
279284
+ " type_name: 'string'\n";
280285
CelEnvironment celEnvironment = ENVIRONMENT_PARSER.parse(configSource);
281-
Cel cel = celEnvironment.extend(newCel(), CelOptions.DEFAULT);
286+
CelBuilder celBuilder = newCel().toCelBuilder();
287+
if (runtimeFlavor == CelRuntimeFlavor.PLANNER) {
288+
celBuilder.addLateBoundFunctions("lateBoundFunc");
289+
}
290+
Cel cel = celEnvironment.extend(celBuilder.build(), CEL_OPTIONS);
291+
282292
String policySource =
283293
"name: late_bound_function_policy\n"
284294
+ "rule:\n"
@@ -298,7 +308,6 @@ public void evaluateYamlPolicy_lateBoundFunction() throws Exception {
298308
(String)
299309
cel.createProgram(compiledPolicyAst)
300310
.eval((unused) -> Optional.empty(), lateFunctionBindings);
301-
302311
assertThat(evalResult).isEqualTo("foo" + exampleValue);
303312
}
304313

@@ -319,7 +328,6 @@ public void evaluateYamlPolicy_withSimpleVariable() throws Exception {
319328

320329
CelAbstractSyntaxTree compiledPolicyAst =
321330
CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy);
322-
323331
boolean evalResult = (boolean) cel.createProgram(compiledPolicyAst).eval();
324332

325333
assertThat(evalResult).isFalse();
@@ -358,28 +366,31 @@ protected ImmutableList<TestParameterValue> provideValues(Context context) throw
358366
}
359367
}
360368

361-
private static Cel newCel() {
362-
return CelFactory.standardCelBuilder()
369+
private Cel newCel() {
370+
return runtimeFlavor
371+
.builder()
363372
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
364373
.addCompilerLibraries(CelOptionalLibrary.INSTANCE)
365374
.addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
366375
.addFileTypes(StandaloneGlobalEnum.getDescriptor().getFile())
367376
.addMessageTypes(TestAllTypes.getDescriptor(), SingleFile.getDescriptor())
368377
.setOptions(CEL_OPTIONS)
369378
.addFunctionBindings(
370-
CelFunctionBinding.from(
371-
"locationCode_string",
372-
String.class,
373-
(ip) -> {
374-
switch (ip) {
375-
case "10.0.0.1":
376-
return "us";
377-
case "10.0.0.2":
378-
return "de";
379-
default:
380-
return "ir";
381-
}
382-
}))
379+
CelFunctionBinding.fromOverloads(
380+
"locationCode",
381+
CelFunctionBinding.from(
382+
"locationCode_string",
383+
String.class,
384+
(ip) -> {
385+
switch (ip) {
386+
case "10.0.0.1":
387+
return "us";
388+
case "10.0.0.2":
389+
return "de";
390+
default:
391+
return "ir";
392+
}
393+
})))
383394
.build();
384395
}
385396

0 commit comments

Comments
 (0)