@@ -60,7 +60,7 @@ public final class AstMutator {
6060 private final long iterationLimit ;
6161
6262 /**
63- * Returns a new instance of a AST mutator with the iteration limit set.
63+ * Returns a new instance of an AST mutator with the iteration limit set.
6464 *
6565 * <p>Mutation is performed by walking the existing AST until the expression node to replace is
6666 * found, then the new subtree is walked to complete the mutation. Visiting of each node
@@ -203,22 +203,20 @@ public CelMutableAst renumberIdsConsecutively(CelMutableAst mutableAst) {
203203 * @param newIterVarPrefix Prefix to use for new iteration variable identifier name. For example,
204204 * providing @c will produce @c0:0, @c0:1, @c1:0, @c2:0... as new names.
205205 * @param newAccuVarPrefix Prefix to use for new accumulation variable identifier name.
206- * @param incrementSerially If true, indices for the mangled variables are incremented serially
207- * per occurrence regardless of their nesting level or its types.
208206 */
209207 public MangledComprehensionAst mangleComprehensionIdentifierNames (
210208 CelMutableAst ast ,
211209 String newIterVarPrefix ,
212- String newAccuVarPrefix ,
213- boolean incrementSerially ) {
210+ String newIterVar2Prefix ,
211+ String newAccuVarPrefix ) {
214212 CelNavigableMutableAst navigableMutableAst = CelNavigableMutableAst .fromAst (ast );
215213 Predicate <CelNavigableMutableExpr > comprehensionIdentifierPredicate = x -> true ;
216214 comprehensionIdentifierPredicate =
217215 comprehensionIdentifierPredicate
218216 .and (node -> node .getKind ().equals (Kind .COMPREHENSION ))
219- .and (node -> !node .expr ().comprehension ().iterVar ().startsWith (newIterVarPrefix ))
220- .and (node -> !node .expr ().comprehension ().accuVar ().startsWith (newAccuVarPrefix ));
221-
217+ .and (node -> !node .expr ().comprehension ().iterVar ().startsWith (newIterVarPrefix + ":" ))
218+ .and (node -> !node .expr ().comprehension ().accuVar ().startsWith (newAccuVarPrefix + ":" ))
219+ . and ( node -> ! node . expr (). comprehension (). iterVar2 (). startsWith ( newIterVar2Prefix + ":" ));
222220 LinkedHashMap <CelNavigableMutableExpr , MangledComprehensionType > comprehensionsToMangle =
223221 navigableMutableAst
224222 .getRoot ()
@@ -231,20 +229,25 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
231229 // Ensure the iter_var or the comprehension result is actually referenced in the
232230 // loop_step. If it's not, we can skip mangling.
233231 String iterVar = node .expr ().comprehension ().iterVar ();
232+ String iterVar2 = node .expr ().comprehension ().iterVar2 ();
234233 String result = node .expr ().comprehension ().result ().ident ().name ();
235234 return CelNavigableMutableExpr .fromExpr (node .expr ().comprehension ().loopStep ())
236235 .allNodes ()
237236 .filter (subNode -> subNode .getKind ().equals (Kind .IDENT ))
238237 .map (subNode -> subNode .expr ().ident ())
239238 .anyMatch (
240- ident -> ident .name ().contains (iterVar ) || ident .name ().contains (result ));
239+ ident ->
240+ ident .name ().contains (iterVar )
241+ || ident .name ().contains (iterVar2 )
242+ || ident .name ().contains (result ));
241243 })
242244 .collect (
243245 Collectors .toMap (
244246 k -> k ,
245247 v -> {
246248 CelMutableComprehension comprehension = v .expr ().comprehension ();
247249 String iterVar = comprehension .iterVar ();
250+ String iterVar2 = comprehension .iterVar2 ();
248251 // Identifiers to mangle could be the iteration variable, comprehension
249252 // result or both, but at least one has to exist.
250253 // As an example, [1,2].map(i, 3) would result in optional.empty for iteration
@@ -258,6 +261,16 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
258261 && loopStepNode .expr ().ident ().name ().equals (iterVar ))
259262 .map (CelNavigableMutableExpr ::id )
260263 .findAny ();
264+ Optional <Long > iterVar2Id =
265+ CelNavigableMutableExpr .fromExpr (comprehension .loopStep ())
266+ .allNodes ()
267+ .filter (
268+ loopStepNode ->
269+ iterVar2 .isEmpty ()
270+ && loopStepNode .getKind ().equals (Kind .IDENT )
271+ && loopStepNode .expr ().ident ().name ().equals (iterVar2 ))
272+ .map (CelNavigableMutableExpr ::id )
273+ .findAny ();
261274 Optional <CelType > iterVarType =
262275 iterVarId .map (
263276 id ->
@@ -269,6 +282,17 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
269282 "Checked type not present for iteration"
270283 + " variable: "
271284 + iterVarId )));
285+ Optional <CelType > iterVar2Type =
286+ iterVar2Id .map (
287+ id ->
288+ navigableMutableAst
289+ .getType (id )
290+ .orElseThrow (
291+ () ->
292+ new NoSuchElementException (
293+ "Checked type not present for iteration"
294+ + " variable: "
295+ + iterVarId )));
272296 CelType resultType =
273297 navigableMutableAst
274298 .getType (comprehension .result ().id ())
@@ -278,7 +302,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
278302 "Result type was not present for the comprehension ID: "
279303 + comprehension .result ().id ()));
280304
281- return MangledComprehensionType .of (iterVarType , resultType );
305+ return MangledComprehensionType .of (iterVarType , iterVar2Type , resultType );
282306 },
283307 (x , y ) -> {
284308 throw new IllegalStateException (
@@ -301,38 +325,25 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
301325 MangledComprehensionType comprehensionEntryType = comprehensionEntry .getValue ();
302326
303327 CelMutableExpr comprehensionExpr = comprehensionNode .expr ();
304- MangledComprehensionName mangledComprehensionName ;
305- if (incrementSerially ) {
306- // In case of applying CSE via cascaded cel.binds, not only is mangling based on level/types
307- // meaningless (because all comprehensions are nested anyways, thus all indices would be
308- // uinque),
309- // it can lead to an erroneous result due to extracting a common subexpr with accu_var at
310- // the wrong scope.
311- // Example: "[1].exists(k, k > 1) && [2].exists(l, l > 1). The loop step for both branches
312- // are identical, but shouldn't be extracted.
313- String mangledIterVarName = newIterVarPrefix + ":" + iterCount ;
314- String mangledResultName = newAccuVarPrefix + ":" + iterCount ;
315- mangledComprehensionName =
316- MangledComprehensionName .of (mangledIterVarName , mangledResultName );
317- mangledIdentNamesToType .put (mangledComprehensionName , comprehensionEntry .getValue ());
318- } else {
319- mangledComprehensionName =
320- getMangledComprehensionName (
321- newIterVarPrefix ,
322- newAccuVarPrefix ,
323- comprehensionNode ,
324- comprehensionLevelToType ,
325- comprehensionEntryType );
326- }
328+ MangledComprehensionName mangledComprehensionName =
329+ getMangledComprehensionName (
330+ newIterVarPrefix ,
331+ newIterVar2Prefix ,
332+ newAccuVarPrefix ,
333+ comprehensionNode ,
334+ comprehensionLevelToType ,
335+ comprehensionEntryType );
327336 mangledIdentNamesToType .put (mangledComprehensionName , comprehensionEntryType );
328337
329338 String iterVar = comprehensionExpr .comprehension ().iterVar ();
339+ String iterVar2 = comprehensionExpr .comprehension ().iterVar2 ();
330340 String accuVar = comprehensionExpr .comprehension ().accuVar ();
331341 mutatedComprehensionExpr =
332342 mangleIdentsInComprehensionExpr (
333343 mutatedComprehensionExpr ,
334344 comprehensionExpr ,
335345 iterVar ,
346+ iterVar2 ,
336347 accuVar ,
337348 mangledComprehensionName );
338349 // Repeat the mangling process for the macro source.
@@ -341,6 +352,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
341352 newSource ,
342353 mutatedComprehensionExpr ,
343354 iterVar ,
355+ iterVar2 ,
344356 mangledComprehensionName ,
345357 comprehensionExpr .id ());
346358 iterCount ++;
@@ -360,6 +372,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
360372
361373 private static MangledComprehensionName getMangledComprehensionName (
362374 String newIterVarPrefix ,
375+ String newIterVar2Prefix ,
363376 String newResultPrefix ,
364377 CelNavigableMutableExpr comprehensionNode ,
365378 Table <Integer , MangledComprehensionType , MangledComprehensionName > comprehensionLevelToType ,
@@ -377,7 +390,15 @@ private static MangledComprehensionName getMangledComprehensionName(
377390 newIterVarPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx ;
378391 String mangledResultName =
379392 newResultPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx ;
380- mangledComprehensionName = MangledComprehensionName .of (mangledIterVarName , mangledResultName );
393+ String mangledIterVar2Name = "" ;
394+
395+ if (!newIterVar2Prefix .isEmpty ()) {
396+ mangledIterVar2Name =
397+ newIterVar2Prefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx ;
398+ }
399+
400+ mangledComprehensionName =
401+ MangledComprehensionName .of (mangledIterVarName , mangledIterVar2Name , mangledResultName );
381402 comprehensionLevelToType .put (
382403 comprehensionNestingLevel , comprehensionEntryType , mangledComprehensionName );
383404 }
@@ -530,6 +551,7 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
530551 CelMutableExpr root ,
531552 CelMutableExpr comprehensionExpr ,
532553 String originalIterVar ,
554+ String originalIterVar2 ,
533555 String originalAccuVar ,
534556 MangledComprehensionName mangledComprehensionName ) {
535557 CelMutableComprehension comprehension = comprehensionExpr .comprehension ();
@@ -538,11 +560,18 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
538560 replaceIdentName (comprehensionExpr , originalAccuVar , mangledComprehensionName .resultName ());
539561
540562 comprehension .setIterVar (mangledComprehensionName .iterVarName ());
563+
541564 // Most standard macros set accu_var as __result__, but not all (ex: cel.bind).
542565 if (comprehension .accuVar ().equals (originalAccuVar )) {
543566 comprehension .setAccuVar (mangledComprehensionName .resultName ());
544567 }
545568
569+ if (!originalIterVar2 .isEmpty ()) {
570+ comprehension .setIterVar2 (mangledComprehensionName .iterVar2Name ());
571+ replaceIdentName (
572+ comprehension .loopStep (), originalIterVar2 , mangledComprehensionName .iterVar2Name ());
573+ }
574+
546575 return mutateExpr (NO_OP_ID_GENERATOR , root , comprehensionExpr , comprehensionExpr .id ());
547576 }
548577
@@ -581,6 +610,7 @@ private CelMutableSource mangleIdentsInMacroSource(
581610 CelMutableSource sourceBuilder ,
582611 CelMutableExpr mutatedComprehensionExpr ,
583612 String originalIterVar ,
613+ String originalIterVar2 ,
584614 MangledComprehensionName mangledComprehensionName ,
585615 long originalComprehensionId ) {
586616 if (!sourceBuilder .getMacroCalls ().containsKey (originalComprehensionId )) {
@@ -604,14 +634,25 @@ private CelMutableSource mangleIdentsInMacroSource(
604634 // macro call expression.
605635 CelMutableExpr identToMangle = macroExpr .call ().args ().get (0 );
606636 if (identToMangle .ident ().name ().equals (originalIterVar )) {
607- // if (identToMangle.identOrDefault().name().equals(originalIterVar)) {
608637 macroExpr =
609638 mutateExpr (
610639 NO_OP_ID_GENERATOR ,
611640 macroExpr ,
612641 CelMutableExpr .ofIdent (mangledComprehensionName .iterVarName ()),
613642 identToMangle .id ());
614643 }
644+ if (!originalIterVar2 .isEmpty ()) {
645+ // Similarly by convention, iter_var2 is always the second argument of the macro call.
646+ identToMangle = macroExpr .call ().args ().get (1 );
647+ if (identToMangle .ident ().name ().equals (originalIterVar2 )) {
648+ macroExpr =
649+ mutateExpr (
650+ NO_OP_ID_GENERATOR ,
651+ macroExpr ,
652+ CelMutableExpr .ofIdent (mangledComprehensionName .iterVar2Name ()),
653+ identToMangle .id ());
654+ }
655+ }
615656
616657 newSource .addMacroCalls (originalComprehensionId , macroExpr );
617658
@@ -815,7 +856,7 @@ private static void unwrapListArgumentsInMacroCallExpr(
815856 newMacroCall .addArgs (
816857 existingMacroCall .args ().get (0 )); // iter_var is first argument of the call by convention
817858
818- CelMutableList extraneousList = null ;
859+ CelMutableList extraneousList ;
819860 if (loopStepArgs .size () == 2 ) {
820861 extraneousList = loopStepArgs .get (1 ).list ();
821862 } else {
@@ -895,14 +936,22 @@ private static MangledComprehensionAst of(
895936 @ AutoValue
896937 public abstract static class MangledComprehensionType {
897938
898- /** Type of iter_var */
939+ /**
940+ * Type of iter_var. Empty if iter_var is not referenced in the expression anywhere (ex: "i" in
941+ * "[1].exists(i, true)"
942+ */
899943 public abstract Optional <CelType > iterVarType ();
900944
945+ /** Type of iter_var2. */
946+ public abstract Optional <CelType > iterVar2Type ();
947+
901948 /** Type of comprehension result */
902949 public abstract CelType resultType ();
903950
904- private static MangledComprehensionType of (Optional <CelType > iterVarType , CelType resultType ) {
905- return new AutoValue_AstMutator_MangledComprehensionType (iterVarType , resultType );
951+ private static MangledComprehensionType of (
952+ Optional <CelType > iterVarType , Optional <CelType > iterVarType2 , CelType resultType ) {
953+ return new AutoValue_AstMutator_MangledComprehensionType (
954+ iterVarType , iterVarType2 , resultType );
906955 }
907956 }
908957
@@ -916,11 +965,16 @@ public abstract static class MangledComprehensionName {
916965 /** Mangled name for iter_var */
917966 public abstract String iterVarName ();
918967
968+ /** Mangled name for iter_var2 */
969+ public abstract String iterVar2Name ();
970+
919971 /** Mangled name for comprehension result */
920972 public abstract String resultName ();
921973
922- private static MangledComprehensionName of (String iterVarName , String resultName ) {
923- return new AutoValue_AstMutator_MangledComprehensionName (iterVarName , resultName );
974+ private static MangledComprehensionName of (
975+ String iterVarName , String iterVar2Name , String resultName ) {
976+ return new AutoValue_AstMutator_MangledComprehensionName (
977+ iterVarName , iterVar2Name , resultName );
924978 }
925979 }
926980}
0 commit comments