Skip to content

Commit e73283b

Browse files
maskri17copybara-github
authored andcommitted
Adding runtime support for two variable comprehensions
PiperOrigin-RevId: 799733003
1 parent 1ee7d1d commit e73283b

6 files changed

Lines changed: 324 additions & 22 deletions

File tree

checker/src/main/java/dev/cel/checker/ExprChecker.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
530530
case DYN:
531531
case ERROR:
532532
varType = SimpleType.DYN;
533+
varType2 = SimpleType.DYN;
533534
break;
534535
case TYPE_PARAM:
535536
// Mark the range as DYN to avoid its free variable being associated with the wrong type
@@ -538,6 +539,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
538539
inferenceContext.isAssignable(SimpleType.DYN, rangeType);
539540
// Mark the variable type as DYN.
540541
varType = SimpleType.DYN;
542+
varType2 = SimpleType.DYN;
541543
break;
542544
default:
543545
env.reportError(
@@ -547,6 +549,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
547549
+ "(must be list, map, or dynamic)",
548550
CelTypes.format(rangeType));
549551
varType = SimpleType.DYN;
552+
varType2 = SimpleType.DYN;
550553
break;
551554
}
552555

extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ private static Optional<CelExpr> expandAllMacro(
5959
checkNotNull(exprFactory);
6060
checkNotNull(target);
6161
checkArgument(arguments.size() == 3);
62-
CelExpr arg0 = checkNotNull(arguments.get(0));
62+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
6363
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
64-
return Optional.of(reportArgumentError(exprFactory, arg0));
64+
return Optional.of(arg0);
6565
}
66-
CelExpr arg1 = checkNotNull(arguments.get(1));
66+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
6767
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
68-
return Optional.of(reportArgumentError(exprFactory, arg1));
68+
return Optional.of(arg1);
6969
}
7070
CelExpr arg2 = checkNotNull(arguments.get(2));
7171
CelExpr accuInit = exprFactory.newBoolLiteral(true);
@@ -96,13 +96,13 @@ private static Optional<CelExpr> expandExistsMacro(
9696
checkNotNull(exprFactory);
9797
checkNotNull(target);
9898
checkArgument(arguments.size() == 3);
99-
CelExpr arg0 = checkNotNull(arguments.get(0));
99+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
100100
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
101-
return Optional.of(reportArgumentError(exprFactory, arg0));
101+
return Optional.of(arg0);
102102
}
103-
CelExpr arg1 = checkNotNull(arguments.get(1));
103+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
104104
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
105-
return Optional.of(reportArgumentError(exprFactory, arg1));
105+
return Optional.of(arg1);
106106
}
107107
CelExpr arg2 = checkNotNull(arguments.get(2));
108108
CelExpr accuInit = exprFactory.newBoolLiteral(false);
@@ -135,13 +135,13 @@ private static Optional<CelExpr> expandExistsOneMacro(
135135
checkNotNull(exprFactory);
136136
checkNotNull(target);
137137
checkArgument(arguments.size() == 3);
138-
CelExpr arg0 = checkNotNull(arguments.get(0));
138+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
139139
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
140-
return Optional.of(reportArgumentError(exprFactory, arg0));
140+
return Optional.of(arg0);
141141
}
142-
CelExpr arg1 = checkNotNull(arguments.get(1));
142+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
143143
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
144-
return Optional.of(reportArgumentError(exprFactory, arg1));
144+
return Optional.of(arg1);
145145
}
146146
CelExpr arg2 = checkNotNull(arguments.get(2));
147147
CelExpr accuInit = exprFactory.newIntLiteral(0);
@@ -177,13 +177,13 @@ private static Optional<CelExpr> transformListMacro(
177177
checkNotNull(exprFactory);
178178
checkNotNull(target);
179179
checkArgument(arguments.size() == 3 || arguments.size() == 4);
180-
CelExpr arg0 = checkNotNull(arguments.get(0));
180+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
181181
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
182-
return Optional.of(reportArgumentError(exprFactory, arg0));
182+
return Optional.of(arg0);
183183
}
184-
CelExpr arg1 = checkNotNull(arguments.get(1));
184+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
185185
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
186-
return Optional.of(reportArgumentError(exprFactory, arg1));
186+
return Optional.of(arg1);
187187
}
188188
CelExpr transform;
189189
CelExpr filter = null;
@@ -220,9 +220,32 @@ private static Optional<CelExpr> transformListMacro(
220220
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
221221
}
222222

223+
private static CelExpr validatedIterationVariable(
224+
CelMacroExprFactory exprFactory, CelExpr argument) {
225+
226+
CelExpr arg = checkNotNull(argument);
227+
if (arg.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
228+
return reportArgumentError(exprFactory, arg);
229+
} else if (arg.exprKind().ident().name().equals("__result__")) {
230+
return reportAccumulatorOverwriteError(exprFactory, arg);
231+
} else {
232+
return arg;
233+
}
234+
}
235+
223236
private static CelExpr reportArgumentError(CelMacroExprFactory exprFactory, CelExpr argument) {
224237
return exprFactory.reportError(
225238
CelIssue.formatError(
226239
exprFactory.getSourceLocation(argument), "The argument must be a simple name"));
227240
}
241+
242+
private static CelExpr reportAccumulatorOverwriteError(
243+
CelMacroExprFactory exprFactory, CelExpr argument) {
244+
return exprFactory.reportError(
245+
CelIssue.formatError(
246+
exprFactory.getSourceLocation(argument),
247+
String.format(
248+
"The iteration variable %s overwrites accumulator variable",
249+
argument.ident().name())));
250+
}
228251
}

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,7 @@ public static CelRegexExtensions regex() {
315315
* <p>This will include all functions denoted in {@link CelComprehensionsExtensions.Function},
316316
* including any future additions.
317317
*/
318-
// TODO: Remove visibility restrictions and make this public once the feature is
319-
// ready.
320-
private static CelComprehensionsExtensions comprehensions() {
318+
public static CelComprehensionsExtensions comprehensions() {
321319
return COMPREHENSIONS_EXTENSIONS;
322320
}
323321

extensions/src/test/java/dev/cel/extensions/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ java_library(
3030
"//extensions:sets_function",
3131
"//extensions:strings",
3232
"//parser:macro",
33+
"//parser:unparser",
3334
"//runtime",
3435
"//runtime:function_binding",
3536
"//runtime:interpreter_util",

0 commit comments

Comments
 (0)