Skip to content

Commit 37828da

Browse files
l46kokcopybara-github
authored andcommitted
Fix replaceSubtree to properly populate three arg map macro source
PiperOrigin-RevId: 797536828
1 parent 4d2670d commit 37828da

2 files changed

Lines changed: 43 additions & 4 deletions

File tree

optimizer/src/main/java/dev/cel/optimizer/AstMutator.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -910,21 +910,32 @@ private static void unwrapListArgumentsInMacroCallExpr(
910910

911911
CelMutableExpr loopStepExpr = comprehension.loopStep();
912912
List<CelMutableExpr> loopStepArgs = loopStepExpr.call().args();
913-
if (loopStepArgs.size() != 2) {
913+
if (loopStepArgs.size() != 2 && loopStepArgs.size() != 3) {
914914
throw new IllegalArgumentException(
915915
String.format(
916-
"Expected exactly 2 arguments but got %d instead on expr id: %d",
916+
"Expected exactly 2 or 3 arguments but got %d instead on expr id: %d",
917917
loopStepArgs.size(), loopStepExpr.id()));
918918
}
919919

920920
CelMutableCall existingMacroCall = newMacroCallExpr.call();
921921
CelMutableCall newMacroCall =
922922
existingMacroCall.target().isPresent()
923-
? CelMutableCall.create(existingMacroCall.target().get(), existingMacroCall.function())
923+
? CelMutableCall.create(existingMacroCall.target().get(),
924+
existingMacroCall.function())
924925
: CelMutableCall.create(existingMacroCall.function());
925926
newMacroCall.addArgs(
926927
existingMacroCall.args().get(0)); // iter_var is first argument of the call by convention
927-
newMacroCall.addArgs(loopStepArgs.get(1).list().elements());
928+
929+
CelMutableList extraneousList = null;
930+
if (loopStepArgs.size() == 2) {
931+
extraneousList = loopStepArgs.get(1).list();
932+
} else {
933+
newMacroCall.addArgs(loopStepArgs.get(0));
934+
// For map(x,y,z), z is wrapped in a _+_(@result, [z])
935+
extraneousList = loopStepArgs.get(1).call().args().get(1).list();
936+
}
937+
938+
newMacroCall.addArgs(extraneousList.elements());
928939

929940
newMacroCallExpr.setCall(newMacroCall);
930941
}

optimizer/src/test/java/dev/cel/optimizer/AstMutatorTest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,34 @@ public void replaceSubtree_replaceExtraneousListCreatedByMacro_unparseSuccess()
581581
.containsExactly(2L);
582582
}
583583

584+
@Test
585+
@SuppressWarnings("unchecked") // Test only
586+
public void replaceSubtree_replaceExtraneousListCreatedByThreeArgMacro_unparseSuccess()
587+
throws Exception {
588+
CelAbstractSyntaxTree ast = CEL.compile("[1].map(x, true, 1)").getAst();
589+
CelMutableAst mutableAst = CelMutableAst.fromCelAst(ast);
590+
CelMutableAst mutableAst2 = CelMutableAst.fromCelAst(ast);
591+
592+
// These two mutation are equivalent.
593+
CelAbstractSyntaxTree mutatedAstWithList =
594+
AST_MUTATOR
595+
.replaceSubtree(
596+
mutableAst,
597+
CelMutableExpr.ofList(
598+
CelMutableList.create(CelMutableExpr.ofConstant(CelConstant.ofValue(2L)))),
599+
10L)
600+
.toParsedAst();
601+
CelAbstractSyntaxTree mutatedAstWithConstant =
602+
AST_MUTATOR
603+
.replaceSubtree(mutableAst2, CelMutableExpr.ofConstant(CelConstant.ofValue(2L)), 6L)
604+
.toParsedAst();
605+
606+
assertThat(CEL_UNPARSER.unparse(mutatedAstWithList)).isEqualTo("[1].map(x, true, 2)");
607+
assertThat(CEL_UNPARSER.unparse(mutatedAstWithConstant)).isEqualTo("[1].map(x, true, 2)");
608+
assertThat((List<Long>) CEL.createProgram(CEL.check(mutatedAstWithList).getAst()).eval())
609+
.containsExactly(2L);
610+
}
611+
584612
@Test
585613
public void globalCallExpr_replaceRoot() throws Exception {
586614
// Tree shape (brackets are expr IDs):

0 commit comments

Comments
 (0)