Skip to content

Commit 39bc2b0

Browse files
l46kokcopybara-github
authored andcommitted
Enhance CSE to handle two variable comprehensions
PiperOrigin-RevId: 803165826
1 parent bcb02b4 commit 39bc2b0

19 files changed

Lines changed: 14673 additions & 862 deletions

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,9 @@ public static ImmutableSet<String> getAllFunctionNames() {
338338
stream(CelListsExtensions.Function.values())
339339
.map(CelListsExtensions.Function::getFunction),
340340
stream(CelRegexExtensions.Function.values())
341-
.map(CelRegexExtensions.Function::getFunction))
341+
.map(CelRegexExtensions.Function::getFunction),
342+
stream(CelComprehensionsExtensions.Function.values())
343+
.map(CelComprehensionsExtensions.Function::getFunction))
342344
.collect(toImmutableSet());
343345
}
344346

extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ public void getAllFunctionNames() {
187187
"lists.@sortByAssociatedKeys",
188188
"regex.replace",
189189
"regex.extract",
190-
"regex.extractAll");
190+
"regex.extractAll",
191+
"cel.@mapInsert");
191192
}
192193
}

optimizer/src/main/java/dev/cel/optimizer/AstMutator.java

Lines changed: 93 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public final class AstMutator {
6060
private final long iterationLimit;
6161

6262
/**
63-
* Returns a new instance of a AST mutator with the iteration limit set.
63+
* Returns a new instance of an AST mutator with the iteration limit set.
6464
*
6565
* <p>Mutation is performed by walking the existing AST until the expression node to replace is
6666
* found, then the new subtree is walked to complete the mutation. Visiting of each node
@@ -203,22 +203,22 @@ public CelMutableAst renumberIdsConsecutively(CelMutableAst mutableAst) {
203203
* @param newIterVarPrefix Prefix to use for new iteration variable identifier name. For example,
204204
* providing @c will produce @c0:0, @c0:1, @c1:0, @c2:0... as new names.
205205
* @param newAccuVarPrefix Prefix to use for new accumulation variable identifier name.
206-
* @param incrementSerially If true, indices for the mangled variables are incremented serially
207-
* per occurrence regardless of their nesting level or its types.
208206
*/
209207
public MangledComprehensionAst mangleComprehensionIdentifierNames(
210208
CelMutableAst ast,
211209
String newIterVarPrefix,
212-
String newAccuVarPrefix,
213-
boolean incrementSerially) {
210+
String newIterVar2Prefix,
211+
String newAccuVarPrefix) {
214212
CelNavigableMutableAst navigableMutableAst = CelNavigableMutableAst.fromAst(ast);
215213
Predicate<CelNavigableMutableExpr> comprehensionIdentifierPredicate = x -> true;
216214
comprehensionIdentifierPredicate =
217215
comprehensionIdentifierPredicate
218216
.and(node -> node.getKind().equals(Kind.COMPREHENSION))
219-
.and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix))
220-
.and(node -> !node.expr().comprehension().accuVar().startsWith(newAccuVarPrefix));
221-
217+
.and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix + ":"))
218+
.and(node -> !node.expr().comprehension().accuVar().startsWith(newAccuVarPrefix + ":"))
219+
.and(
220+
node ->
221+
!node.expr().comprehension().iterVar2().startsWith(newIterVar2Prefix + ":"));
222222
LinkedHashMap<CelNavigableMutableExpr, MangledComprehensionType> comprehensionsToMangle =
223223
navigableMutableAst
224224
.getRoot()
@@ -231,20 +231,25 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
231231
// Ensure the iter_var or the comprehension result is actually referenced in the
232232
// loop_step. If it's not, we can skip mangling.
233233
String iterVar = node.expr().comprehension().iterVar();
234+
String iterVar2 = node.expr().comprehension().iterVar2();
234235
String result = node.expr().comprehension().result().ident().name();
235236
return CelNavigableMutableExpr.fromExpr(node.expr().comprehension().loopStep())
236237
.allNodes()
237238
.filter(subNode -> subNode.getKind().equals(Kind.IDENT))
238239
.map(subNode -> subNode.expr().ident())
239240
.anyMatch(
240-
ident -> ident.name().contains(iterVar) || ident.name().contains(result));
241+
ident ->
242+
ident.name().contains(iterVar)
243+
|| ident.name().contains(iterVar2)
244+
|| ident.name().contains(result));
241245
})
242246
.collect(
243247
Collectors.toMap(
244248
k -> k,
245249
v -> {
246250
CelMutableComprehension comprehension = v.expr().comprehension();
247251
String iterVar = comprehension.iterVar();
252+
String iterVar2 = comprehension.iterVar2();
248253
// Identifiers to mangle could be the iteration variable, comprehension
249254
// result or both, but at least one has to exist.
250255
// As an example, [1,2].map(i, 3) would result in optional.empty for iteration
@@ -258,6 +263,16 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
258263
&& loopStepNode.expr().ident().name().equals(iterVar))
259264
.map(CelNavigableMutableExpr::id)
260265
.findAny();
266+
Optional<Long> iterVar2Id =
267+
CelNavigableMutableExpr.fromExpr(comprehension.loopStep())
268+
.allNodes()
269+
.filter(
270+
loopStepNode ->
271+
!iterVar2.isEmpty()
272+
&& loopStepNode.getKind().equals(Kind.IDENT)
273+
&& loopStepNode.expr().ident().name().equals(iterVar2))
274+
.map(CelNavigableMutableExpr::id)
275+
.findAny();
261276
Optional<CelType> iterVarType =
262277
iterVarId.map(
263278
id ->
@@ -269,6 +284,17 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
269284
"Checked type not present for iteration"
270285
+ " variable: "
271286
+ iterVarId)));
287+
Optional<CelType> iterVar2Type =
288+
iterVar2Id.map(
289+
id ->
290+
navigableMutableAst
291+
.getType(id)
292+
.orElseThrow(
293+
() ->
294+
new NoSuchElementException(
295+
"Checked type not present for iteration"
296+
+ " variable: "
297+
+ iterVar2Id)));
272298
CelType resultType =
273299
navigableMutableAst
274300
.getType(comprehension.result().id())
@@ -278,7 +304,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
278304
"Result type was not present for the comprehension ID: "
279305
+ comprehension.result().id()));
280306

281-
return MangledComprehensionType.of(iterVarType, resultType);
307+
return MangledComprehensionType.of(iterVarType, iterVar2Type, resultType);
282308
},
283309
(x, y) -> {
284310
throw new IllegalStateException(
@@ -301,38 +327,25 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
301327
MangledComprehensionType comprehensionEntryType = comprehensionEntry.getValue();
302328

303329
CelMutableExpr comprehensionExpr = comprehensionNode.expr();
304-
MangledComprehensionName mangledComprehensionName;
305-
if (incrementSerially) {
306-
// In case of applying CSE via cascaded cel.binds, not only is mangling based on level/types
307-
// meaningless (because all comprehensions are nested anyways, thus all indices would be
308-
// uinque),
309-
// it can lead to an erroneous result due to extracting a common subexpr with accu_var at
310-
// the wrong scope.
311-
// Example: "[1].exists(k, k > 1) && [2].exists(l, l > 1). The loop step for both branches
312-
// are identical, but shouldn't be extracted.
313-
String mangledIterVarName = newIterVarPrefix + ":" + iterCount;
314-
String mangledResultName = newAccuVarPrefix + ":" + iterCount;
315-
mangledComprehensionName =
316-
MangledComprehensionName.of(mangledIterVarName, mangledResultName);
317-
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntry.getValue());
318-
} else {
319-
mangledComprehensionName =
320-
getMangledComprehensionName(
321-
newIterVarPrefix,
322-
newAccuVarPrefix,
323-
comprehensionNode,
324-
comprehensionLevelToType,
325-
comprehensionEntryType);
326-
}
330+
MangledComprehensionName mangledComprehensionName =
331+
getMangledComprehensionName(
332+
newIterVarPrefix,
333+
newIterVar2Prefix,
334+
newAccuVarPrefix,
335+
comprehensionNode,
336+
comprehensionLevelToType,
337+
comprehensionEntryType);
327338
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntryType);
328339

329340
String iterVar = comprehensionExpr.comprehension().iterVar();
341+
String iterVar2 = comprehensionExpr.comprehension().iterVar2();
330342
String accuVar = comprehensionExpr.comprehension().accuVar();
331343
mutatedComprehensionExpr =
332344
mangleIdentsInComprehensionExpr(
333345
mutatedComprehensionExpr,
334346
comprehensionExpr,
335347
iterVar,
348+
iterVar2,
336349
accuVar,
337350
mangledComprehensionName);
338351
// Repeat the mangling process for the macro source.
@@ -341,6 +354,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
341354
newSource,
342355
mutatedComprehensionExpr,
343356
iterVar,
357+
iterVar2,
344358
mangledComprehensionName,
345359
comprehensionExpr.id());
346360
iterCount++;
@@ -360,6 +374,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
360374

361375
private static MangledComprehensionName getMangledComprehensionName(
362376
String newIterVarPrefix,
377+
String newIterVar2Prefix,
363378
String newResultPrefix,
364379
CelNavigableMutableExpr comprehensionNode,
365380
Table<Integer, MangledComprehensionType, MangledComprehensionName> comprehensionLevelToType,
@@ -377,7 +392,11 @@ private static MangledComprehensionName getMangledComprehensionName(
377392
newIterVarPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
378393
String mangledResultName =
379394
newResultPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
380-
mangledComprehensionName = MangledComprehensionName.of(mangledIterVarName, mangledResultName);
395+
String mangledIterVar2Name =
396+
newIterVar2Prefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
397+
398+
mangledComprehensionName =
399+
MangledComprehensionName.of(mangledIterVarName, mangledIterVar2Name, mangledResultName);
381400
comprehensionLevelToType.put(
382401
comprehensionNestingLevel, comprehensionEntryType, mangledComprehensionName);
383402
}
@@ -530,6 +549,7 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
530549
CelMutableExpr root,
531550
CelMutableExpr comprehensionExpr,
532551
String originalIterVar,
552+
String originalIterVar2,
533553
String originalAccuVar,
534554
MangledComprehensionName mangledComprehensionName) {
535555
CelMutableComprehension comprehension = comprehensionExpr.comprehension();
@@ -538,11 +558,18 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
538558
replaceIdentName(comprehensionExpr, originalAccuVar, mangledComprehensionName.resultName());
539559

540560
comprehension.setIterVar(mangledComprehensionName.iterVarName());
561+
541562
// Most standard macros set accu_var as __result__, but not all (ex: cel.bind).
542563
if (comprehension.accuVar().equals(originalAccuVar)) {
543564
comprehension.setAccuVar(mangledComprehensionName.resultName());
544565
}
545566

567+
if (!originalIterVar2.isEmpty()) {
568+
comprehension.setIterVar2(mangledComprehensionName.iterVar2Name());
569+
replaceIdentName(
570+
comprehension.loopStep(), originalIterVar2, mangledComprehensionName.iterVar2Name());
571+
}
572+
546573
return mutateExpr(NO_OP_ID_GENERATOR, root, comprehensionExpr, comprehensionExpr.id());
547574
}
548575

@@ -581,6 +608,7 @@ private CelMutableSource mangleIdentsInMacroSource(
581608
CelMutableSource sourceBuilder,
582609
CelMutableExpr mutatedComprehensionExpr,
583610
String originalIterVar,
611+
String originalIterVar2,
584612
MangledComprehensionName mangledComprehensionName,
585613
long originalComprehensionId) {
586614
if (!sourceBuilder.getMacroCalls().containsKey(originalComprehensionId)) {
@@ -604,14 +632,25 @@ private CelMutableSource mangleIdentsInMacroSource(
604632
// macro call expression.
605633
CelMutableExpr identToMangle = macroExpr.call().args().get(0);
606634
if (identToMangle.ident().name().equals(originalIterVar)) {
607-
// if (identToMangle.identOrDefault().name().equals(originalIterVar)) {
608635
macroExpr =
609636
mutateExpr(
610637
NO_OP_ID_GENERATOR,
611638
macroExpr,
612639
CelMutableExpr.ofIdent(mangledComprehensionName.iterVarName()),
613640
identToMangle.id());
614641
}
642+
if (!originalIterVar2.isEmpty()) {
643+
// Similarly by convention, iter_var2 is always the second argument of the macro call.
644+
identToMangle = macroExpr.call().args().get(1);
645+
if (identToMangle.ident().name().equals(originalIterVar2)) {
646+
macroExpr =
647+
mutateExpr(
648+
NO_OP_ID_GENERATOR,
649+
macroExpr,
650+
CelMutableExpr.ofIdent(mangledComprehensionName.iterVar2Name()),
651+
identToMangle.id());
652+
}
653+
}
615654

616655
newSource.addMacroCalls(originalComprehensionId, macroExpr);
617656

@@ -815,7 +854,7 @@ private static void unwrapListArgumentsInMacroCallExpr(
815854
newMacroCall.addArgs(
816855
existingMacroCall.args().get(0)); // iter_var is first argument of the call by convention
817856

818-
CelMutableList extraneousList = null;
857+
CelMutableList extraneousList;
819858
if (loopStepArgs.size() == 2) {
820859
extraneousList = loopStepArgs.get(1).list();
821860
} else {
@@ -895,14 +934,22 @@ private static MangledComprehensionAst of(
895934
@AutoValue
896935
public abstract static class MangledComprehensionType {
897936

898-
/** Type of iter_var */
937+
/**
938+
* Type of iter_var. Empty if iter_var is not referenced in the expression anywhere (ex: "i" in
939+
* "[1].exists(i, true)"
940+
*/
899941
public abstract Optional<CelType> iterVarType();
900942

943+
/** Type of iter_var2. */
944+
public abstract Optional<CelType> iterVar2Type();
945+
901946
/** Type of comprehension result */
902947
public abstract CelType resultType();
903948

904-
private static MangledComprehensionType of(Optional<CelType> iterVarType, CelType resultType) {
905-
return new AutoValue_AstMutator_MangledComprehensionType(iterVarType, resultType);
949+
private static MangledComprehensionType of(
950+
Optional<CelType> iterVarType, Optional<CelType> iterVarType2, CelType resultType) {
951+
return new AutoValue_AstMutator_MangledComprehensionType(
952+
iterVarType, iterVarType2, resultType);
906953
}
907954
}
908955

@@ -916,11 +963,16 @@ public abstract static class MangledComprehensionName {
916963
/** Mangled name for iter_var */
917964
public abstract String iterVarName();
918965

966+
/** Mangled name for iter_var2 */
967+
public abstract String iterVar2Name();
968+
919969
/** Mangled name for comprehension result */
920970
public abstract String resultName();
921971

922-
private static MangledComprehensionName of(String iterVarName, String resultName) {
923-
return new AutoValue_AstMutator_MangledComprehensionName(iterVarName, resultName);
972+
private static MangledComprehensionName of(
973+
String iterVarName, String iterVar2Name, String resultName) {
974+
return new AutoValue_AstMutator_MangledComprehensionName(
975+
iterVarName, iterVar2Name, resultName);
924976
}
925977
}
926978
}

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9191
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
9292
private static final String BIND_IDENTIFIER_PREFIX = "@r";
9393
private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
94+
private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
9495
private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
9596
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
9697
private static final String BLOCK_INDEX_PREFIX = "@index";
@@ -136,8 +137,8 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
136137
astMutator.mangleComprehensionIdentifierNames(
137138
astToModify,
138139
MANGLED_COMPREHENSION_ITER_VAR_PREFIX,
139-
MANGLED_COMPREHENSION_ACCU_VAR_PREFIX,
140-
/* incrementSerially= */ false);
140+
MANGLED_COMPREHENSION_ITER_VAR2_PREFIX,
141+
MANGLED_COMPREHENSION_ACCU_VAR_PREFIX);
141142
astToModify = mangledComprehensionAst.mutableAst();
142143
CelMutableSource sourceToModify = astToModify.source();
143144

@@ -197,6 +198,12 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
197198
iterVarType ->
198199
newVarDecls.add(
199200
CelVarDecl.newVarDeclaration(name.iterVarName(), iterVarType)));
201+
type.iterVar2Type()
202+
.ifPresent(
203+
iterVar2Type ->
204+
newVarDecls.add(
205+
CelVarDecl.newVarDeclaration(name.iterVar2Name(), iterVar2Type)));
206+
200207
newVarDecls.add(CelVarDecl.newVarDeclaration(name.resultName(), type.resultType()));
201208
});
202209

@@ -446,16 +453,16 @@ private boolean containsComprehensionIdentInSubexpr(CelNavigableMutableExpr navE
446453
navExpr
447454
.allNodes()
448455
.filter(
449-
node ->
450-
node.getKind().equals(Kind.IDENT)
451-
&& (node.expr()
452-
.ident()
453-
.name()
454-
.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
455-
|| node.expr()
456-
.ident()
457-
.name()
458-
.startsWith(MANGLED_COMPREHENSION_ACCU_VAR_PREFIX)))
456+
node -> {
457+
if (!node.getKind().equals(Kind.IDENT)) {
458+
return false;
459+
}
460+
461+
String identName = node.expr().ident().name();
462+
return identName.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
463+
|| identName.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX)
464+
|| identName.startsWith(MANGLED_COMPREHENSION_ACCU_VAR_PREFIX);
465+
})
459466
.collect(toImmutableList());
460467

461468
if (comprehensionIdents.isEmpty()) {

0 commit comments

Comments
 (0)