Skip to content

Commit eb68184

Browse files
jnthntatumcopybara-github
authored andcommitted
Add conformance tests with select optimization enabled.
Update the optimizer to permit invalid constant map keys (by not rewriting the select chain when encountered). These should not be allowed per spec, but the CEL implementations generally permit it to tolerate `dyn` keys. PiperOrigin-RevId: 839852761
1 parent 462a880 commit eb68184

9 files changed

Lines changed: 82 additions & 34 deletions

File tree

conformance/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ cc_library(
5353
"//extensions:math_ext_decls",
5454
"//extensions:math_ext_macros",
5555
"//extensions:proto_ext",
56+
"//extensions:select_optimization",
5657
"//extensions:strings",
5758
"//extensions/protobuf:enum_adapter",
5859
"//internal:status_macros",
@@ -297,6 +298,25 @@ gen_conformance_tests(
297298
skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED,
298299
)
299300

301+
# select optimization is only supported for checked expressions.
302+
gen_conformance_tests(
303+
name = "conformance_legacy_select_opt",
304+
checked = True,
305+
data = _ALL_TESTS,
306+
modern = False,
307+
select_opt = True,
308+
skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED,
309+
)
310+
311+
gen_conformance_tests(
312+
name = "conformance_select_opt",
313+
checked = True,
314+
data = _ALL_TESTS,
315+
modern = True,
316+
select_opt = True,
317+
skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED,
318+
)
319+
300320
# Generates a bunch of `cc_test` whose names follow the pattern
301321
# `conformance_dashboard_..._{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`.
302322
gen_conformance_tests(

conformance/run.bzl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ def _conformance_test_name(name, optimize, recursive):
4646
],
4747
)
4848

49-
def _conformance_test_args(modern, optimize, recursive, skip_check, skip_tests, dashboard):
49+
def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard):
5050
args = []
5151
if modern:
5252
args.append("--modern")
5353
if optimize:
5454
args.append("--opt")
55+
if select_opt:
56+
args.append("--select_optimization")
5557
if recursive:
5658
args.append("--recursive")
5759
if skip_check:
@@ -63,16 +65,16 @@ def _conformance_test_args(modern, optimize, recursive, skip_check, skip_tests,
6365
args.append("--dashboard")
6466
return args
6567

66-
def _conformance_test(name, data, modern, optimize, recursive, skip_check, skip_tests, tags, dashboard):
68+
def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard):
6769
cc_test(
6870
name = _conformance_test_name(name, optimize, recursive),
69-
args = _conformance_test_args(modern, optimize, recursive, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data],
71+
args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data],
7072
data = data,
7173
deps = ["//conformance:run"],
7274
tags = tags,
7375
)
7476

75-
def gen_conformance_tests(name, data, modern = False, checked = False, dashboard = False, skip_tests = [], tags = []):
77+
def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = []):
7678
"""Generates conformance tests.
7779
7880
Args:
@@ -97,6 +99,7 @@ def gen_conformance_tests(name, data, modern = False, checked = False, dashboard
9799
modern = modern,
98100
optimize = optimize,
99101
recursive = recursive,
102+
select_opt = select_opt,
100103
skip_check = skip_check,
101104
skip_tests = skip_tests,
102105
tags = tags,

conformance/run.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ ABSL_FLAG(bool, recursive, false,
6363
ABSL_FLAG(std::vector<std::string>, skip_tests, {}, "Tests to skip");
6464
ABSL_FLAG(bool, dashboard, false, "Dashboard mode, ignore test failures");
6565
ABSL_FLAG(bool, skip_check, true, "Skip type checking the expressions");
66+
ABSL_FLAG(bool, select_optimization, false, "Enable select optimization.");
6667

6768
namespace {
6869

@@ -257,7 +258,9 @@ NewConformanceServiceFromFlags() {
257258
cel_conformance::ConformanceServiceOptions{
258259
.optimize = absl::GetFlag(FLAGS_opt),
259260
.modern = absl::GetFlag(FLAGS_modern),
260-
.recursive = absl::GetFlag(FLAGS_recursive)});
261+
.recursive = absl::GetFlag(FLAGS_recursive),
262+
.select_optimization = absl::GetFlag(FLAGS_select_optimization),
263+
});
261264
ABSL_CHECK_OK(status_or_service);
262265
return std::shared_ptr<cel_conformance::ConformanceServiceInterface>(
263266
std::move(*status_or_service));

conformance/service.cc

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,10 @@
4747
#include "checker/type_checker_builder_factory.h"
4848
#include "common/ast.h"
4949
#include "common/ast_proto.h"
50-
#include "common/decl.h"
5150
#include "common/decl_proto_v1alpha1.h"
5251
#include "common/expr.h"
5352
#include "common/internal/value_conversion.h"
5453
#include "common/source.h"
55-
#include "common/type.h"
5654
#include "common/value.h"
5755
#include "eval/public/activation.h"
5856
#include "eval/public/builtin_func_registrar.h"
@@ -70,6 +68,7 @@
7068
#include "extensions/math_ext_macros.h"
7169
#include "extensions/proto_ext.h"
7270
#include "extensions/protobuf/enum_adapter.h"
71+
#include "extensions/select_optimization.h"
7372
#include "extensions/strings.h"
7473
#include "internal/status_macros.h"
7574
#include "parser/macro.h"
@@ -340,7 +339,7 @@ absl::Status CheckImpl(google::protobuf::Arena* arena,
340339
class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
341340
public:
342341
static absl::StatusOr<std::unique_ptr<LegacyConformanceServiceImpl>> Create(
343-
bool optimize, bool recursive) {
342+
bool optimize, bool recursive, bool select_optimization) {
344343
static auto* constant_arena = new Arena();
345344

346345
google::protobuf::LinkMessageReflection<
@@ -385,6 +384,11 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
385384
options.constant_arena = constant_arena;
386385
}
387386

387+
if (select_optimization) {
388+
std::cerr << "Enabling select optimizations" << std::endl;
389+
options.enable_select_optimization = true;
390+
}
391+
388392
if (recursive) {
389393
options.max_recursion_depth = 48;
390394
}
@@ -526,7 +530,7 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface {
526530
class ModernConformanceServiceImpl : public ConformanceServiceInterface {
527531
public:
528532
static absl::StatusOr<std::unique_ptr<ModernConformanceServiceImpl>> Create(
529-
bool optimize, bool recursive) {
533+
bool optimize, bool recursive, bool select_optimization) {
530534
google::protobuf::LinkMessageReflection<
531535
cel::expr::conformance::proto3::TestAllTypes>();
532536
google::protobuf::LinkMessageReflection<
@@ -565,8 +569,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
565569
options.max_recursion_depth = 48;
566570
}
567571

568-
return absl::WrapUnique(
569-
new ModernConformanceServiceImpl(options, optimize));
572+
return absl::WrapUnique(new ModernConformanceServiceImpl(
573+
options, optimize, select_optimization));
570574
}
571575

572576
absl::StatusOr<std::unique_ptr<const cel::Runtime>> Setup(
@@ -583,6 +587,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
583587
}
584588
CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver(
585589
builder, cel::ReferenceResolverEnabled::kAlways));
590+
if (enable_select_optimization_) {
591+
CEL_RETURN_IF_ERROR(cel::extensions::EnableSelectOptimization(builder));
592+
}
586593

587594
auto& type_registry = builder.type_registry();
588595
// Use linked pbs in the generated descriptor pool.
@@ -704,10 +711,12 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
704711
}
705712

706713
private:
707-
explicit ModernConformanceServiceImpl(const RuntimeOptions& options,
708-
bool enable_optimizations)
709-
: options_(options), enable_optimizations_(enable_optimizations) {}
710-
714+
ModernConformanceServiceImpl(const RuntimeOptions& options,
715+
bool enable_optimizations,
716+
bool enable_select_optimization)
717+
: options_(options),
718+
enable_optimizations_(enable_optimizations),
719+
enable_select_optimization_(enable_select_optimization) {}
711720

712721
static absl::StatusOr<std::unique_ptr<cel::TraceableProgram>> Plan(
713722
const cel::Runtime& runtime,
@@ -737,6 +746,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface {
737746

738747
RuntimeOptions options_;
739748
bool enable_optimizations_;
749+
bool enable_select_optimization_;
740750
};
741751

742752
} // namespace
@@ -749,10 +759,10 @@ absl::StatusOr<std::unique_ptr<ConformanceServiceInterface>>
749759
NewConformanceService(const ConformanceServiceOptions& options) {
750760
if (options.modern) {
751761
return google::api::expr::runtime::ModernConformanceServiceImpl::Create(
752-
options.optimize, options.recursive);
762+
options.optimize, options.recursive, options.select_optimization);
753763
} else {
754764
return google::api::expr::runtime::LegacyConformanceServiceImpl::Create(
755-
options.optimize, options.recursive);
765+
options.optimize, options.recursive, options.select_optimization);
756766
}
757767
}
758768

conformance/service.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct ConformanceServiceOptions {
4545
bool modern;
4646
bool arena;
4747
bool recursive;
48+
bool select_optimization;
4849
};
4950

5051
absl::StatusOr<std::unique_ptr<ConformanceServiceInterface>>

eval/public/cel_options.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,6 @@ struct InterpreterOptions {
163163
// always operates as though it is `true`.
164164
// - `enable_heterogeneous_equality` is ignored and optimized traversals
165165
// always operate as though it is `true`.
166-
//
167-
// Note: implementation in progress -- please consult the CEL team before
168-
// enabling in an existing environment.
169166
bool enable_select_optimization = false;
170167

171168
// Enable lazy cel.bind alias initialization.

extensions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ cc_library(
329329
"//eval/eval:evaluator_core",
330330
"//eval/eval:expression_step_base",
331331
"//internal:casts",
332+
"//internal:number",
332333
"//internal:status_macros",
333334
"//runtime:runtime_builder",
334335
"//runtime/internal:errors",

extensions/select_optimization.cc

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "eval/eval/evaluator_core.h"
5454
#include "eval/eval/expression_step_base.h"
5555
#include "internal/casts.h"
56+
#include "internal/number.h"
5657
#include "internal/status_macros.h"
5758
#include "runtime/internal/errors.h"
5859
#include "runtime/internal/runtime_friend_access.h"
@@ -188,34 +189,45 @@ absl::StatusOr<SelectQualifier> SelectQualifierFromList(const ListExpr& list) {
188189
field_name.const_expr().string_value()};
189190
}
190191

192+
// Returns a qualifier instruction derived from a unoptimized ast.
191193
absl::StatusOr<QualifierInstruction> SelectInstructionFromConstant(
192194
const Constant& constant) {
193-
if (constant.has_int64_value()) {
194-
return QualifierInstruction(constant.int64_value());
195-
} else if (constant.has_uint64_value()) {
196-
return QualifierInstruction(constant.uint64_value());
195+
if (constant.has_int_value()) {
196+
return QualifierInstruction(constant.int_value());
197+
} else if (constant.has_uint_value()) {
198+
return QualifierInstruction(constant.uint_value());
197199
} else if (constant.has_bool_value()) {
198200
return QualifierInstruction(constant.bool_value());
199201
} else if (constant.has_string_value()) {
200202
return QualifierInstruction(constant.string_value());
203+
} else if (constant.has_double_value()) {
204+
cel::internal::Number number(constant.double_value());
205+
if (number.LosslessConvertibleToInt()) {
206+
return QualifierInstruction(number.AsInt());
207+
} else if (number.LosslessConvertibleToUint()) {
208+
return QualifierInstruction(number.AsUint());
209+
}
201210
}
202211

203-
return absl::InvalidArgumentError("Invalid cel.attribute constant");
212+
return absl::InvalidArgumentError("invalid index constant for cel.attribute");
204213
}
205214

206215
absl::StatusOr<SelectQualifier> SelectQualifierFromConstant(
207216
const Constant& constant) {
208-
if (constant.has_int64_value()) {
209-
return AttributeQualifier::OfInt(constant.int64_value());
210-
} else if (constant.has_uint64_value()) {
211-
return AttributeQualifier::OfUint(constant.uint64_value());
217+
if (constant.has_int_value()) {
218+
return AttributeQualifier::OfInt(constant.int_value());
219+
} else if (constant.has_uint_value()) {
220+
return AttributeQualifier::OfUint(constant.uint_value());
212221
} else if (constant.has_bool_value()) {
213222
return AttributeQualifier::OfBool(constant.bool_value());
214223
} else if (constant.has_string_value()) {
215224
return AttributeQualifier::OfString(constant.string_value());
216225
}
226+
// TODO(uncreated-issue/51): double keys could possibly be valid selectors, but
227+
// the other stacks don't implement the optimization yet and we normalize the
228+
// key to a uint or int if we do the late AST rewrite during planning.
217229

218-
return absl::InvalidArgumentError("Invalid cel.attribute constant");
230+
return absl::InvalidArgumentError("invalid cel.attribute constant");
219231
}
220232

221233
absl::StatusOr<size_t> ListIndexFromQualifier(const AttributeQualifier& qual) {
@@ -248,7 +260,7 @@ absl::StatusOr<Value> MapKeyFromQualifier(const AttributeQualifier& qual,
248260
case Kind::kBool:
249261
return cel::BoolValue(*qual.GetBoolKey());
250262
case Kind::kString:
251-
return cel::StringValue(arena, *qual.GetStringKey());
263+
return StringValue::From(*qual.GetStringKey(), arena);
252264
default:
253265
return runtime_internal::CreateNoMatchingOverloadError(
254266
cel::builtin::kIndex);
@@ -424,7 +436,8 @@ class RewriterImpl : public AstRewriterBase {
424436
auto qualifier_or =
425437
SelectInstructionFromConstant(qualifier_expr.const_expr());
426438
if (!qualifier_or.ok()) {
427-
SetProgressStatus(qualifier_or.status());
439+
// TODO(uncreated-issue/54): should warn, but by default warnings fail overall
440+
// program planning.
428441
return;
429442
}
430443
candidates_[&expr] = std::move(qualifier_or).value();

extensions/select_optimization.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ struct SelectOptimizationOptions {
5454
// Assumes the default runtime implementation, an error with code
5555
// InvalidArgument is returned if it is not.
5656
//
57-
// Note: implementation in progress -- please consult the CEL team before
58-
// enabling in an existing environment.
57+
// Note: implementation does not support optional field traversal, and will
58+
// instead revert to the normal implementation instead of trying to optimize.
5959
absl::Status EnableSelectOptimization(
6060
cel::RuntimeBuilder& builder,
6161
const SelectOptimizationOptions& options = {});

0 commit comments

Comments
 (0)