1616
1717import static com .google .common .base .Preconditions .checkNotNull ;
1818import static com .google .common .collect .ImmutableList .toImmutableList ;
19+ import static com .google .common .collect .ImmutableSet .toImmutableSet ;
1920import static java .util .stream .Collectors .toCollection ;
2021
2122import com .google .auto .value .AutoValue ;
2223import com .google .common .annotations .VisibleForTesting ;
2324import com .google .common .base .Preconditions ;
25+ import com .google .common .base .Strings ;
2426import com .google .common .base .Verify ;
2527import com .google .common .collect .ImmutableList ;
2628import com .google .common .collect .ImmutableSet ;
4143import dev .cel .common .CelVarDecl ;
4244import dev .cel .common .ast .CelExpr ;
4345import dev .cel .common .ast .CelExpr .CelCall ;
46+ import dev .cel .common .ast .CelExpr .CelComprehension ;
47+ import dev .cel .common .ast .CelExpr .CelList ;
4448import dev .cel .common .ast .CelExpr .ExprKind .Kind ;
4549import dev .cel .common .ast .CelMutableExpr ;
50+ import dev .cel .common .ast .CelMutableExpr .CelMutableComprehension ;
4651import dev .cel .common .ast .CelMutableExprConverter ;
4752import dev .cel .common .navigation .CelNavigableExpr ;
4853import dev .cel .common .navigation .CelNavigableMutableAst ;
5964import java .util .Comparator ;
6065import java .util .HashSet ;
6166import java .util .List ;
67+ import java .util .Objects ;
6268import java .util .Set ;
69+ import java .util .stream .Stream ;
6370
6471/**
6572 * Performs Common Subexpression Elimination.
@@ -90,14 +97,15 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9097 private static final SubexpressionOptimizer INSTANCE =
9198 new SubexpressionOptimizer (SubexpressionOptimizerOptions .newBuilder ().build ());
9299 private static final String BIND_IDENTIFIER_PREFIX = "@r" ;
93- private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it" ;
94- private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2" ;
95- private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac" ;
96100 private static final String CEL_BLOCK_FUNCTION = "cel.@block" ;
97101 private static final String BLOCK_INDEX_PREFIX = "@index" ;
98102 private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
99103 Extension .create ("cel_block" , Version .of (1L , 1L ), Component .COMPONENT_RUNTIME );
100104
105+ @ VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it" ;
106+ @ VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2" ;
107+ @ VisibleForTesting static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac" ;
108+
101109 private final SubexpressionOptimizerOptions cseOptions ;
102110 private final AstMutator astMutator ;
103111 private final ImmutableSet <String > cseEliminableFunctions ;
@@ -269,6 +277,8 @@ static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
269277 Verify .verify (
270278 resultHasAtLeastOneBlockIndex ,
271279 "Expected at least one reference of index in cel.block result" );
280+
281+ verifyNoInvalidScopedMangledVariables (celBlockExpr );
272282 }
273283
274284 private static void verifyBlockIndex (CelExpr celExpr , int maxIndexValue ) {
@@ -289,6 +299,67 @@ private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
289299 celExpr );
290300 }
291301
302+ private static void verifyNoInvalidScopedMangledVariables (CelExpr celExpr ) {
303+ CelCall celBlockCall = celExpr .call ();
304+ CelExpr blockBody = celBlockCall .args ().get (1 );
305+
306+ ImmutableSet <String > allMangledVariablesInBlockBody =
307+ CelNavigableExpr .fromExpr (blockBody )
308+ .allNodes ()
309+ .map (CelNavigableExpr ::expr )
310+ .flatMap (SubexpressionOptimizer ::extractMangledNames )
311+ .collect (toImmutableSet ());
312+
313+ CelList blockIndices = celBlockCall .args ().get (0 ).list ();
314+ for (CelExpr blockIndex : blockIndices .elements ()) {
315+ ImmutableSet <String > indexDeclaredCompVariables =
316+ CelNavigableExpr .fromExpr (blockIndex )
317+ .allNodes ()
318+ .map (CelNavigableExpr ::expr )
319+ .filter (expr -> expr .getKind () == Kind .COMPREHENSION )
320+ .map (CelExpr ::comprehension )
321+ .flatMap (comp -> Stream .of (comp .iterVar (), comp .iterVar2 ()))
322+ .filter (iter -> !Strings .isNullOrEmpty (iter ))
323+ .collect (toImmutableSet ());
324+
325+ boolean containsIllegalDeclaration =
326+ CelNavigableExpr .fromExpr (blockIndex )
327+ .allNodes ()
328+ .map (CelNavigableExpr ::expr )
329+ .filter (expr -> expr .getKind () == Kind .IDENT )
330+ .map (expr -> expr .ident ().name ())
331+ .filter (SubexpressionOptimizer ::isMangled )
332+ .anyMatch (
333+ ident ->
334+ !indexDeclaredCompVariables .contains (ident )
335+ && allMangledVariablesInBlockBody .contains (ident ));
336+
337+ Verify .verify (
338+ !containsIllegalDeclaration ,
339+ "Illegal declared reference to a comprehension variable found in block indices. Expr: %s" ,
340+ celExpr );
341+ }
342+ }
343+
344+ private static Stream <String > extractMangledNames (CelExpr expr ) {
345+ if (expr .getKind () == Kind .IDENT ) {
346+ String name = expr .ident ().name ();
347+ return isMangled (name ) ? Stream .of (name ) : Stream .empty ();
348+ }
349+ if (expr .getKind () == Kind .COMPREHENSION ) {
350+ CelComprehension comp = expr .comprehension ();
351+ return Stream .of (comp .iterVar (), comp .iterVar2 (), comp .accuVar ())
352+ .filter (Objects ::nonNull ) // Handle potential null/empty iterVar2
353+ .filter (SubexpressionOptimizer ::isMangled );
354+ }
355+ return Stream .empty ();
356+ }
357+
358+ private static boolean isMangled (String name ) {
359+ return name .startsWith (MANGLED_COMPREHENSION_ITER_VAR_PREFIX )
360+ || name .startsWith (MANGLED_COMPREHENSION_ITER_VAR2_PREFIX );
361+ }
362+
292363 private static CelAbstractSyntaxTree tagAstExtension (CelAbstractSyntaxTree ast ) {
293364 // Tag the extension
294365 CelSource .Builder celSourceBuilder =
@@ -355,8 +426,8 @@ private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
355426 navAst
356427 .getRoot ()
357428 .descendants (TraversalOrder .PRE_ORDER )
358- .filter (node -> canEliminate (node , ineligibleExprs ))
359429 .filter (node -> node .height () <= recursionLimit )
430+ .filter (node -> canEliminate (node , ineligibleExprs ))
360431 .sorted (Comparator .comparingInt (CelNavigableMutableExpr ::height ).reversed ())
361432 .collect (toImmutableList ());
362433 if (descendants .isEmpty ()) {
@@ -441,7 +512,45 @@ private boolean canEliminate(
441512 && navigableExpr .expr ().list ().elements ().isEmpty ())
442513 && containsEliminableFunctionOnly (navigableExpr )
443514 && !ineligibleExprs .contains (navigableExpr .expr ())
444- && containsComprehensionIdentInSubexpr (navigableExpr );
515+ && containsComprehensionIdentInSubexpr (navigableExpr )
516+ && containsProperScopedComprehensionIdents (navigableExpr );
517+ }
518+
519+ private boolean containsProperScopedComprehensionIdents (CelNavigableMutableExpr navExpr ) {
520+ if (!navExpr .getKind ().equals (Kind .COMPREHENSION )) {
521+ return true ;
522+ }
523+
524+ // For nested comprehensions of form [1].exists(x, [2].exists(y, x == y)), the inner
525+ // comprehension [2].exists(y, x == y)
526+ // should not be extracted out into a block index, as it causes issues with scoping.
527+ ImmutableSet <String > mangledIterVars =
528+ navExpr
529+ .descendants ()
530+ .filter (x -> x .getKind ().equals (Kind .IDENT ))
531+ .map (x -> x .expr ().ident ().name ())
532+ .filter (
533+ name ->
534+ name .startsWith (MANGLED_COMPREHENSION_ITER_VAR_PREFIX )
535+ || name .startsWith (MANGLED_COMPREHENSION_ITER_VAR2_PREFIX ))
536+ .collect (toImmutableSet ());
537+
538+ CelNavigableMutableExpr parent = navExpr .parent ().orElse (null );
539+ while (parent != null ) {
540+ if (parent .getKind ().equals (Kind .COMPREHENSION )) {
541+ CelMutableComprehension comp = parent .expr ().comprehension ();
542+ boolean containsParentIterReferences =
543+ mangledIterVars .contains (comp .iterVar ()) || mangledIterVars .contains (comp .iterVar2 ());
544+
545+ if (containsParentIterReferences ) {
546+ return false ;
547+ }
548+ }
549+
550+ parent = parent .parent ().orElse (null );
551+ }
552+
553+ return true ;
445554 }
446555
447556 private boolean containsComprehensionIdentInSubexpr (CelNavigableMutableExpr navExpr ) {
0 commit comments