Skip to content

Commit 502fb0a

Browse files
jnthntatumcopybara-github
authored andcommitted
Add specialized implementations for ==/!=/@in.
PiperOrigin-RevId: 716429211
1 parent b243e37 commit 502fb0a

18 files changed

Lines changed: 1267 additions & 21 deletions

eval/compiler/BUILD

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ cc_library(
101101
"//common:ast",
102102
"//common:ast_traverse",
103103
"//common:ast_visitor",
104+
"//common:kind",
104105
"//common:memory",
105106
"//common:type",
106107
"//common:value",
@@ -111,6 +112,7 @@ cc_library(
111112
"//eval/eval:create_map_step",
112113
"//eval/eval:create_struct_step",
113114
"//eval/eval:direct_expression_step",
115+
"//eval/eval:equality_steps",
114116
"//eval/eval:evaluator_core",
115117
"//eval/eval:function_step",
116118
"//eval/eval:ident_step",
@@ -159,6 +161,7 @@ cc_test(
159161
":constant_folding",
160162
":flat_expr_builder",
161163
":qualified_reference_resolver",
164+
"//base:builtins",
162165
"//base:function",
163166
"//base:function_descriptor",
164167
"//common:value",
@@ -178,18 +181,18 @@ cc_test(
178181
"//eval/public/containers:container_backed_map_impl",
179182
"//eval/public/structs:cel_proto_descriptor_pool_builder",
180183
"//eval/public/structs:cel_proto_wrapper",
181-
"//eval/public/structs:protobuf_descriptor_type_provider",
182184
"//eval/public/testing:matchers",
183185
"//eval/testutil:test_message_cc_proto",
184-
"//internal:proto_file_util",
185186
"//internal:proto_matchers",
186187
"//internal:status_macros",
187188
"//internal:testing",
188189
"//parser",
190+
"//runtime:function_adapter",
189191
"//runtime:runtime_options",
190192
"//runtime/internal:runtime_env_testing",
191193
"@com_google_absl//absl/container:flat_hash_map",
192194
"@com_google_absl//absl/status",
195+
"@com_google_absl//absl/status:status_matchers",
193196
"@com_google_absl//absl/strings",
194197
"@com_google_absl//absl/types:span",
195198
"@com_google_cel_spec//proto/cel/expr:checked_cc_proto",

eval/compiler/flat_expr_builder.cc

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
#include "common/ast.h"
5555
#include "common/ast_traverse.h"
5656
#include "common/ast_visitor.h"
57+
#include "common/kind.h"
5758
#include "common/memory.h"
5859
#include "common/type.h"
5960
#include "common/value.h"
@@ -68,6 +69,7 @@
6869
#include "eval/eval/create_map_step.h"
6970
#include "eval/eval/create_struct_step.h"
7071
#include "eval/eval/direct_expression_step.h"
72+
#include "eval/eval/equality_steps.h"
7173
#include "eval/eval/evaluator_core.h"
7274
#include "eval/eval/function_step.h"
7375
#include "eval/eval/ident_step.h"
@@ -527,6 +529,36 @@ class FlatExprVisitor : public cel::AstVisitor {
527529
const cel::ast_internal::Call& call) {
528530
return HandleNot(expr, call);
529531
};
532+
if (options_.enable_heterogeneous_equality) {
533+
for (const auto& in_op :
534+
{cel::builtin::kIn, cel::builtin::kInDeprecated,
535+
cel::builtin::kInFunction}) {
536+
call_handlers_[in_op] = [this](const cel::ast_internal::Expr& expr,
537+
const cel::ast_internal::Call& call) {
538+
return HandleHeterogeneousEqualityIn(expr, call);
539+
};
540+
}
541+
// Try to detect if the environment is setup with a custom equality
542+
// implementation.
543+
if (resolver_
544+
.FindOverloads(cel::builtin::kEqual,
545+
/*receiver_style=*/false,
546+
{cel::Kind::kAny, cel::Kind::kAny})
547+
.empty()) {
548+
call_handlers_[cel::builtin::kEqual] =
549+
[this](const cel::ast_internal::Expr& expr,
550+
const cel::ast_internal::Call& call) {
551+
return HandleHeterogeneousEquality(expr, call,
552+
/*inequality=*/false);
553+
};
554+
call_handlers_[cel::builtin::kInequal] =
555+
[this](const cel::ast_internal::Expr& expr,
556+
const cel::ast_internal::Call& call) {
557+
return HandleHeterogeneousEquality(expr, call,
558+
/*inequality=*/true);
559+
};
560+
}
561+
}
530562
}
531563
}
532564

@@ -1874,6 +1906,13 @@ class FlatExprVisitor : public cel::AstVisitor {
18741906
CallHandlerResult HandleNotStrictlyFalse(const cel::ast_internal::Expr& expr,
18751907
const cel::ast_internal::Call& call);
18761908

1909+
CallHandlerResult HandleHeterogeneousEquality(
1910+
const cel::ast_internal::Expr& expr, const cel::ast_internal::Call& call,
1911+
bool inequality);
1912+
1913+
CallHandlerResult HandleHeterogeneousEqualityIn(
1914+
const cel::ast_internal::Expr& expr, const cel::ast_internal::Call& call);
1915+
18771916
const Resolver& resolver_;
18781917
ValueManager& value_factory_;
18791918
absl::Status progress_status_;
@@ -2026,6 +2065,59 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleListAppend(
20262065
return CallHandlerResult::kNotIntercepted;
20272066
}
20282067

2068+
FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality(
2069+
const cel::ast_internal::Expr& expr, const cel::ast_internal::Call& call,
2070+
bool inequality) {
2071+
if (!ValidateOrError(
2072+
call.args().size() == 2,
2073+
"unexpected number of args for builtin equality operator")) {
2074+
return CallHandlerResult::kIntercepted;
2075+
}
2076+
auto depth = RecursionEligible();
2077+
2078+
if (depth.has_value()) {
2079+
auto args = ExtractRecursiveDependencies();
2080+
if (args.size() != 2) {
2081+
SetProgressStatusError(absl::InvalidArgumentError(
2082+
"unexpected number of args for builtin equality operator"));
2083+
return CallHandlerResult::kIntercepted;
2084+
}
2085+
SetRecursiveStep(
2086+
CreateDirectEqualityStep(std::move(args[0]), std::move(args[1]),
2087+
inequality, expr.id()),
2088+
*depth + 1);
2089+
return CallHandlerResult::kIntercepted;
2090+
}
2091+
AddStep(CreateEqualityStep(inequality, expr.id()));
2092+
return CallHandlerResult::kIntercepted;
2093+
}
2094+
2095+
FlatExprVisitor::CallHandlerResult
2096+
FlatExprVisitor::HandleHeterogeneousEqualityIn(
2097+
const cel::ast_internal::Expr& expr, const cel::ast_internal::Call& call) {
2098+
if (!ValidateOrError(call.args().size() == 2,
2099+
"unexpected number of args for builtin 'in' operator")) {
2100+
return CallHandlerResult::kIntercepted;
2101+
}
2102+
2103+
auto depth = RecursionEligible();
2104+
if (depth.has_value()) {
2105+
auto args = ExtractRecursiveDependencies();
2106+
if (args.size() != 2) {
2107+
SetProgressStatusError(absl::InvalidArgumentError(
2108+
"unexpected number of args for builtin 'in' operator"));
2109+
return CallHandlerResult::kIntercepted;
2110+
}
2111+
SetRecursiveStep(
2112+
CreateDirectInStep(std::move(args[0]), std::move(args[1]), expr.id()),
2113+
*depth + 1);
2114+
return CallHandlerResult::kIntercepted;
2115+
}
2116+
2117+
AddStep(CreateInStep(expr.id()));
2118+
return CallHandlerResult::kIntercepted;
2119+
}
2120+
20292121
void BinaryCondVisitor::PreVisit(const cel::ast_internal::Expr* expr) {
20302122
switch (cond_) {
20312123
case BinaryCond::kAnd:

eval/compiler/flat_expr_builder_test.cc

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "eval/compiler/flat_expr_builder.h"
1818

19+
#include <cstdint>
1920
#include <functional>
2021
#include <memory>
2122
#include <string>
@@ -28,9 +29,11 @@
2829
#include "google/protobuf/descriptor.pb.h"
2930
#include "absl/container/flat_hash_map.h"
3031
#include "absl/status/status.h"
32+
#include "absl/status/status_matchers.h"
3133
#include "absl/strings/str_split.h"
3234
#include "absl/strings/string_view.h"
3335
#include "absl/types/span.h"
36+
#include "base/builtins.h"
3437
#include "base/function.h"
3538
#include "base/function_descriptor.h"
3639
#include "common/value.h"
@@ -51,16 +54,15 @@
5154
#include "eval/public/portable_cel_function_adapter.h"
5255
#include "eval/public/structs/cel_proto_descriptor_pool_builder.h"
5356
#include "eval/public/structs/cel_proto_wrapper.h"
54-
#include "eval/public/structs/protobuf_descriptor_type_provider.h"
5557
#include "eval/public/testing/matchers.h"
5658
#include "eval/public/unknown_attribute_set.h"
5759
#include "eval/public/unknown_set.h"
5860
#include "eval/testutil/test_message.pb.h"
59-
#include "internal/proto_file_util.h"
6061
#include "internal/proto_matchers.h"
6162
#include "internal/status_macros.h"
6263
#include "internal/testing.h"
6364
#include "parser/parser.h"
65+
#include "runtime/function_adapter.h"
6466
#include "runtime/internal/runtime_env_testing.h"
6567
#include "runtime/runtime_options.h"
6668
#include "cel/expr/conformance/proto3/test_all_types.pb.h"
@@ -73,7 +75,9 @@ namespace google::api::expr::runtime {
7375

7476
namespace {
7577

78+
using ::absl_testing::IsOk;
7679
using ::absl_testing::StatusIs;
80+
using ::cel::BytesValue;
7781
using ::cel::Value;
7882
using ::cel::expr::conformance::proto3::TestAllTypes;
7983
using ::cel::internal::test::EqualsProto;
@@ -1842,6 +1846,64 @@ TEST(FlatExprBuilderTest, TypeResolve) {
18421846
EXPECT_TRUE(result.BoolOrDie());
18431847
}
18441848

1849+
TEST(FlatExprBuilderTest, FastEquality) {
1850+
TestMessage message;
1851+
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("'foo' == 'bar'"));
1852+
cel::RuntimeOptions options;
1853+
options.enable_fast_builtins = true;
1854+
InterpreterOptions legacy_options;
1855+
legacy_options.enable_fast_builtins = true;
1856+
CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options);
1857+
ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options),
1858+
IsOk());
1859+
ASSERT_OK_AND_ASSIGN(auto expression,
1860+
builder.CreateExpression(&parsed_expr.expr(),
1861+
&parsed_expr.source_info()));
1862+
1863+
Activation activation;
1864+
google::protobuf::Arena arena;
1865+
ASSERT_OK_AND_ASSIGN(CelValue result,
1866+
expression->Evaluate(activation, &arena));
1867+
1868+
ASSERT_TRUE(result.IsBool()) << result.DebugString();
1869+
EXPECT_FALSE(result.BoolOrDie());
1870+
}
1871+
1872+
TEST(FlatExprBuilderTest, FastEqualityDisabledWithCustomEquality) {
1873+
TestMessage message;
1874+
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("1 == b'\001'"));
1875+
cel::RuntimeOptions options;
1876+
options.enable_fast_builtins = true;
1877+
InterpreterOptions legacy_options;
1878+
legacy_options.enable_fast_builtins = true;
1879+
CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options);
1880+
ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options),
1881+
IsOk());
1882+
1883+
auto& registry = builder.GetRegistry()->InternalGetRegistry();
1884+
1885+
auto status = cel::BinaryFunctionAdapter<bool, int64_t, const BytesValue&>::
1886+
RegisterGlobalOverload(
1887+
"_==_",
1888+
[](auto&, int64_t lhs, const cel::BytesValue& rhs) -> bool {
1889+
return true;
1890+
},
1891+
registry);
1892+
ASSERT_THAT(status, IsOk());
1893+
1894+
ASSERT_OK_AND_ASSIGN(auto expression,
1895+
builder.CreateExpression(&parsed_expr.expr(),
1896+
&parsed_expr.source_info()));
1897+
1898+
Activation activation;
1899+
google::protobuf::Arena arena;
1900+
ASSERT_OK_AND_ASSIGN(CelValue result,
1901+
expression->Evaluate(activation, &arena));
1902+
1903+
ASSERT_TRUE(result.IsBool()) << result.DebugString();
1904+
EXPECT_TRUE(result.BoolOrDie());
1905+
}
1906+
18451907
TEST(FlatExprBuilderTest, AnyPackingList) {
18461908
google::protobuf::LinkMessageReflection<TestAllTypes>();
18471909
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr,

eval/compiler/qualified_reference_resolver.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ bool IsSpecialFunction(absl::string_view function_name) {
5959
function_name == cel::builtin::kIndex ||
6060
function_name == cel::builtin::kTernary ||
6161
function_name == kOptionalOr || function_name == kOptionalOrValue ||
62+
function_name == cel::builtin::kEqual ||
63+
function_name == cel::builtin::kInequal ||
64+
function_name == cel::builtin::kNot ||
65+
function_name == cel::builtin::kNotStrictlyFalse ||
66+
function_name == cel::builtin::kNotStrictlyFalseDeprecated ||
67+
function_name == cel::builtin::kIn ||
68+
function_name == cel::builtin::kInDeprecated ||
69+
function_name == cel::builtin::kInFunction ||
6270
function_name == "cel.@block";
6371
}
6472

eval/eval/BUILD

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,57 @@ cc_library(
457457
],
458458
)
459459

460+
cc_library(
461+
name = "equality_steps",
462+
srcs = [
463+
"equality_steps.cc",
464+
],
465+
hdrs = [
466+
"equality_steps.h",
467+
],
468+
deps = [
469+
":attribute_trail",
470+
":direct_expression_step",
471+
":evaluator_core",
472+
":expression_step_base",
473+
"//base:builtins",
474+
"//common:value",
475+
"//common:value_kind",
476+
"//internal:number",
477+
"//internal:status_macros",
478+
"//runtime/internal:errors",
479+
"//runtime/standard:equality_functions",
480+
"@com_google_absl//absl/status",
481+
"@com_google_absl//absl/status:statusor",
482+
],
483+
)
484+
485+
cc_test(
486+
name = "equality_steps_test",
487+
srcs = [
488+
"equality_steps_test.cc",
489+
],
490+
deps = [
491+
":attribute_trail",
492+
":direct_expression_step",
493+
":equality_steps",
494+
":evaluator_core",
495+
"//base:attributes",
496+
"//common:memory",
497+
"//common:type",
498+
"//common:value",
499+
"//common:value_kind",
500+
"//common:value_testing",
501+
"//internal:testing",
502+
"//runtime:activation",
503+
"//runtime:runtime_options",
504+
"@com_google_absl//absl/log:absl_check",
505+
"@com_google_absl//absl/status",
506+
"@com_google_absl//absl/status:status_matchers",
507+
"@com_google_protobuf//:protobuf",
508+
],
509+
)
510+
460511
cc_library(
461512
name = "comprehension_step",
462513
srcs = [

eval/eval/attribute_utility.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,16 @@ void Accumulator::Add(const UnknownValue& value) {
202202
void Accumulator::Add(const AttributeTrail& attr) { parent_.Add(*this, attr); }
203203

204204
void Accumulator::MaybeAdd(const Value& v) {
205-
if (InstanceOf<UnknownValue>(v)) {
206-
Add(Cast<UnknownValue>(v));
205+
if (v.IsUnknown()) {
206+
Add(v.GetUnknown());
207+
}
208+
}
209+
210+
void Accumulator::MaybeAdd(const Value& v, const AttributeTrail& attr) {
211+
if (v.IsUnknown()) {
212+
Add(v.GetUnknown());
213+
} else if (parent_.CheckForUnknown(attr, /*use_partial=*/true)) {
214+
Add(attr);
207215
}
208216
}
209217

eval/eval/attribute_utility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ class AttributeUtility {
3434
// Add to the accumulated set of unknowns if value is UnknownValue.
3535
void MaybeAdd(const cel::Value& v);
3636

37+
// Add to the accumulated set of unknowns if value is UnknownValue or
38+
// the attribute trail is (partially) unknown. This version prefers
39+
// preserving an already present unknown value over a new one matching the
40+
// attribute trail.
41+
//
42+
// Uses partial matching (a pattern matches the attribute or any
43+
// sub-attribute).
44+
void MaybeAdd(const cel::Value& v, const AttributeTrail& attr);
45+
3746
bool IsEmpty() const;
3847

3948
cel::UnknownValue Build() &&;

0 commit comments

Comments
 (0)