Skip to content

Commit 13a0d52

Browse files
l46kokcopybara-github
authored andcommitted
Fix constant folding to not error when sub-asts contain unbound variables
CEL-Java fix for xref: google/cel-go#1296 PiperOrigin-RevId: 892757312
1 parent 4d00593 commit 13a0d52

File tree

5 files changed

+97
-32
lines changed

5 files changed

+97
-32
lines changed

optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ java_library(
3535
"//optimizer:mutable_ast",
3636
"//optimizer:optimization_exception",
3737
"//runtime",
38+
"//runtime:partial_vars",
39+
"//runtime:unknown_attributes",
3840
"@maven//:com_google_errorprone_error_prone_annotations",
3941
"@maven//:com_google_guava_guava",
4042
],

optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import dev.cel.common.CelValidationException;
3131
import dev.cel.common.Operator;
3232
import dev.cel.common.ast.CelConstant;
33-
import dev.cel.common.ast.CelExpr;
3433
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
3534
import dev.cel.common.ast.CelMutableExpr;
3635
import dev.cel.common.ast.CelMutableExpr.CelMutableCall;
@@ -47,7 +46,10 @@
4746
import dev.cel.optimizer.AstMutator;
4847
import dev.cel.optimizer.CelAstOptimizer;
4948
import dev.cel.optimizer.CelOptimizationException;
49+
import dev.cel.runtime.CelAttribute.Qualifier;
50+
import dev.cel.runtime.CelAttributePattern;
5051
import dev.cel.runtime.CelEvaluationException;
52+
import dev.cel.runtime.PartialVars;
5153
import java.time.Duration;
5254
import java.time.Instant;
5355
import java.util.ArrayList;
@@ -282,7 +284,7 @@ private Optional<CelMutableAst> maybeFold(
282284
throws CelOptimizationException {
283285
Object result;
284286
try {
285-
result = evaluateExpr(cel, CelMutableExprConverter.fromMutableExpr(node.expr()));
287+
result = evaluateExpr(cel, node);
286288
} catch (CelValidationException | CelEvaluationException e) {
287289
throw new CelOptimizationException(
288290
"Constant folding failure. Failed to evaluate subtree due to: " + e.getMessage(), e);
@@ -674,13 +676,23 @@ private CelMutableAst pruneOptionalStructElements(CelMutableAst ast, CelMutableE
674676
}
675677

676678
@CanIgnoreReturnValue
677-
private static Object evaluateExpr(Cel cel, CelExpr expr)
679+
private static Object evaluateExpr(Cel cel, CelNavigableMutableExpr navigableMutableExpr)
678680
throws CelValidationException, CelEvaluationException {
681+
ImmutableList<CelAttributePattern> attributePatterns =
682+
navigableMutableExpr
683+
.allNodes()
684+
.filter(node -> node.getKind().equals(Kind.IDENT))
685+
.map(node -> node.expr().ident().name())
686+
.filter(Qualifier::isLegalIdentifier)
687+
.map(CelAttributePattern::create)
688+
.collect(toImmutableList());
679689
CelAbstractSyntaxTree ast =
680-
CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build());
690+
CelAbstractSyntaxTree.newParsedAst(
691+
CelMutableExprConverter.fromMutableExpr(navigableMutableExpr.expr()),
692+
CelSource.newBuilder().build());
681693
ast = cel.check(ast).getAst();
682694

683-
return cel.createProgram(ast).eval();
695+
return cel.createProgram(ast).eval(PartialVars.of(attributePatterns));
684696
}
685697

686698
/** Options to configure how Constant Folding behave. */

optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ java_library(
1111
deps = [
1212
# "//java/com/google/testing/testsize:annotations",
1313
"//bundle:cel",
14+
"//bundle:cel_experimental_factory",
1415
"//common:cel_ast",
1516
"//common:cel_source",
1617
"//common:compiler_common",

optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
import static org.junit.Assert.assertThrows;
1919

2020
import com.google.common.collect.ImmutableList;
21+
import com.google.testing.junit.testparameterinjector.TestParameter;
2122
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
2223
import com.google.testing.junit.testparameterinjector.TestParameters;
2324
import dev.cel.bundle.Cel;
25+
import dev.cel.bundle.CelBuilder;
26+
import dev.cel.bundle.CelExperimentalFactory;
2427
import dev.cel.bundle.CelFactory;
2528
import dev.cel.common.CelAbstractSyntaxTree;
2629
import dev.cel.common.CelContainer;
@@ -47,9 +50,23 @@
4750
@RunWith(TestParameterInjector.class)
4851
public class ConstantFoldingOptimizerTest {
4952
private static final CelOptions CEL_OPTIONS =
50-
CelOptions.current().populateMacroCalls(true).build();
51-
private static final Cel CEL =
52-
CelFactory.standardCelBuilder()
53+
CelOptions.current()
54+
.populateMacroCalls(true)
55+
.enableHeterogeneousNumericComparisons(true)
56+
.build();
57+
58+
private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser();
59+
60+
@SuppressWarnings("ImmutableEnumChecker") // test only
61+
private enum RuntimeEnv {
62+
LEGACY(setupEnv(CelFactory.standardCelBuilder())),
63+
PLANNER(setupEnv(CelExperimentalFactory.plannerCelBuilder()));
64+
65+
private final Cel cel;
66+
private final CelOptimizer celOptimizer;
67+
68+
private static Cel setupEnv(CelBuilder celBuilder) {
69+
return celBuilder
5370
.addVar("x", SimpleType.DYN)
5471
.addVar("y", SimpleType.DYN)
5572
.addVar("list_var", ListType.create(SimpleType.STRING))
@@ -84,13 +101,28 @@ public class ConstantFoldingOptimizerTest {
84101
CelExtensions.sets(CEL_OPTIONS),
85102
CelExtensions.encoders(CEL_OPTIONS))
86103
.build();
104+
}
105+
106+
RuntimeEnv(Cel cel) {
107+
this.cel = cel;
108+
this.celOptimizer =
109+
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
110+
.addAstOptimizers(ConstantFoldingOptimizer.getInstance())
111+
.build();
112+
}
113+
114+
private CelBuilder newCelBuilder() {
115+
switch (this) {
116+
case LEGACY:
117+
return CelFactory.standardCelBuilder();
118+
case PLANNER:
119+
return CelExperimentalFactory.plannerCelBuilder();
120+
}
121+
throw new AssertionError("Unknown RuntimeEnv: " + this);
122+
}
123+
}
87124

88-
private static final CelOptimizer CEL_OPTIMIZER =
89-
CelOptimizerFactory.standardCelOptimizerBuilder(CEL)
90-
.addAstOptimizers(ConstantFoldingOptimizer.getInstance())
91-
.build();
92-
93-
private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser();
125+
@TestParameter RuntimeEnv runtimeEnv;
94126

95127
@Test
96128
@TestParameters("{source: 'null', expected: 'null'}")
@@ -238,9 +270,9 @@ public class ConstantFoldingOptimizerTest {
238270
// TODO: Support folding lists with mixed types. This requires mutable lists.
239271
// @TestParameters("{source: 'dyn([1]) + [1.0]'}")
240272
public void constantFold_success(String source, String expected) throws Exception {
241-
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();
273+
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(source).getAst();
242274

243-
CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast);
275+
CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast);
244276

245277
assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(expected);
246278
}
@@ -285,12 +317,13 @@ public void constantFold_success(String source, String expected) throws Exceptio
285317
public void constantFold_macros_macroCallMetadataPopulated(String source, String expected)
286318
throws Exception {
287319
Cel cel =
288-
CelFactory.standardCelBuilder()
320+
runtimeEnv
321+
.newCelBuilder()
289322
.addVar("x", SimpleType.DYN)
290323
.addVar("y", SimpleType.DYN)
291324
.addMessageTypes(TestAllTypes.getDescriptor())
292325
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
293-
.setOptions(CelOptions.current().populateMacroCalls(true).build())
326+
.setOptions(CEL_OPTIONS)
294327
.addCompilerLibraries(
295328
CelExtensions.bindings(), CelExtensions.optional(), CelExtensions.comprehensions())
296329
.addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions())
@@ -330,12 +363,17 @@ public void constantFold_macros_macroCallMetadataPopulated(String source, String
330363
@TestParameters("{source: 'false ? false : cel.bind(a, true, a)'}")
331364
public void constantFold_macros_withoutMacroCallMetadata(String source) throws Exception {
332365
Cel cel =
333-
CelFactory.standardCelBuilder()
366+
runtimeEnv
367+
.newCelBuilder()
334368
.addVar("x", SimpleType.DYN)
335369
.addVar("y", SimpleType.DYN)
336370
.addMessageTypes(TestAllTypes.getDescriptor())
337371
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
338-
.setOptions(CelOptions.current().populateMacroCalls(false).build())
372+
.setOptions(
373+
CelOptions.current()
374+
.enableHeterogeneousNumericComparisons(true)
375+
.populateMacroCalls(false)
376+
.build())
339377
.addCompilerLibraries(CelExtensions.bindings(), CelOptionalLibrary.INSTANCE)
340378
.addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
341379
.build();
@@ -378,21 +416,22 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E
378416
@TestParameters("{source: 'duration(\"1h\")'}")
379417
@TestParameters("{source: '[true].exists(x, x == get_true())'}")
380418
@TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}")
419+
@TestParameters("{source: '[(x - 1 > 3) ? (x - 1) : 5].exists(x, x - 1 > 3)'}")
381420
public void constantFold_noOp(String source) throws Exception {
382-
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();
421+
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(source).getAst();
383422

384-
CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast);
423+
CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast);
385424

386425
assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source);
387426
}
388427

389428
@Test
390429
public void constantFold_addFoldableFunction_success() throws Exception {
391-
CelAbstractSyntaxTree ast = CEL.compile("get_true() == get_true()").getAst();
430+
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("get_true() == get_true()").getAst();
392431
ConstantFoldingOptions options =
393432
ConstantFoldingOptions.newBuilder().addFoldableFunctions("get_true").build();
394433
CelOptimizer optimizer =
395-
CelOptimizerFactory.standardCelOptimizerBuilder(CEL)
434+
CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel)
396435
.addAstOptimizers(ConstantFoldingOptimizer.newInstance(options))
397436
.build();
398437

@@ -403,7 +442,7 @@ public void constantFold_addFoldableFunction_success() throws Exception {
403442

404443
@Test
405444
public void constantFold_withExpectedResultTypeSet_success() throws Exception {
406-
Cel cel = CelFactory.standardCelBuilder().setResultType(SimpleType.STRING).build();
445+
Cel cel = runtimeEnv.newCelBuilder().setResultType(SimpleType.STRING).build();
407446
CelOptimizer optimizer =
408447
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
409448
.addAstOptimizers(ConstantFoldingOptimizer.getInstance())
@@ -419,10 +458,11 @@ public void constantFold_withExpectedResultTypeSet_success() throws Exception {
419458
public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNotSet()
420459
throws Exception {
421460
Cel cel =
422-
CelFactory.standardCelBuilder()
461+
runtimeEnv
462+
.newCelBuilder()
423463
.addVar("x", SimpleType.DYN)
424464
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
425-
.setOptions(CelOptions.current().populateMacroCalls(true).build())
465+
.setOptions(CEL_OPTIONS)
426466
.build();
427467
CelOptimizer celOptimizer =
428468
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
@@ -492,9 +532,9 @@ public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNot
492532

493533
@Test
494534
public void constantFold_astProducesConsistentlyNumberedIds() throws Exception {
495-
CelAbstractSyntaxTree ast = CEL.compile("[1] + [2] + [3]").getAst();
535+
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("[1] + [2] + [3]").getAst();
496536

497-
CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast);
537+
CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast);
498538

499539
assertThat(optimizedAst.getExpr().toString())
500540
.isEqualTo(
@@ -515,8 +555,13 @@ public void iterationLimitReached_throws() throws Exception {
515555
sb.append(" + ").append(i);
516556
} // 0 + 1 + 2 + 3 + ... 200
517557
Cel cel =
518-
CelFactory.standardCelBuilder()
519-
.setOptions(CelOptions.current().maxParseRecursionDepth(200).build())
558+
runtimeEnv
559+
.newCelBuilder()
560+
.setOptions(
561+
CelOptions.current()
562+
.enableHeterogeneousNumericComparisons(true)
563+
.maxParseRecursionDepth(200)
564+
.build())
520565
.build();
521566
CelAbstractSyntaxTree ast = cel.compile(sb.toString()).getAst();
522567
CelOptimizer optimizer =

runtime/src/main/java/dev/cel/runtime/PartialVars.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ public abstract class PartialVars {
3737

3838
/** Constructs a new {@code PartialVars} from one or more {@link CelAttributePattern}s. */
3939
public static PartialVars of(CelAttributePattern... unknownAttributes) {
40-
return of((unused) -> Optional.empty(), ImmutableList.copyOf(unknownAttributes));
40+
return of(ImmutableList.copyOf(unknownAttributes));
41+
}
42+
43+
/** Constructs a new {@code PartialVars} from a list of {@link CelAttributePattern}s. */
44+
public static PartialVars of(Iterable<CelAttributePattern> unknownAttributes) {
45+
return of((unused) -> Optional.empty(), unknownAttributes);
4146
}
4247

4348
/**

0 commit comments

Comments
 (0)