Skip to content

Commit 664c31b

Browse files
l46kokcopybara-github
authored andcommitted
Implement cel.@block for planner
PiperOrigin-RevId: 894210773
1 parent 7d73658 commit 664c31b

File tree

13 files changed

+466
-130
lines changed

13 files changed

+466
-130
lines changed

extensions/src/main/java/dev/cel/extensions/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ java_library(
142142
deps = [
143143
"//common:compiler_common",
144144
"//common/ast",
145+
"//common/types",
145146
"//compiler:compiler_builder",
146147
"//extensions:extension_library",
147148
"//parser:macro",

extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
import com.google.errorprone.annotations.Immutable;
2323
import dev.cel.common.CelFunctionDecl;
2424
import dev.cel.common.CelIssue;
25+
import dev.cel.common.CelOverloadDecl;
2526
import dev.cel.common.ast.CelExpr;
27+
import dev.cel.common.types.ListType;
28+
import dev.cel.common.types.SimpleType;
29+
import dev.cel.common.types.TypeParamType;
2630
import dev.cel.compiler.CelCompilerLibrary;
2731
import dev.cel.parser.CelMacro;
2832
import dev.cel.parser.CelMacroExprFactory;
@@ -62,7 +66,15 @@ public int version() {
6266

6367
@Override
6468
public ImmutableSet<CelFunctionDecl> functions() {
65-
return ImmutableSet.of();
69+
// TODO: Add bindings for block once decorator support is available.
70+
return ImmutableSet.of(
71+
CelFunctionDecl.newFunctionDeclaration(
72+
"cel.@block",
73+
CelOverloadDecl.newGlobalOverload(
74+
"cel_block_list",
75+
TypeParamType.create("T"),
76+
ListType.create(SimpleType.DYN),
77+
TypeParamType.create("T"))));
6678
}
6779

6880
@Override

extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ public void library() {
6363
CelExtensions.getExtensionLibrary("bindings", CelOptions.DEFAULT);
6464
assertThat(library.name()).isEqualTo("bindings");
6565
assertThat(library.latest().version()).isEqualTo(0);
66-
assertThat(library.version(0).functions()).isEmpty();
66+
assertThat(library.version(0).functions().stream().map(CelFunctionDecl::name))
67+
.containsExactly("cel.@block");
6768
assertThat(library.version(0).macros().stream().map(CelMacro::getFunction))
6869
.containsExactly("bind");
6970
}

optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ java_library(
3333
"//parser:unparser",
3434
"//runtime",
3535
"//runtime:function_binding",
36+
"//runtime:partial_vars",
37+
"//runtime:program",
38+
"//runtime:unknown_attributes",
3639
"//testing:baseline_test_case",
3740
"@maven//:junit_junit",
3841
"@maven//:com_google_testparameterinjector_test_parameter_injector",

optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java

Lines changed: 96 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
// import com.google.testing.testsize.MediumTest;
2525
import dev.cel.bundle.Cel;
2626
import dev.cel.bundle.CelBuilder;
27+
import dev.cel.bundle.CelExperimentalFactory;
2728
import dev.cel.bundle.CelFactory;
2829
import dev.cel.common.CelAbstractSyntaxTree;
2930
import dev.cel.common.CelContainer;
@@ -43,6 +44,7 @@
4344
import dev.cel.parser.CelUnparserFactory;
4445
import dev.cel.runtime.CelFunctionBinding;
4546
import dev.cel.testing.BaselineTestCase;
47+
import java.util.EnumSet;
4648
import java.util.Optional;
4749
import org.junit.Before;
4850
import org.junit.Test;
@@ -51,6 +53,50 @@
5153
// @MediumTest
5254
@RunWith(TestParameterInjector.class)
5355
public class SubexpressionOptimizerBaselineTest extends BaselineTestCase {
56+
private enum RuntimeEnv {
57+
LEGACY(setupCelEnv(CelFactory.standardCelBuilder())),
58+
PLANNER(setupCelEnv(CelExperimentalFactory.plannerCelBuilder()));
59+
60+
private final Cel cel;
61+
62+
private static Cel setupCelEnv(CelBuilder celBuilder) {
63+
return celBuilder
64+
.addMessageTypes(TestAllTypes.getDescriptor())
65+
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
66+
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
67+
.setOptions(
68+
CelOptions.current()
69+
.populateMacroCalls(true)
70+
.enableHeterogeneousNumericComparisons(true)
71+
.build())
72+
.addCompilerLibraries(
73+
CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions())
74+
.addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions())
75+
.addFunctionDeclarations(
76+
CelFunctionDecl.newFunctionDeclaration(
77+
"pure_custom_func",
78+
newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)),
79+
CelFunctionDecl.newFunctionDeclaration(
80+
"non_pure_custom_func",
81+
newGlobalOverload(
82+
"non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT)))
83+
.addFunctionBindings(
84+
// This is pure, but for the purposes of excluding it as a CSE candidate, pretend that
85+
// it isn't.
86+
CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val),
87+
CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))
88+
.addVar("x", SimpleType.DYN)
89+
.addVar("y", SimpleType.DYN)
90+
.addVar("opt_x", OptionalType.create(SimpleType.DYN))
91+
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()))
92+
.build();
93+
}
94+
95+
RuntimeEnv(Cel cel) {
96+
this.cel = cel;
97+
}
98+
}
99+
54100
private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser();
55101
private static final TestAllTypes TEST_ALL_TYPES_INPUT =
56102
TestAllTypes.newBuilder()
@@ -67,7 +113,6 @@ public class SubexpressionOptimizerBaselineTest extends BaselineTestCase {
67113
.putMapInt32Int64(2, 2)
68114
.putMapStringString("key", "A")))
69115
.build();
70-
private static final Cel CEL = newCelBuilder().build();
71116

72117
private static final SubexpressionOptimizerOptions OPTIMIZER_COMMON_OPTIONS =
73118
SubexpressionOptimizerOptions.newBuilder()
@@ -90,45 +135,49 @@ protected String baselineFileName() {
90135
return overriddenBaseFilePath;
91136
}
92137

138+
@TestParameter RuntimeEnv runtimeEnv;
139+
93140
@Test
94141
public void allOptimizers_producesSameEvaluationResult(
95142
@TestParameter CseTestOptimizer cseTestOptimizer, @TestParameter CseTestCase cseTestCase)
96143
throws Exception {
97144
skipBaselineVerification();
98-
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
145+
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
99146
ImmutableMap<String, Object> inputMap =
100147
ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L));
101-
Object expectedEvalResult = CEL.createProgram(ast).eval(inputMap);
148+
Object expectedEvalResult = runtimeEnv.cel.createProgram(ast).eval(inputMap);
102149

103-
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast);
150+
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast);
104151

105-
Object optimizedEvalResult = CEL.createProgram(optimizedAst).eval(inputMap);
152+
Object optimizedEvalResult = runtimeEnv.cel.createProgram(optimizedAst).eval(inputMap);
106153
assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult);
107154
}
108155

109156
@Test
110157
public void subexpression_unparsed() throws Exception {
111-
for (CseTestCase cseTestCase : CseTestCase.values()) {
158+
for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) {
112159
testOutput().println("Test case: " + cseTestCase.name());
113160
testOutput().println("Source: " + cseTestCase.source);
114161
testOutput().println("=====>");
115-
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
162+
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
116163
boolean resultPrinted = false;
117164
for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) {
118165
String optimizerName = cseTestOptimizer.name();
119166
CelAbstractSyntaxTree optimizedAst;
120167
try {
121-
optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast);
168+
optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast);
122169
} catch (Exception e) {
123170
testOutput().printf("[%s]: Optimization Error: %s", optimizerName, e);
124171
continue;
125172
}
126173
if (!resultPrinted) {
127174
Object optimizedEvalResult =
128-
CEL.createProgram(optimizedAst)
175+
runtimeEnv
176+
.cel
177+
.createProgram(optimizedAst)
129178
.eval(
130179
ImmutableMap.of(
131-
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)));
180+
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)));
132181
testOutput().println("Result: " + optimizedEvalResult);
133182
resultPrinted = true;
134183
}
@@ -145,22 +194,24 @@ public void subexpression_unparsed() throws Exception {
145194

146195
@Test
147196
public void constfold_before_subexpression_unparsed() throws Exception {
148-
for (CseTestCase cseTestCase : CseTestCase.values()) {
197+
for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) {
149198
testOutput().println("Test case: " + cseTestCase.name());
150199
testOutput().println("Source: " + cseTestCase.source);
151200
testOutput().println("=====>");
152-
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
201+
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
153202
boolean resultPrinted = false;
154-
for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) {
203+
for (CseTestOptimizer cseTestOptimizer : EnumSet.allOf(CseTestOptimizer.class)) {
155204
String optimizerName = cseTestOptimizer.name();
156205
CelAbstractSyntaxTree optimizedAst =
157-
cseTestOptimizer.cseWithConstFoldingOptimizer.optimize(ast);
206+
cseTestOptimizer.newCseWithConstFoldingOptimizer(runtimeEnv).optimize(ast);
158207
if (!resultPrinted) {
159208
Object optimizedEvalResult =
160-
CEL.createProgram(optimizedAst)
209+
runtimeEnv
210+
.cel
211+
.createProgram(optimizedAst)
161212
.eval(
162213
ImmutableMap.of(
163-
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)));
214+
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)));
164215
testOutput().println("Result: " + optimizedEvalResult);
165216
resultPrinted = true;
166217
}
@@ -179,12 +230,13 @@ public void constfold_before_subexpression_unparsed() throws Exception {
179230
public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) throws Exception {
180231
String testBasefileName = "subexpression_ast_" + Ascii.toLowerCase(cseTestOptimizer.name());
181232
overriddenBaseFilePath = String.format("%s%s.baseline", testdataDir(), testBasefileName);
182-
for (CseTestCase cseTestCase : CseTestCase.values()) {
233+
for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) {
183234
testOutput().println("Test case: " + cseTestCase.name());
184235
testOutput().println("Source: " + cseTestCase.source);
185236
testOutput().println("=====>");
186-
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
187-
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast);
237+
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
238+
CelAbstractSyntaxTree optimizedAst =
239+
newCseOptimizer(runtimeEnv.cel, cseTestOptimizer.option).optimize(ast);
188240
testOutput().println(optimizedAst.getExpr());
189241
}
190242
}
@@ -193,7 +245,8 @@ public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer)
193245
public void large_expressions_block_common_subexpr() throws Exception {
194246
CelOptimizer celOptimizer =
195247
newCseOptimizer(
196-
CEL, SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build());
248+
runtimeEnv.cel,
249+
SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build());
197250

198251
runLargeTestCases(celOptimizer);
199252
}
@@ -202,7 +255,7 @@ public void large_expressions_block_common_subexpr() throws Exception {
202255
public void large_expressions_block_recursion_depth_1() throws Exception {
203256
CelOptimizer celOptimizer =
204257
newCseOptimizer(
205-
CEL,
258+
runtimeEnv.cel,
206259
SubexpressionOptimizerOptions.newBuilder()
207260
.populateMacroCalls(true)
208261
.subexpressionMaxRecursionDepth(1)
@@ -215,7 +268,7 @@ public void large_expressions_block_recursion_depth_1() throws Exception {
215268
public void large_expressions_block_recursion_depth_2() throws Exception {
216269
CelOptimizer celOptimizer =
217270
newCseOptimizer(
218-
CEL,
271+
runtimeEnv.cel,
219272
SubexpressionOptimizerOptions.newBuilder()
220273
.populateMacroCalls(true)
221274
.subexpressionMaxRecursionDepth(2)
@@ -228,7 +281,7 @@ public void large_expressions_block_recursion_depth_2() throws Exception {
228281
public void large_expressions_block_recursion_depth_3() throws Exception {
229282
CelOptimizer celOptimizer =
230283
newCseOptimizer(
231-
CEL,
284+
runtimeEnv.cel,
232285
SubexpressionOptimizerOptions.newBuilder()
233286
.populateMacroCalls(true)
234287
.subexpressionMaxRecursionDepth(3)
@@ -238,15 +291,16 @@ public void large_expressions_block_recursion_depth_3() throws Exception {
238291
}
239292

240293
private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception {
241-
for (CseLargeTestCase cseTestCase : CseLargeTestCase.values()) {
294+
for (CseLargeTestCase cseTestCase : EnumSet.allOf(CseLargeTestCase.class)) {
242295
testOutput().println("Test case: " + cseTestCase.name());
243296
testOutput().println("Source: " + cseTestCase.source);
244297
testOutput().println("=====>");
245-
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
246-
298+
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
247299
CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast);
248300
Object optimizedEvalResult =
249-
CEL.createProgram(optimizedAst)
301+
runtimeEnv
302+
.cel
303+
.createProgram(optimizedAst)
250304
.eval(
251305
ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)));
252306
testOutput().println("Result: " + optimizedEvalResult);
@@ -260,33 +314,6 @@ private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception {
260314
}
261315
}
262316

263-
private static CelBuilder newCelBuilder() {
264-
return CelFactory.standardCelBuilder()
265-
.addMessageTypes(TestAllTypes.getDescriptor())
266-
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
267-
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
268-
.setOptions(CelOptions.current().populateMacroCalls(true).build())
269-
.addCompilerLibraries(
270-
CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions())
271-
.addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions())
272-
.addFunctionDeclarations(
273-
CelFunctionDecl.newFunctionDeclaration(
274-
"pure_custom_func",
275-
newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)),
276-
CelFunctionDecl.newFunctionDeclaration(
277-
"non_pure_custom_func",
278-
newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT)))
279-
.addFunctionBindings(
280-
// This is pure, but for the purposes of excluding it as a CSE candidate, pretend that
281-
// it isn't.
282-
CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val),
283-
CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))
284-
.addVar("x", SimpleType.DYN)
285-
.addVar("y", SimpleType.DYN)
286-
.addVar("opt_x", OptionalType.create(SimpleType.DYN))
287-
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()));
288-
}
289-
290317
private static CelOptimizer newCseOptimizer(Cel cel, SubexpressionOptimizerOptions options) {
291318
return CelOptimizerFactory.standardCelOptimizerBuilder(cel)
292319
.addAstOptimizers(SubexpressionOptimizer.newInstance(options))
@@ -315,17 +342,23 @@ private enum CseTestOptimizer {
315342
BLOCK_RECURSION_DEPTH_9(
316343
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(9).build());
317344

318-
private final CelOptimizer cseOptimizer;
319-
private final CelOptimizer cseWithConstFoldingOptimizer;
345+
private final SubexpressionOptimizerOptions option;
320346

321347
CseTestOptimizer(SubexpressionOptimizerOptions option) {
322-
this.cseOptimizer = newCseOptimizer(CEL, option);
323-
this.cseWithConstFoldingOptimizer =
324-
CelOptimizerFactory.standardCelOptimizerBuilder(CEL)
325-
.addAstOptimizers(
326-
ConstantFoldingOptimizer.getInstance(),
327-
SubexpressionOptimizer.newInstance(option))
328-
.build();
348+
this.option = option;
349+
}
350+
351+
// Defers building the optimizer until the test runs
352+
private CelOptimizer newCseOptimizer(RuntimeEnv env) {
353+
return SubexpressionOptimizerBaselineTest.newCseOptimizer(env.cel, option);
354+
}
355+
356+
// Defers building the optimizer until the test runs
357+
private CelOptimizer newCseWithConstFoldingOptimizer(RuntimeEnv env) {
358+
return CelOptimizerFactory.standardCelOptimizerBuilder(env.cel)
359+
.addAstOptimizers(
360+
ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance(option))
361+
.build();
329362
}
330363
}
331364

0 commit comments

Comments
 (0)