Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extensions/src/main/java/dev/cel/extensions/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ java_library(
deps = [
"//common:compiler_common",
"//common/ast",
"//common/types",
"//compiler:compiler_builder",
"//extensions:extension_library",
"//parser:macro",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
import com.google.errorprone.annotations.Immutable;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelIssue;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.types.ListType;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.TypeParamType;
import dev.cel.compiler.CelCompilerLibrary;
import dev.cel.parser.CelMacro;
import dev.cel.parser.CelMacroExprFactory;
Expand Down Expand Up @@ -62,7 +66,15 @@ public int version() {

@Override
public ImmutableSet<CelFunctionDecl> functions() {
return ImmutableSet.of();
// TODO: Add bindings for block once decorator support is available.
return ImmutableSet.of(
CelFunctionDecl.newFunctionDeclaration(
"cel.@block",
CelOverloadDecl.newGlobalOverload(
"cel_block_list",
TypeParamType.create("T"),
ListType.create(SimpleType.DYN),
TypeParamType.create("T"))));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public void library() {
CelExtensions.getExtensionLibrary("bindings", CelOptions.DEFAULT);
assertThat(library.name()).isEqualTo("bindings");
assertThat(library.latest().version()).isEqualTo(0);
assertThat(library.version(0).functions()).isEmpty();
assertThat(library.version(0).functions().stream().map(CelFunctionDecl::name))
.containsExactly("cel.@block");
assertThat(library.version(0).macros().stream().map(CelMacro::getFunction))
.containsExactly("bind");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ java_library(
"//parser:unparser",
"//runtime",
"//runtime:function_binding",
"//runtime:partial_vars",
"//runtime:program",
"//runtime:unknown_attributes",
"//testing:baseline_test_case",
"@maven//:junit_junit",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
// import com.google.testing.testsize.MediumTest;
import dev.cel.bundle.Cel;
import dev.cel.bundle.CelBuilder;
import dev.cel.bundle.CelExperimentalFactory;
import dev.cel.bundle.CelFactory;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelContainer;
Expand All @@ -43,6 +44,7 @@
import dev.cel.parser.CelUnparserFactory;
import dev.cel.runtime.CelFunctionBinding;
import dev.cel.testing.BaselineTestCase;
import java.util.EnumSet;
import java.util.Optional;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -51,6 +53,50 @@
// @MediumTest
@RunWith(TestParameterInjector.class)
public class SubexpressionOptimizerBaselineTest extends BaselineTestCase {
private enum RuntimeEnv {
LEGACY(setupCelEnv(CelFactory.standardCelBuilder())),
PLANNER(setupCelEnv(CelExperimentalFactory.plannerCelBuilder()));

private final Cel cel;

private static Cel setupCelEnv(CelBuilder celBuilder) {
return celBuilder
.addMessageTypes(TestAllTypes.getDescriptor())
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.setOptions(
CelOptions.current()
.populateMacroCalls(true)
.enableHeterogeneousNumericComparisons(true)
.build())
.addCompilerLibraries(
CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions())
.addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"pure_custom_func",
newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)),
CelFunctionDecl.newFunctionDeclaration(
"non_pure_custom_func",
newGlobalOverload(
"non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT)))
.addFunctionBindings(
// This is pure, but for the purposes of excluding it as a CSE candidate, pretend that
// it isn't.
CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val),
CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))
.addVar("x", SimpleType.DYN)
.addVar("y", SimpleType.DYN)
.addVar("opt_x", OptionalType.create(SimpleType.DYN))
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()))
.build();
}

RuntimeEnv(Cel cel) {
this.cel = cel;
}
}

private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser();
private static final TestAllTypes TEST_ALL_TYPES_INPUT =
TestAllTypes.newBuilder()
Expand All @@ -67,7 +113,6 @@ public class SubexpressionOptimizerBaselineTest extends BaselineTestCase {
.putMapInt32Int64(2, 2)
.putMapStringString("key", "A")))
.build();
private static final Cel CEL = newCelBuilder().build();

private static final SubexpressionOptimizerOptions OPTIMIZER_COMMON_OPTIONS =
SubexpressionOptimizerOptions.newBuilder()
Expand All @@ -90,45 +135,49 @@ protected String baselineFileName() {
return overriddenBaseFilePath;
}

@TestParameter RuntimeEnv runtimeEnv;

@Test
public void allOptimizers_producesSameEvaluationResult(
@TestParameter CseTestOptimizer cseTestOptimizer, @TestParameter CseTestCase cseTestCase)
throws Exception {
skipBaselineVerification();
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
ImmutableMap<String, Object> inputMap =
ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L));
Object expectedEvalResult = CEL.createProgram(ast).eval(inputMap);
Object expectedEvalResult = runtimeEnv.cel.createProgram(ast).eval(inputMap);

CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast);
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast);

Object optimizedEvalResult = CEL.createProgram(optimizedAst).eval(inputMap);
Object optimizedEvalResult = runtimeEnv.cel.createProgram(optimizedAst).eval(inputMap);
assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult);
}

@Test
public void subexpression_unparsed() throws Exception {
for (CseTestCase cseTestCase : CseTestCase.values()) {
for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) {
testOutput().println("Test case: " + cseTestCase.name());
testOutput().println("Source: " + cseTestCase.source);
testOutput().println("=====>");
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
boolean resultPrinted = false;
for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) {
String optimizerName = cseTestOptimizer.name();
CelAbstractSyntaxTree optimizedAst;
try {
optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast);
optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast);
} catch (Exception e) {
testOutput().printf("[%s]: Optimization Error: %s", optimizerName, e);
continue;
}
if (!resultPrinted) {
Object optimizedEvalResult =
CEL.createProgram(optimizedAst)
runtimeEnv
.cel
.createProgram(optimizedAst)
.eval(
ImmutableMap.of(
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)));
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)));
testOutput().println("Result: " + optimizedEvalResult);
resultPrinted = true;
}
Expand All @@ -145,22 +194,24 @@ public void subexpression_unparsed() throws Exception {

@Test
public void constfold_before_subexpression_unparsed() throws Exception {
for (CseTestCase cseTestCase : CseTestCase.values()) {
for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) {
testOutput().println("Test case: " + cseTestCase.name());
testOutput().println("Source: " + cseTestCase.source);
testOutput().println("=====>");
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
boolean resultPrinted = false;
for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) {
for (CseTestOptimizer cseTestOptimizer : EnumSet.allOf(CseTestOptimizer.class)) {
String optimizerName = cseTestOptimizer.name();
CelAbstractSyntaxTree optimizedAst =
cseTestOptimizer.cseWithConstFoldingOptimizer.optimize(ast);
cseTestOptimizer.newCseWithConstFoldingOptimizer(runtimeEnv).optimize(ast);
if (!resultPrinted) {
Object optimizedEvalResult =
CEL.createProgram(optimizedAst)
runtimeEnv
.cel
.createProgram(optimizedAst)
.eval(
ImmutableMap.of(
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)));
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)));
testOutput().println("Result: " + optimizedEvalResult);
resultPrinted = true;
}
Expand All @@ -179,12 +230,13 @@ public void constfold_before_subexpression_unparsed() throws Exception {
public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) throws Exception {
String testBasefileName = "subexpression_ast_" + Ascii.toLowerCase(cseTestOptimizer.name());
overriddenBaseFilePath = String.format("%s%s.baseline", testdataDir(), testBasefileName);
for (CseTestCase cseTestCase : CseTestCase.values()) {
for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) {
testOutput().println("Test case: " + cseTestCase.name());
testOutput().println("Source: " + cseTestCase.source);
testOutput().println("=====>");
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast);
CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
CelAbstractSyntaxTree optimizedAst =
newCseOptimizer(runtimeEnv.cel, cseTestOptimizer.option).optimize(ast);
testOutput().println(optimizedAst.getExpr());
}
}
Expand All @@ -193,7 +245,8 @@ public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer)
public void large_expressions_block_common_subexpr() throws Exception {
CelOptimizer celOptimizer =
newCseOptimizer(
CEL, SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build());
runtimeEnv.cel,
SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build());

runLargeTestCases(celOptimizer);
}
Expand All @@ -202,7 +255,7 @@ public void large_expressions_block_common_subexpr() throws Exception {
public void large_expressions_block_recursion_depth_1() throws Exception {
CelOptimizer celOptimizer =
newCseOptimizer(
CEL,
runtimeEnv.cel,
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.subexpressionMaxRecursionDepth(1)
Expand All @@ -215,7 +268,7 @@ public void large_expressions_block_recursion_depth_1() throws Exception {
public void large_expressions_block_recursion_depth_2() throws Exception {
CelOptimizer celOptimizer =
newCseOptimizer(
CEL,
runtimeEnv.cel,
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.subexpressionMaxRecursionDepth(2)
Expand All @@ -228,7 +281,7 @@ public void large_expressions_block_recursion_depth_2() throws Exception {
public void large_expressions_block_recursion_depth_3() throws Exception {
CelOptimizer celOptimizer =
newCseOptimizer(
CEL,
runtimeEnv.cel,
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.subexpressionMaxRecursionDepth(3)
Expand All @@ -238,15 +291,16 @@ public void large_expressions_block_recursion_depth_3() throws Exception {
}

private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception {
for (CseLargeTestCase cseTestCase : CseLargeTestCase.values()) {
for (CseLargeTestCase cseTestCase : EnumSet.allOf(CseLargeTestCase.class)) {
testOutput().println("Test case: " + cseTestCase.name());
testOutput().println("Source: " + cseTestCase.source);
testOutput().println("=====>");
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();

CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst();
CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast);
Object optimizedEvalResult =
CEL.createProgram(optimizedAst)
runtimeEnv
.cel
.createProgram(optimizedAst)
.eval(
ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)));
testOutput().println("Result: " + optimizedEvalResult);
Expand All @@ -260,33 +314,6 @@ private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception {
}
}

private static CelBuilder newCelBuilder() {
return CelFactory.standardCelBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.setOptions(CelOptions.current().populateMacroCalls(true).build())
.addCompilerLibraries(
CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions())
.addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"pure_custom_func",
newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)),
CelFunctionDecl.newFunctionDeclaration(
"non_pure_custom_func",
newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT)))
.addFunctionBindings(
// This is pure, but for the purposes of excluding it as a CSE candidate, pretend that
// it isn't.
CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val),
CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))
.addVar("x", SimpleType.DYN)
.addVar("y", SimpleType.DYN)
.addVar("opt_x", OptionalType.create(SimpleType.DYN))
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()));
}

private static CelOptimizer newCseOptimizer(Cel cel, SubexpressionOptimizerOptions options) {
return CelOptimizerFactory.standardCelOptimizerBuilder(cel)
.addAstOptimizers(SubexpressionOptimizer.newInstance(options))
Expand Down Expand Up @@ -315,17 +342,23 @@ private enum CseTestOptimizer {
BLOCK_RECURSION_DEPTH_9(
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(9).build());

private final CelOptimizer cseOptimizer;
private final CelOptimizer cseWithConstFoldingOptimizer;
private final SubexpressionOptimizerOptions option;

CseTestOptimizer(SubexpressionOptimizerOptions option) {
this.cseOptimizer = newCseOptimizer(CEL, option);
this.cseWithConstFoldingOptimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(CEL)
.addAstOptimizers(
ConstantFoldingOptimizer.getInstance(),
SubexpressionOptimizer.newInstance(option))
.build();
this.option = option;
}

// Defers building the optimizer until the test runs
private CelOptimizer newCseOptimizer(RuntimeEnv env) {
return SubexpressionOptimizerBaselineTest.newCseOptimizer(env.cel, option);
}

// Defers building the optimizer until the test runs
private CelOptimizer newCseWithConstFoldingOptimizer(RuntimeEnv env) {
return CelOptimizerFactory.standardCelOptimizerBuilder(env.cel)
.addAstOptimizers(
ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance(option))
.build();
}
}

Expand Down
Loading
Loading