Skip to content

Commit e7e9fdf

Browse files
l46kokcopybara-github
authored andcommitted
Enhance CSE to handle two variable comprehensions
PiperOrigin-RevId: 803165826
1 parent bcb02b4 commit e7e9fdf

19 files changed

Lines changed: 11341 additions & 501 deletions

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,9 @@ public static ImmutableSet<String> getAllFunctionNames() {
338338
stream(CelListsExtensions.Function.values())
339339
.map(CelListsExtensions.Function::getFunction),
340340
stream(CelRegexExtensions.Function.values())
341-
.map(CelRegexExtensions.Function::getFunction))
341+
.map(CelRegexExtensions.Function::getFunction),
342+
stream(CelComprehensionsExtensions.Function.values())
343+
.map(CelComprehensionsExtensions.Function::getFunction))
342344
.collect(toImmutableSet());
343345
}
344346

extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ public void getAllFunctionNames() {
187187
"lists.@sortByAssociatedKeys",
188188
"regex.replace",
189189
"regex.extract",
190-
"regex.extractAll");
190+
"regex.extractAll",
191+
"cel.@mapInsert");
191192
}
192193
}

optimizer/src/main/java/dev/cel/optimizer/AstMutator.java

Lines changed: 95 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public final class AstMutator {
6060
private final long iterationLimit;
6161

6262
/**
63-
* Returns a new instance of a AST mutator with the iteration limit set.
63+
* Returns a new instance of an AST mutator with the iteration limit set.
6464
*
6565
* <p>Mutation is performed by walking the existing AST until the expression node to replace is
6666
* found, then the new subtree is walked to complete the mutation. Visiting of each node
@@ -203,22 +203,20 @@ public CelMutableAst renumberIdsConsecutively(CelMutableAst mutableAst) {
203203
* @param newIterVarPrefix Prefix to use for new iteration variable identifier name. For example,
204204
* providing @c will produce @c0:0, @c0:1, @c1:0, @c2:0... as new names.
205205
* @param newAccuVarPrefix Prefix to use for new accumulation variable identifier name.
206-
* @param incrementSerially If true, indices for the mangled variables are incremented serially
207-
* per occurrence regardless of their nesting level or its types.
208206
*/
209207
public MangledComprehensionAst mangleComprehensionIdentifierNames(
210208
CelMutableAst ast,
211209
String newIterVarPrefix,
212-
String newAccuVarPrefix,
213-
boolean incrementSerially) {
210+
String newIterVar2Prefix,
211+
String newAccuVarPrefix) {
214212
CelNavigableMutableAst navigableMutableAst = CelNavigableMutableAst.fromAst(ast);
215213
Predicate<CelNavigableMutableExpr> comprehensionIdentifierPredicate = x -> true;
216214
comprehensionIdentifierPredicate =
217215
comprehensionIdentifierPredicate
218216
.and(node -> node.getKind().equals(Kind.COMPREHENSION))
219-
.and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix))
220-
.and(node -> !node.expr().comprehension().accuVar().startsWith(newAccuVarPrefix));
221-
217+
.and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix + ":"))
218+
.and(node -> !node.expr().comprehension().accuVar().startsWith(newAccuVarPrefix + ":"))
219+
.and(node -> !node.expr().comprehension().iterVar2().startsWith(newIterVar2Prefix + ":"));
222220
LinkedHashMap<CelNavigableMutableExpr, MangledComprehensionType> comprehensionsToMangle =
223221
navigableMutableAst
224222
.getRoot()
@@ -231,20 +229,25 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
231229
// Ensure the iter_var or the comprehension result is actually referenced in the
232230
// loop_step. If it's not, we can skip mangling.
233231
String iterVar = node.expr().comprehension().iterVar();
232+
String iterVar2 = node.expr().comprehension().iterVar2();
234233
String result = node.expr().comprehension().result().ident().name();
235234
return CelNavigableMutableExpr.fromExpr(node.expr().comprehension().loopStep())
236235
.allNodes()
237236
.filter(subNode -> subNode.getKind().equals(Kind.IDENT))
238237
.map(subNode -> subNode.expr().ident())
239238
.anyMatch(
240-
ident -> ident.name().contains(iterVar) || ident.name().contains(result));
239+
ident ->
240+
ident.name().contains(iterVar)
241+
|| ident.name().contains(iterVar2)
242+
|| ident.name().contains(result));
241243
})
242244
.collect(
243245
Collectors.toMap(
244246
k -> k,
245247
v -> {
246248
CelMutableComprehension comprehension = v.expr().comprehension();
247249
String iterVar = comprehension.iterVar();
250+
String iterVar2 = comprehension.iterVar2();
248251
// Identifiers to mangle could be the iteration variable, comprehension
249252
// result or both, but at least one has to exist.
250253
// As an example, [1,2].map(i, 3) would result in optional.empty for iteration
@@ -258,6 +261,16 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
258261
&& loopStepNode.expr().ident().name().equals(iterVar))
259262
.map(CelNavigableMutableExpr::id)
260263
.findAny();
264+
Optional<Long> iterVar2Id =
265+
CelNavigableMutableExpr.fromExpr(comprehension.loopStep())
266+
.allNodes()
267+
.filter(
268+
loopStepNode ->
269+
iterVar2.isEmpty()
270+
&& loopStepNode.getKind().equals(Kind.IDENT)
271+
&& loopStepNode.expr().ident().name().equals(iterVar2))
272+
.map(CelNavigableMutableExpr::id)
273+
.findAny();
261274
Optional<CelType> iterVarType =
262275
iterVarId.map(
263276
id ->
@@ -269,6 +282,17 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
269282
"Checked type not present for iteration"
270283
+ " variable: "
271284
+ iterVarId)));
285+
Optional<CelType> iterVar2Type =
286+
iterVar2Id.map(
287+
id ->
288+
navigableMutableAst
289+
.getType(id)
290+
.orElseThrow(
291+
() ->
292+
new NoSuchElementException(
293+
"Checked type not present for iteration"
294+
+ " variable: "
295+
+ iterVarId)));
272296
CelType resultType =
273297
navigableMutableAst
274298
.getType(comprehension.result().id())
@@ -278,7 +302,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
278302
"Result type was not present for the comprehension ID: "
279303
+ comprehension.result().id()));
280304

281-
return MangledComprehensionType.of(iterVarType, resultType);
305+
return MangledComprehensionType.of(iterVarType, iterVar2Type, resultType);
282306
},
283307
(x, y) -> {
284308
throw new IllegalStateException(
@@ -301,38 +325,25 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
301325
MangledComprehensionType comprehensionEntryType = comprehensionEntry.getValue();
302326

303327
CelMutableExpr comprehensionExpr = comprehensionNode.expr();
304-
MangledComprehensionName mangledComprehensionName;
305-
if (incrementSerially) {
306-
// In case of applying CSE via cascaded cel.binds, not only is mangling based on level/types
307-
// meaningless (because all comprehensions are nested anyways, thus all indices would be
308-
// uinque),
309-
// it can lead to an erroneous result due to extracting a common subexpr with accu_var at
310-
// the wrong scope.
311-
// Example: "[1].exists(k, k > 1) && [2].exists(l, l > 1). The loop step for both branches
312-
// are identical, but shouldn't be extracted.
313-
String mangledIterVarName = newIterVarPrefix + ":" + iterCount;
314-
String mangledResultName = newAccuVarPrefix + ":" + iterCount;
315-
mangledComprehensionName =
316-
MangledComprehensionName.of(mangledIterVarName, mangledResultName);
317-
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntry.getValue());
318-
} else {
319-
mangledComprehensionName =
320-
getMangledComprehensionName(
321-
newIterVarPrefix,
322-
newAccuVarPrefix,
323-
comprehensionNode,
324-
comprehensionLevelToType,
325-
comprehensionEntryType);
326-
}
328+
MangledComprehensionName mangledComprehensionName =
329+
getMangledComprehensionName(
330+
newIterVarPrefix,
331+
newIterVar2Prefix,
332+
newAccuVarPrefix,
333+
comprehensionNode,
334+
comprehensionLevelToType,
335+
comprehensionEntryType);
327336
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntryType);
328337

329338
String iterVar = comprehensionExpr.comprehension().iterVar();
339+
String iterVar2 = comprehensionExpr.comprehension().iterVar2();
330340
String accuVar = comprehensionExpr.comprehension().accuVar();
331341
mutatedComprehensionExpr =
332342
mangleIdentsInComprehensionExpr(
333343
mutatedComprehensionExpr,
334344
comprehensionExpr,
335345
iterVar,
346+
iterVar2,
336347
accuVar,
337348
mangledComprehensionName);
338349
// Repeat the mangling process for the macro source.
@@ -341,6 +352,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
341352
newSource,
342353
mutatedComprehensionExpr,
343354
iterVar,
355+
iterVar2,
344356
mangledComprehensionName,
345357
comprehensionExpr.id());
346358
iterCount++;
@@ -360,6 +372,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
360372

361373
private static MangledComprehensionName getMangledComprehensionName(
362374
String newIterVarPrefix,
375+
String newIterVar2Prefix,
363376
String newResultPrefix,
364377
CelNavigableMutableExpr comprehensionNode,
365378
Table<Integer, MangledComprehensionType, MangledComprehensionName> comprehensionLevelToType,
@@ -377,7 +390,15 @@ private static MangledComprehensionName getMangledComprehensionName(
377390
newIterVarPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
378391
String mangledResultName =
379392
newResultPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
380-
mangledComprehensionName = MangledComprehensionName.of(mangledIterVarName, mangledResultName);
393+
String mangledIterVar2Name = "";
394+
395+
if (!newIterVar2Prefix.isEmpty()) {
396+
mangledIterVar2Name =
397+
newIterVar2Prefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
398+
}
399+
400+
mangledComprehensionName =
401+
MangledComprehensionName.of(mangledIterVarName, mangledIterVar2Name, mangledResultName);
381402
comprehensionLevelToType.put(
382403
comprehensionNestingLevel, comprehensionEntryType, mangledComprehensionName);
383404
}
@@ -530,6 +551,7 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
530551
CelMutableExpr root,
531552
CelMutableExpr comprehensionExpr,
532553
String originalIterVar,
554+
String originalIterVar2,
533555
String originalAccuVar,
534556
MangledComprehensionName mangledComprehensionName) {
535557
CelMutableComprehension comprehension = comprehensionExpr.comprehension();
@@ -538,11 +560,18 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
538560
replaceIdentName(comprehensionExpr, originalAccuVar, mangledComprehensionName.resultName());
539561

540562
comprehension.setIterVar(mangledComprehensionName.iterVarName());
563+
541564
// Most standard macros set accu_var as __result__, but not all (ex: cel.bind).
542565
if (comprehension.accuVar().equals(originalAccuVar)) {
543566
comprehension.setAccuVar(mangledComprehensionName.resultName());
544567
}
545568

569+
if (!originalIterVar2.isEmpty()) {
570+
comprehension.setIterVar2(mangledComprehensionName.iterVar2Name());
571+
replaceIdentName(
572+
comprehension.loopStep(), originalIterVar2, mangledComprehensionName.iterVar2Name());
573+
}
574+
546575
return mutateExpr(NO_OP_ID_GENERATOR, root, comprehensionExpr, comprehensionExpr.id());
547576
}
548577

@@ -581,6 +610,7 @@ private CelMutableSource mangleIdentsInMacroSource(
581610
CelMutableSource sourceBuilder,
582611
CelMutableExpr mutatedComprehensionExpr,
583612
String originalIterVar,
613+
String originalIterVar2,
584614
MangledComprehensionName mangledComprehensionName,
585615
long originalComprehensionId) {
586616
if (!sourceBuilder.getMacroCalls().containsKey(originalComprehensionId)) {
@@ -604,14 +634,25 @@ private CelMutableSource mangleIdentsInMacroSource(
604634
// macro call expression.
605635
CelMutableExpr identToMangle = macroExpr.call().args().get(0);
606636
if (identToMangle.ident().name().equals(originalIterVar)) {
607-
// if (identToMangle.identOrDefault().name().equals(originalIterVar)) {
608637
macroExpr =
609638
mutateExpr(
610639
NO_OP_ID_GENERATOR,
611640
macroExpr,
612641
CelMutableExpr.ofIdent(mangledComprehensionName.iterVarName()),
613642
identToMangle.id());
614643
}
644+
if (!originalIterVar2.isEmpty()) {
645+
// Similarly by convention, iter_var2 is always the second argument of the macro call.
646+
identToMangle = macroExpr.call().args().get(1);
647+
if (identToMangle.ident().name().equals(originalIterVar2)) {
648+
macroExpr =
649+
mutateExpr(
650+
NO_OP_ID_GENERATOR,
651+
macroExpr,
652+
CelMutableExpr.ofIdent(mangledComprehensionName.iterVar2Name()),
653+
identToMangle.id());
654+
}
655+
}
615656

616657
newSource.addMacroCalls(originalComprehensionId, macroExpr);
617658

@@ -815,7 +856,7 @@ private static void unwrapListArgumentsInMacroCallExpr(
815856
newMacroCall.addArgs(
816857
existingMacroCall.args().get(0)); // iter_var is first argument of the call by convention
817858

818-
CelMutableList extraneousList = null;
859+
CelMutableList extraneousList;
819860
if (loopStepArgs.size() == 2) {
820861
extraneousList = loopStepArgs.get(1).list();
821862
} else {
@@ -895,14 +936,22 @@ private static MangledComprehensionAst of(
895936
@AutoValue
896937
public abstract static class MangledComprehensionType {
897938

898-
/** Type of iter_var */
939+
/**
940+
* Type of iter_var. Empty if iter_var is not referenced in the expression anywhere (ex: "i" in
941+
* "[1].exists(i, true)"
942+
*/
899943
public abstract Optional<CelType> iterVarType();
900944

945+
/** Type of iter_var2. */
946+
public abstract Optional<CelType> iterVar2Type();
947+
901948
/** Type of comprehension result */
902949
public abstract CelType resultType();
903950

904-
private static MangledComprehensionType of(Optional<CelType> iterVarType, CelType resultType) {
905-
return new AutoValue_AstMutator_MangledComprehensionType(iterVarType, resultType);
951+
private static MangledComprehensionType of(
952+
Optional<CelType> iterVarType, Optional<CelType> iterVarType2, CelType resultType) {
953+
return new AutoValue_AstMutator_MangledComprehensionType(
954+
iterVarType, iterVarType2, resultType);
906955
}
907956
}
908957

@@ -916,11 +965,16 @@ public abstract static class MangledComprehensionName {
916965
/** Mangled name for iter_var */
917966
public abstract String iterVarName();
918967

968+
/** Mangled name for iter_var2 */
969+
public abstract String iterVar2Name();
970+
919971
/** Mangled name for comprehension result */
920972
public abstract String resultName();
921973

922-
private static MangledComprehensionName of(String iterVarName, String resultName) {
923-
return new AutoValue_AstMutator_MangledComprehensionName(iterVarName, resultName);
974+
private static MangledComprehensionName of(
975+
String iterVarName, String iterVar2Name, String resultName) {
976+
return new AutoValue_AstMutator_MangledComprehensionName(
977+
iterVarName, iterVar2Name, resultName);
924978
}
925979
}
926980
}

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9191
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
9292
private static final String BIND_IDENTIFIER_PREFIX = "@r";
9393
private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
94+
private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
9495
private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
9596
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
9697
private static final String BLOCK_INDEX_PREFIX = "@index";
@@ -136,8 +137,8 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
136137
astMutator.mangleComprehensionIdentifierNames(
137138
astToModify,
138139
MANGLED_COMPREHENSION_ITER_VAR_PREFIX,
139-
MANGLED_COMPREHENSION_ACCU_VAR_PREFIX,
140-
/* incrementSerially= */ false);
140+
MANGLED_COMPREHENSION_ITER_VAR2_PREFIX,
141+
MANGLED_COMPREHENSION_ACCU_VAR_PREFIX);
141142
astToModify = mangledComprehensionAst.mutableAst();
142143
CelMutableSource sourceToModify = astToModify.source();
143144

@@ -197,6 +198,11 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
197198
iterVarType ->
198199
newVarDecls.add(
199200
CelVarDecl.newVarDeclaration(name.iterVarName(), iterVarType)));
201+
type.iterVar2Type()
202+
.ifPresent(
203+
iterVar2Type ->
204+
newVarDecls.add(
205+
CelVarDecl.newVarDeclaration(name.iterVar2Name(), iterVar2Type)));
200206
newVarDecls.add(CelVarDecl.newVarDeclaration(name.resultName(), type.resultType()));
201207
});
202208

0 commit comments

Comments
 (0)