@@ -60,7 +60,7 @@ public final class AstMutator {
6060 private final long iterationLimit ;
6161
6262 /**
63- * Returns a new instance of a AST mutator with the iteration limit set.
63+ * Returns a new instance of an AST mutator with the iteration limit set.
6464 *
6565 * <p>Mutation is performed by walking the existing AST until the expression node to replace is
6666 * found, then the new subtree is walked to complete the mutation. Visiting of each node
@@ -203,22 +203,22 @@ public CelMutableAst renumberIdsConsecutively(CelMutableAst mutableAst) {
203203 * @param newIterVarPrefix Prefix to use for new iteration variable identifier name. For example,
204204 * providing @c will produce @c0:0, @c0:1, @c1:0, @c2:0... as new names.
205205 * @param newAccuVarPrefix Prefix to use for new accumulation variable identifier name.
206- * @param incrementSerially If true, indices for the mangled variables are incremented serially
207- * per occurrence regardless of their nesting level or its types.
208206 */
209207 public MangledComprehensionAst mangleComprehensionIdentifierNames (
210208 CelMutableAst ast ,
211209 String newIterVarPrefix ,
212- String newAccuVarPrefix ,
213- boolean incrementSerially ) {
210+ String newIterVar2Prefix ,
211+ String newAccuVarPrefix ) {
214212 CelNavigableMutableAst navigableMutableAst = CelNavigableMutableAst .fromAst (ast );
215213 Predicate <CelNavigableMutableExpr > comprehensionIdentifierPredicate = x -> true ;
216214 comprehensionIdentifierPredicate =
217215 comprehensionIdentifierPredicate
218216 .and (node -> node .getKind ().equals (Kind .COMPREHENSION ))
219- .and (node -> !node .expr ().comprehension ().iterVar ().startsWith (newIterVarPrefix ))
220- .and (node -> !node .expr ().comprehension ().accuVar ().startsWith (newAccuVarPrefix ));
221-
217+ .and (node -> !node .expr ().comprehension ().iterVar ().startsWith (newIterVarPrefix + ":" ))
218+ .and (node -> !node .expr ().comprehension ().accuVar ().startsWith (newAccuVarPrefix + ":" ))
219+ .and (
220+ node ->
221+ !node .expr ().comprehension ().iterVar2 ().startsWith (newIterVar2Prefix + ":" ));
222222 LinkedHashMap <CelNavigableMutableExpr , MangledComprehensionType > comprehensionsToMangle =
223223 navigableMutableAst
224224 .getRoot ()
@@ -231,20 +231,25 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
231231 // Ensure the iter_var or the comprehension result is actually referenced in the
232232 // loop_step. If it's not, we can skip mangling.
233233 String iterVar = node .expr ().comprehension ().iterVar ();
234+ String iterVar2 = node .expr ().comprehension ().iterVar2 ();
234235 String result = node .expr ().comprehension ().result ().ident ().name ();
235236 return CelNavigableMutableExpr .fromExpr (node .expr ().comprehension ().loopStep ())
236237 .allNodes ()
237238 .filter (subNode -> subNode .getKind ().equals (Kind .IDENT ))
238239 .map (subNode -> subNode .expr ().ident ())
239240 .anyMatch (
240- ident -> ident .name ().contains (iterVar ) || ident .name ().contains (result ));
241+ ident ->
242+ ident .name ().contains (iterVar )
243+ || ident .name ().contains (iterVar2 )
244+ || ident .name ().contains (result ));
241245 })
242246 .collect (
243247 Collectors .toMap (
244248 k -> k ,
245249 v -> {
246250 CelMutableComprehension comprehension = v .expr ().comprehension ();
247251 String iterVar = comprehension .iterVar ();
252+ String iterVar2 = comprehension .iterVar2 ();
248253 // Identifiers to mangle could be the iteration variable, comprehension
249254 // result or both, but at least one has to exist.
250255 // As an example, [1,2].map(i, 3) would result in optional.empty for iteration
@@ -258,6 +263,16 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
258263 && loopStepNode .expr ().ident ().name ().equals (iterVar ))
259264 .map (CelNavigableMutableExpr ::id )
260265 .findAny ();
266+ Optional <Long > iterVar2Id =
267+ CelNavigableMutableExpr .fromExpr (comprehension .loopStep ())
268+ .allNodes ()
269+ .filter (
270+ loopStepNode ->
271+ !iterVar2 .isEmpty ()
272+ && loopStepNode .getKind ().equals (Kind .IDENT )
273+ && loopStepNode .expr ().ident ().name ().equals (iterVar2 ))
274+ .map (CelNavigableMutableExpr ::id )
275+ .findAny ();
261276 Optional <CelType > iterVarType =
262277 iterVarId .map (
263278 id ->
@@ -269,6 +284,17 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
269284 "Checked type not present for iteration"
270285 + " variable: "
271286 + iterVarId )));
287+ Optional <CelType > iterVar2Type =
288+ iterVar2Id .map (
289+ id ->
290+ navigableMutableAst
291+ .getType (id )
292+ .orElseThrow (
293+ () ->
294+ new NoSuchElementException (
295+ "Checked type not present for iteration"
296+ + " variable: "
297+ + iterVar2Id )));
272298 CelType resultType =
273299 navigableMutableAst
274300 .getType (comprehension .result ().id ())
@@ -278,7 +304,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
278304 "Result type was not present for the comprehension ID: "
279305 + comprehension .result ().id ()));
280306
281- return MangledComprehensionType .of (iterVarType , resultType );
307+ return MangledComprehensionType .of (iterVarType , iterVar2Type , resultType );
282308 },
283309 (x , y ) -> {
284310 throw new IllegalStateException (
@@ -301,38 +327,25 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
301327 MangledComprehensionType comprehensionEntryType = comprehensionEntry .getValue ();
302328
303329 CelMutableExpr comprehensionExpr = comprehensionNode .expr ();
304- MangledComprehensionName mangledComprehensionName ;
305- if (incrementSerially ) {
306- // In case of applying CSE via cascaded cel.binds, not only is mangling based on level/types
307- // meaningless (because all comprehensions are nested anyways, thus all indices would be
308- // uinque),
309- // it can lead to an erroneous result due to extracting a common subexpr with accu_var at
310- // the wrong scope.
311- // Example: "[1].exists(k, k > 1) && [2].exists(l, l > 1). The loop step for both branches
312- // are identical, but shouldn't be extracted.
313- String mangledIterVarName = newIterVarPrefix + ":" + iterCount ;
314- String mangledResultName = newAccuVarPrefix + ":" + iterCount ;
315- mangledComprehensionName =
316- MangledComprehensionName .of (mangledIterVarName , mangledResultName );
317- mangledIdentNamesToType .put (mangledComprehensionName , comprehensionEntry .getValue ());
318- } else {
319- mangledComprehensionName =
320- getMangledComprehensionName (
321- newIterVarPrefix ,
322- newAccuVarPrefix ,
323- comprehensionNode ,
324- comprehensionLevelToType ,
325- comprehensionEntryType );
326- }
330+ MangledComprehensionName mangledComprehensionName =
331+ getMangledComprehensionName (
332+ newIterVarPrefix ,
333+ newIterVar2Prefix ,
334+ newAccuVarPrefix ,
335+ comprehensionNode ,
336+ comprehensionLevelToType ,
337+ comprehensionEntryType );
327338 mangledIdentNamesToType .put (mangledComprehensionName , comprehensionEntryType );
328339
329340 String iterVar = comprehensionExpr .comprehension ().iterVar ();
341+ String iterVar2 = comprehensionExpr .comprehension ().iterVar2 ();
330342 String accuVar = comprehensionExpr .comprehension ().accuVar ();
331343 mutatedComprehensionExpr =
332344 mangleIdentsInComprehensionExpr (
333345 mutatedComprehensionExpr ,
334346 comprehensionExpr ,
335347 iterVar ,
348+ iterVar2 ,
336349 accuVar ,
337350 mangledComprehensionName );
338351 // Repeat the mangling process for the macro source.
@@ -341,6 +354,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
341354 newSource ,
342355 mutatedComprehensionExpr ,
343356 iterVar ,
357+ iterVar2 ,
344358 mangledComprehensionName ,
345359 comprehensionExpr .id ());
346360 iterCount ++;
@@ -360,6 +374,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
360374
361375 private static MangledComprehensionName getMangledComprehensionName (
362376 String newIterVarPrefix ,
377+ String newIterVar2Prefix ,
363378 String newResultPrefix ,
364379 CelNavigableMutableExpr comprehensionNode ,
365380 Table <Integer , MangledComprehensionType , MangledComprehensionName > comprehensionLevelToType ,
@@ -377,7 +392,11 @@ private static MangledComprehensionName getMangledComprehensionName(
377392 newIterVarPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx ;
378393 String mangledResultName =
379394 newResultPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx ;
380- mangledComprehensionName = MangledComprehensionName .of (mangledIterVarName , mangledResultName );
395+ String mangledIterVar2Name =
396+ newIterVar2Prefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx ;
397+
398+ mangledComprehensionName =
399+ MangledComprehensionName .of (mangledIterVarName , mangledIterVar2Name , mangledResultName );
381400 comprehensionLevelToType .put (
382401 comprehensionNestingLevel , comprehensionEntryType , mangledComprehensionName );
383402 }
@@ -530,6 +549,7 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
530549 CelMutableExpr root ,
531550 CelMutableExpr comprehensionExpr ,
532551 String originalIterVar ,
552+ String originalIterVar2 ,
533553 String originalAccuVar ,
534554 MangledComprehensionName mangledComprehensionName ) {
535555 CelMutableComprehension comprehension = comprehensionExpr .comprehension ();
@@ -538,11 +558,18 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
538558 replaceIdentName (comprehensionExpr , originalAccuVar , mangledComprehensionName .resultName ());
539559
540560 comprehension .setIterVar (mangledComprehensionName .iterVarName ());
561+
541562 // Most standard macros set accu_var as __result__, but not all (ex: cel.bind).
542563 if (comprehension .accuVar ().equals (originalAccuVar )) {
543564 comprehension .setAccuVar (mangledComprehensionName .resultName ());
544565 }
545566
567+ if (!originalIterVar2 .isEmpty ()) {
568+ comprehension .setIterVar2 (mangledComprehensionName .iterVar2Name ());
569+ replaceIdentName (
570+ comprehension .loopStep (), originalIterVar2 , mangledComprehensionName .iterVar2Name ());
571+ }
572+
546573 return mutateExpr (NO_OP_ID_GENERATOR , root , comprehensionExpr , comprehensionExpr .id ());
547574 }
548575
@@ -581,6 +608,7 @@ private CelMutableSource mangleIdentsInMacroSource(
581608 CelMutableSource sourceBuilder ,
582609 CelMutableExpr mutatedComprehensionExpr ,
583610 String originalIterVar ,
611+ String originalIterVar2 ,
584612 MangledComprehensionName mangledComprehensionName ,
585613 long originalComprehensionId ) {
586614 if (!sourceBuilder .getMacroCalls ().containsKey (originalComprehensionId )) {
@@ -604,14 +632,25 @@ private CelMutableSource mangleIdentsInMacroSource(
604632 // macro call expression.
605633 CelMutableExpr identToMangle = macroExpr .call ().args ().get (0 );
606634 if (identToMangle .ident ().name ().equals (originalIterVar )) {
607- // if (identToMangle.identOrDefault().name().equals(originalIterVar)) {
608635 macroExpr =
609636 mutateExpr (
610637 NO_OP_ID_GENERATOR ,
611638 macroExpr ,
612639 CelMutableExpr .ofIdent (mangledComprehensionName .iterVarName ()),
613640 identToMangle .id ());
614641 }
642+ if (!originalIterVar2 .isEmpty ()) {
643+ // Similarly by convention, iter_var2 is always the second argument of the macro call.
644+ identToMangle = macroExpr .call ().args ().get (1 );
645+ if (identToMangle .ident ().name ().equals (originalIterVar2 )) {
646+ macroExpr =
647+ mutateExpr (
648+ NO_OP_ID_GENERATOR ,
649+ macroExpr ,
650+ CelMutableExpr .ofIdent (mangledComprehensionName .iterVar2Name ()),
651+ identToMangle .id ());
652+ }
653+ }
615654
616655 newSource .addMacroCalls (originalComprehensionId , macroExpr );
617656
@@ -815,7 +854,7 @@ private static void unwrapListArgumentsInMacroCallExpr(
815854 newMacroCall .addArgs (
816855 existingMacroCall .args ().get (0 )); // iter_var is first argument of the call by convention
817856
818- CelMutableList extraneousList = null ;
857+ CelMutableList extraneousList ;
819858 if (loopStepArgs .size () == 2 ) {
820859 extraneousList = loopStepArgs .get (1 ).list ();
821860 } else {
@@ -895,14 +934,22 @@ private static MangledComprehensionAst of(
895934 @ AutoValue
896935 public abstract static class MangledComprehensionType {
897936
898- /** Type of iter_var */
937+ /**
938+ * Type of iter_var. Empty if iter_var is not referenced in the expression anywhere (ex: "i" in
939+ * "[1].exists(i, true)"
940+ */
899941 public abstract Optional <CelType > iterVarType ();
900942
943+ /** Type of iter_var2. */
944+ public abstract Optional <CelType > iterVar2Type ();
945+
901946 /** Type of comprehension result */
902947 public abstract CelType resultType ();
903948
904- private static MangledComprehensionType of (Optional <CelType > iterVarType , CelType resultType ) {
905- return new AutoValue_AstMutator_MangledComprehensionType (iterVarType , resultType );
949+ private static MangledComprehensionType of (
950+ Optional <CelType > iterVarType , Optional <CelType > iterVarType2 , CelType resultType ) {
951+ return new AutoValue_AstMutator_MangledComprehensionType (
952+ iterVarType , iterVarType2 , resultType );
906953 }
907954 }
908955
@@ -916,11 +963,16 @@ public abstract static class MangledComprehensionName {
916963 /** Mangled name for iter_var */
917964 public abstract String iterVarName ();
918965
966+ /** Mangled name for iter_var2 */
967+ public abstract String iterVar2Name ();
968+
919969 /** Mangled name for comprehension result */
920970 public abstract String resultName ();
921971
922- private static MangledComprehensionName of (String iterVarName , String resultName ) {
923- return new AutoValue_AstMutator_MangledComprehensionName (iterVarName , resultName );
972+ private static MangledComprehensionName of (
973+ String iterVarName , String iterVar2Name , String resultName ) {
974+ return new AutoValue_AstMutator_MangledComprehensionName (
975+ iterVarName , iterVar2Name , resultName );
924976 }
925977 }
926978}
0 commit comments