Skip to content

Commit 6975536

Browse files
TristonianJonescopybara-github
authored andcommitted
Fix variadic logical operator planning
PiperOrigin-RevId: 930627361
1 parent e51522f commit 6975536

8 files changed

Lines changed: 113 additions & 67 deletions

File tree

conformance/BUILD

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ cc_library(
3232
"//common:ast",
3333
"//common:ast_proto",
3434
"//common:decl_proto_v1alpha1",
35-
"//common:expr",
3635
"//common:source",
3736
"//common:value",
3837
"//common/internal:value_conversion",
@@ -57,8 +56,6 @@ cc_library(
5756
"//extensions/protobuf:enum_adapter",
5857
"//internal:status_macros",
5958
"//parser",
60-
"//parser:macro",
61-
"//parser:macro_expr_factory",
6259
"//parser:macro_registry",
6360
"//parser:options",
6461
"//parser:standard_macros",
@@ -75,8 +72,6 @@ cc_library(
7572
"@com_google_absl//absl/status",
7673
"@com_google_absl//absl/status:statusor",
7774
"@com_google_absl//absl/strings",
78-
"@com_google_absl//absl/types:optional",
79-
"@com_google_absl//absl/types:span",
8075
"@com_google_cel_spec//proto/cel/expr:syntax_cc_proto",
8176
"@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto",
8277
"@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto",
@@ -302,6 +297,24 @@ gen_conformance_tests(
302297
skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED,
303298
)
304299

300+
gen_conformance_tests(
301+
name = "conformance_variadic",
302+
checked = True,
303+
data = _ALL_TESTS,
304+
enable_variadic_logical_operators = True,
305+
modern = True,
306+
skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED,
307+
)
308+
309+
gen_conformance_tests(
310+
name = "conformance_legacy_variadic",
311+
checked = True,
312+
data = _ALL_TESTS,
313+
enable_variadic_logical_operators = True,
314+
modern = False,
315+
skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED,
316+
)
317+
305318
# Generates a bunch of `cc_test` whose names follow the pattern
306319
# `conformance_dashboard_..._{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`.
307320
gen_conformance_tests(

conformance/run.bzl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _conformance_test_name(name, optimize, recursive):
5656
],
5757
)
5858

59-
def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard):
59+
def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators):
6060
args = []
6161
if modern:
6262
args.append("--modern")
@@ -72,12 +72,14 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check,
7272
args.append("--noskip_check")
7373
if dashboard:
7474
args.append("--dashboard")
75+
if enable_variadic_logical_operators:
76+
args.append("--enable_variadic_logical_operators")
7577
return args
7678

77-
def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard):
79+
def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard, enable_variadic_logical_operators):
7880
cc_test(
7981
name = _conformance_test_name(name, optimize, recursive),
80-
args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(rlocationpath {})".format(test) for test in data],
82+
args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators) + ["$(rlocationpath {})".format(test) for test in data],
8183
env = select(
8284
{
8385
"@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)},
@@ -89,18 +91,20 @@ def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_
8991
tags = tags,
9092
)
9193

92-
def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = []):
94+
def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = [], enable_variadic_logical_operators = False):
9395
"""Generates conformance tests.
9496
9597
Args:
9698
name: prefix for all tests
99+
data: textproto targets describing conformance tests
97100
modern: run using modern APIs
98101
checked: whether to apply type checking
99-
data: textproto targets describing conformance tests
102+
select_opt: enable select optimization
103+
dashboard: enable dashboard mode
100104
skip_tests: tests to skip in the format of the cel-spec test runner. See documentation
101105
in github.com/google/cel-spec/tests/simple/simple_test.go
102106
tags: tags added to the generated targets
103-
dashboard: enable dashboard mode
107+
enable_variadic_logical_operators: enable variadic logical operators
104108
"""
105109
skip_check = not checked
106110
tests = []
@@ -119,6 +123,7 @@ def gen_conformance_tests(name, data, modern = False, checked = False, select_op
119123
skip_tests = _expand_tests_to_skip(skip_tests),
120124
tags = tags,
121125
dashboard = dashboard,
126+
enable_variadic_logical_operators = enable_variadic_logical_operators,
122127
)
123128
native.test_suite(
124129
name = name,

conformance/run.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ ABSL_FLAG(std::vector<std::string>, skip_tests, {}, "Tests to skip");
6666
ABSL_FLAG(bool, dashboard, false, "Dashboard mode, ignore test failures");
6767
ABSL_FLAG(bool, skip_check, true, "Skip type checking the expressions");
6868
ABSL_FLAG(bool, select_optimization, false, "Enable select optimization.");
69+
ABSL_FLAG(bool, enable_variadic_logical_operators, false,
70+
"Enable parsing logical AND & OR operators as a single flat variadic "
71+
"call.");
6972

7073
namespace {
7174

@@ -261,6 +264,8 @@ NewConformanceServiceFromFlags() {
261264
.modern = absl::GetFlag(FLAGS_modern),
262265
.recursive = absl::GetFlag(FLAGS_recursive),
263266
.select_optimization = absl::GetFlag(FLAGS_select_optimization),
267+
.enable_variadic_logical_operators =
268+
absl::GetFlag(FLAGS_enable_variadic_logical_operators),
264269
});
265270
ABSL_CHECK_OK(status_or_service);
266271
return std::shared_ptr<cel_conformance::ConformanceServiceInterface>(

conformance/service.cc

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,15 @@ cel::expr::Expr ExtractExpr(
128128

129129
absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request,
130130
conformance::v1alpha1::ParseResponse& response,
131-
bool enable_optional_syntax) {
131+
bool enable_optional_syntax,
132+
bool enable_variadic_logical_operators) {
132133
if (request.cel_source().empty()) {
133134
return absl::InvalidArgumentError("no source code");
134135
}
135136
cel::ParserOptions options;
136137
options.enable_optional_syntax = enable_optional_syntax;
137138
options.enable_quoted_identifiers = true;
139+
options.enable_variadic_logical_operators = enable_variadic_logical_operators;
138140
cel::MacroRegistry macros;
139141
CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options));
140142
CEL_RETURN_IF_ERROR(
@@ -236,7 +238,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena,
236238
class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
237239
public:
238240
static absl::StatusOr<std::unique_ptr<LegacyConformanceServiceImpl>> Create(
239-
bool optimize, bool recursive, bool select_optimization) {
241+
bool optimize, bool recursive, bool select_optimization,
242+
bool enable_variadic_logical_operators) {
240243
static auto* constant_arena = new Arena();
241244

242245
google::protobuf::LinkMessageReflection<
@@ -313,14 +316,15 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
313316
CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions(
314317
builder->GetRegistry(), options));
315318

316-
return absl::WrapUnique(
317-
new LegacyConformanceServiceImpl(std::move(builder)));
319+
return absl::WrapUnique(new LegacyConformanceServiceImpl(
320+
std::move(builder), enable_variadic_logical_operators));
318321
}
319322

320323
void Parse(const conformance::v1alpha1::ParseRequest& request,
321324
conformance::v1alpha1::ParseResponse& response) override {
322325
auto status =
323-
LegacyParse(request, response, /*enable_optional_syntax=*/false);
326+
LegacyParse(request, response, /*enable_optional_syntax=*/false,
327+
enable_variadic_logical_operators_);
324328
if (!status.ok()) {
325329
auto* issue = response.add_issues();
326330
issue->set_code(ToGrpcCode(status.code()));
@@ -418,17 +422,20 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
418422
}
419423

420424
private:
421-
explicit LegacyConformanceServiceImpl(
422-
std::unique_ptr<CelExpressionBuilder> builder)
423-
: builder_(std::move(builder)) {}
425+
LegacyConformanceServiceImpl(std::unique_ptr<CelExpressionBuilder> builder,
426+
bool enable_variadic_logical_operators)
427+
: builder_(std::move(builder)),
428+
enable_variadic_logical_operators_(enable_variadic_logical_operators) {}
424429

425430
std::unique_ptr<CelExpressionBuilder> builder_;
431+
bool enable_variadic_logical_operators_;
426432
};
427433

428434
class ModernConformanceServiceImpl : public ConformanceServiceInterface {
429435
public:
430436
static absl::StatusOr<std::unique_ptr<ModernConformanceServiceImpl>> Create(
431-
bool optimize, bool recursive, bool select_optimization) {
437+
bool optimize, bool recursive, bool select_optimization,
438+
bool enable_variadic_logical_operators) {
432439
google::protobuf::LinkMessageReflection<
433440
cel::expr::conformance::proto3::TestAllTypes>();
434441
google::protobuf::LinkMessageReflection<
@@ -470,8 +477,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
470477
options.max_recursion_depth = 48;
471478
}
472479

473-
return absl::WrapUnique(new ModernConformanceServiceImpl(
474-
options, optimize, select_optimization));
480+
return absl::WrapUnique(
481+
new ModernConformanceServiceImpl(options, optimize, select_optimization,
482+
enable_variadic_logical_operators));
475483
}
476484

477485
absl::StatusOr<std::unique_ptr<const cel::Runtime>> Setup(
@@ -523,7 +531,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
523531
void Parse(const conformance::v1alpha1::ParseRequest& request,
524532
conformance::v1alpha1::ParseResponse& response) override {
525533
auto status =
526-
LegacyParse(request, response, /*enable_optional_syntax=*/true);
534+
LegacyParse(request, response, /*enable_optional_syntax=*/true,
535+
enable_variadic_logical_operators_);
527536
if (!status.ok()) {
528537
auto* issue = response.add_issues();
529538
issue->set_code(ToGrpcCode(status.code()));
@@ -614,10 +623,12 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
614623
private:
615624
ModernConformanceServiceImpl(const RuntimeOptions& options,
616625
bool enable_optimizations,
617-
bool enable_select_optimization)
626+
bool enable_select_optimization,
627+
bool enable_variadic_logical_operators)
618628
: options_(options),
619629
enable_optimizations_(enable_optimizations),
620-
enable_select_optimization_(enable_select_optimization) {}
630+
enable_select_optimization_(enable_select_optimization),
631+
enable_variadic_logical_operators_(enable_variadic_logical_operators) {}
621632

622633
static absl::StatusOr<std::unique_ptr<cel::TraceableProgram>> Plan(
623634
const cel::Runtime& runtime,
@@ -648,6 +659,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
648659
RuntimeOptions options_;
649660
bool enable_optimizations_;
650661
bool enable_select_optimization_;
662+
bool enable_variadic_logical_operators_;
651663
};
652664

653665
} // namespace
@@ -660,10 +672,12 @@ absl::StatusOr<std::unique_ptr<ConformanceServiceInterface>>
660672
NewConformanceService(const ConformanceServiceOptions& options) {
661673
if (options.modern) {
662674
return google::api::expr::runtime::ModernConformanceServiceImpl::Create(
663-
options.optimize, options.recursive, options.select_optimization);
675+
options.optimize, options.recursive, options.select_optimization,
676+
options.enable_variadic_logical_operators);
664677
} else {
665678
return google::api::expr::runtime::LegacyConformanceServiceImpl::Create(
666-
options.optimize, options.recursive, options.select_optimization);
679+
options.optimize, options.recursive, options.select_optimization,
680+
options.enable_variadic_logical_operators);
667681
}
668682
}
669683

conformance/service.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ struct ConformanceServiceOptions {
4646
bool arena;
4747
bool recursive;
4848
bool select_optimization;
49+
bool enable_variadic_logical_operators = false;
4950
};
5051

5152
absl::StatusOr<std::unique_ptr<ConformanceServiceInterface>>

eval/compiler/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ cc_test(
193193
"//internal:status_macros",
194194
"//internal:testing",
195195
"//parser",
196+
"//parser:options",
196197
"//runtime:function",
197198
"//runtime:function_adapter",
198199
"//runtime:runtime_options",

eval/compiler/flat_expr_builder.cc

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,7 +2154,7 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) {
21542154
case BinaryCond::kOr:
21552155
visitor_->ValidateOrError(
21562156
!expr->call_expr().has_target() &&
2157-
expr->call_expr().args().size() == 2,
2157+
expr->call_expr().args().size() >= 2,
21582158
"Invalid argument count for a binary function call.");
21592159
break;
21602160
case BinaryCond::kOptionalOr:
@@ -2172,28 +2172,40 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) {
21722172
return;
21732173
}
21742174
const int last_arg_index = expr->call_expr().args().size() - 1;
2175-
if (short_circuiting_ && arg_num < last_arg_index &&
2176-
(cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) {
2177-
// If first branch evaluation result is enough to determine output,
2178-
// jump over the second branch and provide result of the first argument as
2179-
// final output.
2180-
// Retain pointers to the jump steps so we can update the target after
2181-
// planning the next arguments.
2182-
std::unique_ptr<JumpStepBase> jump_step;
2183-
switch (cond_) {
2184-
case BinaryCond::kAnd:
2185-
jump_step = CreateCondJumpStep(false, true, {}, expr->id());
2186-
break;
2187-
case BinaryCond::kOr:
2188-
jump_step = CreateCondJumpStep(true, true, {}, expr->id());
2189-
break;
2190-
default:
2191-
ABSL_UNREACHABLE();
2175+
if (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) {
2176+
if (arg_num > 0) {
2177+
switch (cond_) {
2178+
case BinaryCond::kAnd:
2179+
visitor_->AddStep(CreateAndStep(expr->id()));
2180+
break;
2181+
case BinaryCond::kOr:
2182+
visitor_->AddStep(CreateOrStep(expr->id()));
2183+
break;
2184+
default:
2185+
break;
2186+
}
2187+
if (short_circuiting_ && !jump_steps_.empty()) {
2188+
visitor_->SetProgressStatusIfError(
2189+
jump_steps_.back().set_target(visitor_->GetCurrentIndex()));
2190+
}
21922191
}
2193-
ProgramStepIndex index = visitor_->GetCurrentIndex();
2194-
if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step));
2195-
jump_step_ptr) {
2196-
jump_steps_.push_back(Jump(index, jump_step_ptr));
2192+
if (short_circuiting_ && arg_num < last_arg_index) {
2193+
std::unique_ptr<JumpStepBase> jump_step;
2194+
switch (cond_) {
2195+
case BinaryCond::kAnd:
2196+
jump_step = CreateCondJumpStep(false, true, {}, expr->id());
2197+
break;
2198+
case BinaryCond::kOr:
2199+
jump_step = CreateCondJumpStep(true, true, {}, expr->id());
2200+
break;
2201+
default:
2202+
ABSL_UNREACHABLE();
2203+
}
2204+
ProgramStepIndex index = visitor_->GetCurrentIndex();
2205+
if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step));
2206+
jump_step_ptr) {
2207+
jump_steps_.push_back(Jump(index, jump_step_ptr));
2208+
}
21972209
}
21982210
}
21992211
}
@@ -2251,17 +2263,9 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) {
22512263
return;
22522264
}
22532265

2254-
int args_count = (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)
2255-
? expr->call_expr().args().size()
2256-
: 2;
2257-
for (int i = 0; i < args_count - 1; ++i) {
2266+
if (cond_ == BinaryCond::kOptionalOr ||
2267+
cond_ == BinaryCond::kOptionalOrValue) {
22582268
switch (cond_) {
2259-
case BinaryCond::kAnd:
2260-
visitor_->AddStep(CreateAndStep(expr->id()));
2261-
break;
2262-
case BinaryCond::kOr:
2263-
visitor_->AddStep(CreateOrStep(expr->id()));
2264-
break;
22652269
case BinaryCond::kOptionalOr:
22662270
visitor_->AddStep(
22672271
CreateOptionalOrStep(/*is_or_value=*/false, expr->id()));
@@ -2273,13 +2277,11 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) {
22732277
default:
22742278
ABSL_UNREACHABLE();
22752279
}
2276-
}
2277-
if (short_circuiting_) {
2278-
// If short-circuiting is enabled, point the conditional jump past the
2279-
// boolean operator step.
2280-
for (auto& jump : jump_steps_) {
2281-
visitor_->SetProgressStatusIfError(
2282-
jump.set_target(visitor_->GetCurrentIndex()));
2280+
if (short_circuiting_) {
2281+
for (auto& jump : jump_steps_) {
2282+
visitor_->SetProgressStatusIfError(
2283+
jump.set_target(visitor_->GetCurrentIndex()));
2284+
}
22832285
}
22842286
}
22852287
}

0 commit comments

Comments
 (0)