Skip to content

Commit 9d402f1

Browse files
l46kokcopybara-github
authored andcommitted
Persist lazily bound variables in the correct scoped resolver
PiperOrigin-RevId: 839860417
1 parent 741ad14 commit 9d402f1

9 files changed

Lines changed: 778 additions & 542 deletions

File tree

extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import dev.cel.runtime.CelRuntime;
3939
import dev.cel.runtime.CelRuntimeFactory;
4040
import java.util.Arrays;
41+
import java.util.List;
4142
import java.util.concurrent.atomic.AtomicInteger;
4243
import org.junit.Test;
4344
import org.junit.runner.RunWith;
@@ -243,4 +244,38 @@ public void lazyBinding_withNestedBinds() throws Exception {
243244
assertThat(result).isTrue();
244245
assertThat(invocation.get()).isEqualTo(2);
245246
}
247+
248+
@Test
249+
@SuppressWarnings({"Immutable", "unchecked"}) // Test only
250+
public void lazyBinding_boundAttributeInComprehension() throws Exception {
251+
CelCompiler celCompiler =
252+
CelCompilerFactory.standardCelCompilerBuilder()
253+
.setStandardMacros(CelStandardMacro.MAP)
254+
.addLibraries(CelExtensions.bindings())
255+
.addFunctionDeclarations(
256+
CelFunctionDecl.newFunctionDeclaration(
257+
"get_true",
258+
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
259+
.build();
260+
AtomicInteger invocation = new AtomicInteger();
261+
CelRuntime celRuntime =
262+
CelRuntimeFactory.standardCelRuntimeBuilder()
263+
.addFunctionBindings(
264+
CelFunctionBinding.from(
265+
"get_true_overload",
266+
ImmutableList.of(),
267+
arg -> {
268+
invocation.getAndIncrement();
269+
return true;
270+
}))
271+
.build();
272+
273+
CelAbstractSyntaxTree ast =
274+
celCompiler.compile("cel.bind(x, get_true(), [1,2,3].map(y, y < 0 || x))").getAst();
275+
276+
List<Boolean> result = (List<Boolean>) celRuntime.createProgram(ast).eval();
277+
278+
assertThat(result).containsExactly(true, true, true);
279+
assertThat(invocation.get()).isEqualTo(1);
280+
}
246281
}

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
import static com.google.common.base.Preconditions.checkNotNull;
1818
import static com.google.common.collect.ImmutableList.toImmutableList;
19+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
1920
import static java.util.stream.Collectors.toCollection;
2021

2122
import com.google.auto.value.AutoValue;
2223
import com.google.common.annotations.VisibleForTesting;
2324
import com.google.common.base.Preconditions;
25+
import com.google.common.base.Strings;
2426
import com.google.common.base.Verify;
2527
import com.google.common.collect.ImmutableList;
2628
import com.google.common.collect.ImmutableSet;
@@ -41,8 +43,11 @@
4143
import dev.cel.common.CelVarDecl;
4244
import dev.cel.common.ast.CelExpr;
4345
import dev.cel.common.ast.CelExpr.CelCall;
46+
import dev.cel.common.ast.CelExpr.CelComprehension;
47+
import dev.cel.common.ast.CelExpr.CelList;
4448
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
4549
import dev.cel.common.ast.CelMutableExpr;
50+
import dev.cel.common.ast.CelMutableExpr.CelMutableComprehension;
4651
import dev.cel.common.ast.CelMutableExprConverter;
4752
import dev.cel.common.navigation.CelNavigableExpr;
4853
import dev.cel.common.navigation.CelNavigableMutableAst;
@@ -59,7 +64,9 @@
5964
import java.util.Comparator;
6065
import java.util.HashSet;
6166
import java.util.List;
67+
import java.util.Objects;
6268
import java.util.Set;
69+
import java.util.stream.Stream;
6370

6471
/**
6572
* Performs Common Subexpression Elimination.
@@ -90,14 +97,15 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9097
private static final SubexpressionOptimizer INSTANCE =
9198
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
9299
private static final String BIND_IDENTIFIER_PREFIX = "@r";
93-
private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
94-
private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
95-
private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
96100
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
97101
private static final String BLOCK_INDEX_PREFIX = "@index";
98102
private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
99103
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME);
100104

105+
@VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
106+
@VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
107+
@VisibleForTesting static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
108+
101109
private final SubexpressionOptimizerOptions cseOptions;
102110
private final AstMutator astMutator;
103111
private final ImmutableSet<String> cseEliminableFunctions;
@@ -269,6 +277,8 @@ static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
269277
Verify.verify(
270278
resultHasAtLeastOneBlockIndex,
271279
"Expected at least one reference of index in cel.block result");
280+
281+
verifyNoInvalidScopedMangledVariables(celBlockExpr);
272282
}
273283

274284
private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
@@ -289,6 +299,67 @@ private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
289299
celExpr);
290300
}
291301

302+
private static void verifyNoInvalidScopedMangledVariables(CelExpr celExpr) {
303+
CelCall celBlockCall = celExpr.call();
304+
CelExpr blockBody = celBlockCall.args().get(1);
305+
306+
ImmutableSet<String> allMangledVariablesInBlockBody =
307+
CelNavigableExpr.fromExpr(blockBody)
308+
.allNodes()
309+
.map(CelNavigableExpr::expr)
310+
.flatMap(SubexpressionOptimizer::extractMangledNames)
311+
.collect(toImmutableSet());
312+
313+
CelList blockIndices = celBlockCall.args().get(0).list();
314+
for (CelExpr blockIndex : blockIndices.elements()) {
315+
ImmutableSet<String> indexDeclaredCompVariables =
316+
CelNavigableExpr.fromExpr(blockIndex)
317+
.allNodes()
318+
.map(CelNavigableExpr::expr)
319+
.filter(expr -> expr.getKind() == Kind.COMPREHENSION)
320+
.map(CelExpr::comprehension)
321+
.flatMap(comp -> Stream.of(comp.iterVar(), comp.iterVar2()))
322+
.filter(iter -> !Strings.isNullOrEmpty(iter))
323+
.collect(toImmutableSet());
324+
325+
boolean containsIllegalDeclaration =
326+
CelNavigableExpr.fromExpr(blockIndex)
327+
.allNodes()
328+
.map(CelNavigableExpr::expr)
329+
.filter(expr -> expr.getKind() == Kind.IDENT)
330+
.map(expr -> expr.ident().name())
331+
.filter(SubexpressionOptimizer::isMangled)
332+
.anyMatch(
333+
ident ->
334+
!indexDeclaredCompVariables.contains(ident)
335+
&& allMangledVariablesInBlockBody.contains(ident));
336+
337+
Verify.verify(
338+
!containsIllegalDeclaration,
339+
"Illegal declared reference to a comprehension variable found in block indices. Expr: %s",
340+
celExpr);
341+
}
342+
}
343+
344+
private static Stream<String> extractMangledNames(CelExpr expr) {
345+
if (expr.getKind() == Kind.IDENT) {
346+
String name = expr.ident().name();
347+
return isMangled(name) ? Stream.of(name) : Stream.empty();
348+
}
349+
if (expr.getKind() == Kind.COMPREHENSION) {
350+
CelComprehension comp = expr.comprehension();
351+
return Stream.of(comp.iterVar(), comp.iterVar2(), comp.accuVar())
352+
.filter(Objects::nonNull) // Handle potential null/empty iterVar2
353+
.filter(SubexpressionOptimizer::isMangled);
354+
}
355+
return Stream.empty();
356+
}
357+
358+
private static boolean isMangled(String name) {
359+
return name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
360+
|| name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX);
361+
}
362+
292363
private static CelAbstractSyntaxTree tagAstExtension(CelAbstractSyntaxTree ast) {
293364
// Tag the extension
294365
CelSource.Builder celSourceBuilder =
@@ -355,8 +426,8 @@ private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
355426
navAst
356427
.getRoot()
357428
.descendants(TraversalOrder.PRE_ORDER)
358-
.filter(node -> canEliminate(node, ineligibleExprs))
359429
.filter(node -> node.height() <= recursionLimit)
430+
.filter(node -> canEliminate(node, ineligibleExprs))
360431
.sorted(Comparator.comparingInt(CelNavigableMutableExpr::height).reversed())
361432
.collect(toImmutableList());
362433
if (descendants.isEmpty()) {
@@ -441,7 +512,45 @@ private boolean canEliminate(
441512
&& navigableExpr.expr().list().elements().isEmpty())
442513
&& containsEliminableFunctionOnly(navigableExpr)
443514
&& !ineligibleExprs.contains(navigableExpr.expr())
444-
&& containsComprehensionIdentInSubexpr(navigableExpr);
515+
&& containsComprehensionIdentInSubexpr(navigableExpr)
516+
&& containsProperScopedComprehensionIdents(navigableExpr);
517+
}
518+
519+
private boolean containsProperScopedComprehensionIdents(CelNavigableMutableExpr navExpr) {
520+
if (!navExpr.getKind().equals(Kind.COMPREHENSION)) {
521+
return true;
522+
}
523+
524+
// For nested comprehensions of form [1].exists(x, [2].exists(y, x == y)), the inner
525+
// comprehension [2].exists(y, x == y)
526+
// should not be extracted out into a block index, as it causes issues with scoping.
527+
ImmutableSet<String> mangledIterVars =
528+
navExpr
529+
.descendants()
530+
.filter(x -> x.getKind().equals(Kind.IDENT))
531+
.map(x -> x.expr().ident().name())
532+
.filter(
533+
name ->
534+
name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
535+
|| name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX))
536+
.collect(toImmutableSet());
537+
538+
CelNavigableMutableExpr parent = navExpr.parent().orElse(null);
539+
while (parent != null) {
540+
if (parent.getKind().equals(Kind.COMPREHENSION)) {
541+
CelMutableComprehension comp = parent.expr().comprehension();
542+
boolean containsParentIterReferences =
543+
mangledIterVars.contains(comp.iterVar()) || mangledIterVars.contains(comp.iterVar2());
544+
545+
if (containsParentIterReferences) {
546+
return false;
547+
}
548+
}
549+
550+
parent = parent.parent().orElse(null);
551+
}
552+
553+
return true;
445554
}
446555

447556
private boolean containsComprehensionIdentInSubexpr(CelNavigableMutableExpr navExpr) {

optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import dev.cel.runtime.CelFunctionBinding;
5656
import dev.cel.runtime.CelRuntime;
5757
import dev.cel.runtime.CelRuntimeFactory;
58+
import java.util.List;
5859
import java.util.concurrent.atomic.AtomicInteger;
5960
import org.junit.Test;
6061
import org.junit.runner.RunWith;
@@ -381,6 +382,31 @@ public void lazyEval_blockIndexEvaluatedOnlyOnce() throws Exception {
381382
assertThat(invocation.get()).isEqualTo(1);
382383
}
383384

385+
@Test
386+
@SuppressWarnings({"Immutable", "unchecked"}) // Test only
387+
public void lazyEval_withinComprehension_blockIndexEvaluatedOnlyOnce() throws Exception {
388+
AtomicInteger invocation = new AtomicInteger();
389+
CelRuntime celRuntime =
390+
CelRuntimeFactory.standardCelRuntimeBuilder()
391+
.addMessageTypes(TestAllTypes.getDescriptor())
392+
.addFunctionBindings(
393+
CelFunctionBinding.from(
394+
"get_true_overload",
395+
ImmutableList.of(),
396+
arg -> {
397+
invocation.getAndIncrement();
398+
return true;
399+
}))
400+
.build();
401+
CelAbstractSyntaxTree ast =
402+
compileUsingInternalFunctions("cel.block([get_true()], [1,2,3].map(x, x < 0 || index0))");
403+
404+
List<Boolean> result = (List<Boolean>) celRuntime.createProgram(ast).eval();
405+
406+
assertThat(result).containsExactly(true, true, true);
407+
assertThat(invocation.get()).isEqualTo(1);
408+
}
409+
384410
@Test
385411
@SuppressWarnings("Immutable") // Test only
386412
public void lazyEval_multipleBlockIndices_inResultExpr() throws Exception {
@@ -452,9 +478,9 @@ public void lazyEval_nestedComprehension_indexReferencedInNestedScopes() throws
452478
// Equivalent of [true, false, true].map(c0, [c0].map(c1, [c0, c1, true]))
453479
CelAbstractSyntaxTree ast =
454480
compileUsingInternalFunctions(
455-
"cel.block([c0, c1, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [index0,"
456-
+ " index1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true,"
457-
+ " true, true]]]");
481+
"cel.block([true, false, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [c0,"
482+
+ " c1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true, true,"
483+
+ " true]]]");
458484

459485
boolean result = (boolean) celRuntime.createProgram(ast).eval();
460486

0 commit comments

Comments
 (0)