Skip to content

Commit 35ae560

Browse files
jnthntatumcopybara-github
authored andcommitted
Add specialized impls for !_ and @not_strictly_false
Adds custom evaluator steps for !_ and @not_strictly_false as an alternative for the extension function based implementation. These tend to be called at high frequency (e.g. .exists and .all, predicates), so they benefit from the reduced overhead. This is optional since it prevents disabling or extending the operators. PiperOrigin-RevId: 716319261
1 parent aff0663 commit 35ae560

9 files changed

Lines changed: 532 additions & 96 deletions

File tree

eval/compiler/flat_expr_builder.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,23 @@ class FlatExprVisitor : public cel::AstVisitor {
511511
const cel::ast_internal::Call& call) {
512512
return HandleListAppend(expr, call);
513513
};
514+
if (options_.enable_fast_builtins) {
515+
call_handlers_[cel::builtin::kNotStrictlyFalse] =
516+
[this](const cel::ast_internal::Expr& expr,
517+
const cel::ast_internal::Call& call) {
518+
return HandleNotStrictlyFalse(expr, call);
519+
};
520+
call_handlers_[cel::builtin::kNotStrictlyFalseDeprecated] =
521+
[this](const cel::ast_internal::Expr& expr,
522+
const cel::ast_internal::Call& call) {
523+
return HandleNotStrictlyFalse(expr, call);
524+
};
525+
call_handlers_[cel::builtin::kNot] =
526+
[this](const cel::ast_internal::Expr& expr,
527+
const cel::ast_internal::Call& call) {
528+
return HandleNot(expr, call);
529+
};
530+
}
514531
}
515532

516533
void PreVisitExpr(const cel::ast_internal::Expr& expr) override {
@@ -1852,6 +1869,10 @@ class FlatExprVisitor : public cel::AstVisitor {
18521869
const cel::ast_internal::Call& call);
18531870
CallHandlerResult HandleListAppend(const cel::ast_internal::Expr& expr,
18541871
const cel::ast_internal::Call& call);
1872+
CallHandlerResult HandleNot(const cel::ast_internal::Expr& expr,
1873+
const cel::ast_internal::Call& call);
1874+
CallHandlerResult HandleNotStrictlyFalse(const cel::ast_internal::Expr& expr,
1875+
const cel::ast_internal::Call& call);
18551876

18561877
const Resolver& resolver_;
18571878
ValueManager& value_factory_;
@@ -1910,6 +1931,49 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex(
19101931
return CallHandlerResult::kIntercepted;
19111932
}
19121933

1934+
FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot(
1935+
const cel::ast_internal::Expr& expr,
1936+
const cel::ast_internal::Call& call_expr) {
1937+
ABSL_DCHECK(call_expr.function() == cel::builtin::kNot);
1938+
auto depth = RecursionEligible();
1939+
1940+
if (depth.has_value()) {
1941+
auto args = ExtractRecursiveDependencies();
1942+
if (args.size() != 1) {
1943+
SetProgressStatusError(absl::InvalidArgumentError(
1944+
"unexpected number of args for builtin not operator"));
1945+
return CallHandlerResult::kIntercepted;
1946+
}
1947+
SetRecursiveStep(CreateDirectNotStep(std::move(args[0]), expr.id()),
1948+
*depth + 1);
1949+
return CallHandlerResult::kIntercepted;
1950+
}
1951+
AddStep(CreateNotStep(expr.id()));
1952+
return CallHandlerResult::kIntercepted;
1953+
}
1954+
1955+
FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse(
1956+
const cel::ast_internal::Expr& expr,
1957+
const cel::ast_internal::Call& call_expr) {
1958+
auto depth = RecursionEligible();
1959+
1960+
if (depth.has_value()) {
1961+
auto args = ExtractRecursiveDependencies();
1962+
if (args.size() != 1) {
1963+
SetProgressStatusError(
1964+
absl::InvalidArgumentError("unexpected number of args for builtin "
1965+
"@not_strictly_false operator"));
1966+
return CallHandlerResult::kIntercepted;
1967+
}
1968+
SetRecursiveStep(
1969+
CreateDirectNotStrictlyFalseStep(std::move(args[0]), expr.id()),
1970+
*depth + 1);
1971+
return CallHandlerResult::kIntercepted;
1972+
}
1973+
AddStep(CreateNotStrictlyFalseStep(expr.id()));
1974+
return CallHandlerResult::kIntercepted;
1975+
}
1976+
19131977
FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock(
19141978
const cel::ast_internal::Expr& expr,
19151979
const cel::ast_internal::Call& call_expr) {

eval/eval/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,7 @@ cc_test(
745745
"//runtime/internal:runtime_env_testing",
746746
"@com_google_absl//absl/base:nullability",
747747
"@com_google_absl//absl/status",
748+
"@com_google_absl//absl/status:status_matchers",
748749
"@com_google_absl//absl/strings",
749750
"@com_google_absl//absl/strings:string_view",
750751
"@com_google_protobuf//:protobuf",

eval/eval/logic_step.cc

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,155 @@ std::unique_ptr<DirectExpressionStep> CreateDirectLogicStep(
285285
}
286286
}
287287

288+
class DirectNotStep : public DirectExpressionStep {
289+
public:
290+
explicit DirectNotStep(std::unique_ptr<DirectExpressionStep> operand,
291+
int64_t expr_id)
292+
: DirectExpressionStep(expr_id), operand_(std::move(operand)) {}
293+
absl::Status Evaluate(ExecutionFrameBase& frame, Value& result,
294+
AttributeTrail& attribute_trail) const override;
295+
296+
private:
297+
std::unique_ptr<DirectExpressionStep> operand_;
298+
};
299+
300+
absl::Status DirectNotStep::Evaluate(ExecutionFrameBase& frame, Value& result,
301+
AttributeTrail& attribute_trail) const {
302+
CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute_trail));
303+
304+
if (frame.unknown_processing_enabled()) {
305+
if (frame.attribute_utility().CheckForUnknownPartial(attribute_trail)) {
306+
result = frame.attribute_utility().CreateUnknownSet(
307+
attribute_trail.attribute());
308+
return absl::OkStatus();
309+
}
310+
}
311+
312+
switch (result.kind()) {
313+
case ValueKind::kBool:
314+
result = BoolValue{!result.GetBool().NativeValue()};
315+
break;
316+
case ValueKind::kUnknown:
317+
case ValueKind::kError:
318+
// just forward.
319+
break;
320+
default:
321+
result = frame.value_manager().CreateErrorValue(
322+
CreateNoMatchingOverloadError(cel::builtin::kNot));
323+
break;
324+
}
325+
326+
return absl::OkStatus();
327+
}
328+
329+
class IterativeNotStep : public ExpressionStepBase {
330+
public:
331+
explicit IterativeNotStep(int64_t expr_id) : ExpressionStepBase(expr_id) {}
332+
333+
absl::Status Evaluate(ExecutionFrame* frame) const override;
334+
};
335+
336+
absl::Status IterativeNotStep::Evaluate(ExecutionFrame* frame) const {
337+
if (!frame->value_stack().HasEnough(1)) {
338+
return absl::InternalError("Value stack underflow");
339+
}
340+
const Value& operand = frame->value_stack().Peek();
341+
342+
if (frame->unknown_processing_enabled()) {
343+
const AttributeTrail& attribute_trail =
344+
frame->value_stack().PeekAttribute();
345+
if (frame->attribute_utility().CheckForUnknownPartial(attribute_trail)) {
346+
frame->value_stack().PopAndPush(
347+
frame->attribute_utility().CreateUnknownSet(
348+
attribute_trail.attribute()));
349+
return absl::OkStatus();
350+
}
351+
}
352+
353+
switch (operand.kind()) {
354+
case ValueKind::kBool:
355+
frame->value_stack().PopAndPush(
356+
BoolValue{!operand.GetBool().NativeValue()});
357+
break;
358+
case ValueKind::kUnknown:
359+
case ValueKind::kError:
360+
// just forward.
361+
break;
362+
default:
363+
frame->value_stack().PopAndPush(frame->value_factory().CreateErrorValue(
364+
CreateNoMatchingOverloadError(cel::builtin::kNot)));
365+
break;
366+
}
367+
368+
return absl::OkStatus();
369+
}
370+
371+
class DirectNotStrictlyFalseStep : public DirectExpressionStep {
372+
public:
373+
explicit DirectNotStrictlyFalseStep(
374+
std::unique_ptr<DirectExpressionStep> operand, int64_t expr_id)
375+
: DirectExpressionStep(expr_id), operand_(std::move(operand)) {}
376+
absl::Status Evaluate(ExecutionFrameBase& frame, Value& result,
377+
AttributeTrail& attribute_trail) const override;
378+
379+
private:
380+
std::unique_ptr<DirectExpressionStep> operand_;
381+
};
382+
383+
absl::Status DirectNotStrictlyFalseStep::Evaluate(
384+
ExecutionFrameBase& frame, Value& result,
385+
AttributeTrail& attribute_trail) const {
386+
CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute_trail));
387+
388+
switch (result.kind()) {
389+
case ValueKind::kBool:
390+
// just forward.
391+
break;
392+
case ValueKind::kUnknown:
393+
case ValueKind::kError:
394+
result = BoolValue(true);
395+
break;
396+
default:
397+
result = frame.value_manager().CreateErrorValue(
398+
CreateNoMatchingOverloadError(cel::builtin::kNot));
399+
break;
400+
}
401+
402+
return absl::OkStatus();
403+
}
404+
405+
class IterativeNotStrictlyFalseStep : public ExpressionStepBase {
406+
public:
407+
explicit IterativeNotStrictlyFalseStep(int64_t expr_id)
408+
: ExpressionStepBase(expr_id) {}
409+
410+
absl::Status Evaluate(ExecutionFrame* frame) const override;
411+
};
412+
413+
absl::Status IterativeNotStrictlyFalseStep::Evaluate(
414+
ExecutionFrame* frame) const {
415+
if (!frame->value_stack().HasEnough(1)) {
416+
return absl::InternalError("Value stack underflow");
417+
}
418+
const Value& operand = frame->value_stack().Peek();
419+
420+
switch (operand.kind()) {
421+
case ValueKind::kBool:
422+
// just forward.
423+
break;
424+
case ValueKind::kUnknown:
425+
case ValueKind::kError:
426+
frame->value_stack().PopAndPush(BoolValue(true));
427+
break;
428+
default:
429+
frame->value_stack().PopAndPush(frame->value_factory().CreateErrorValue(
430+
CreateNoMatchingOverloadError(cel::builtin::kNot)));
431+
break;
432+
}
433+
434+
return absl::OkStatus();
435+
}
436+
288437
} // namespace
289438

290439
// Factory method for "And" Execution step
@@ -315,4 +464,27 @@ absl::StatusOr<std::unique_ptr<ExpressionStep>> CreateOrStep(int64_t expr_id) {
315464
return std::make_unique<LogicalOpStep>(OpType::kOr, expr_id);
316465
}
317466

467+
// Factory method for recursive logical not "!" Execution step
468+
std::unique_ptr<DirectExpressionStep> CreateDirectNotStep(
469+
std::unique_ptr<DirectExpressionStep> operand, int64_t expr_id) {
470+
return std::make_unique<DirectNotStep>(std::move(operand), expr_id);
471+
}
472+
473+
// Factory method for iterative logical not "!" Execution step
474+
std::unique_ptr<ExpressionStep> CreateNotStep(int64_t expr_id) {
475+
return std::make_unique<IterativeNotStep>(expr_id);
476+
}
477+
478+
// Factory method for recursive logical "@not_strictly_false" Execution step.
479+
std::unique_ptr<DirectExpressionStep> CreateDirectNotStrictlyFalseStep(
480+
std::unique_ptr<DirectExpressionStep> operand, int64_t expr_id) {
481+
return std::make_unique<DirectNotStrictlyFalseStep>(std::move(operand),
482+
expr_id);
483+
}
484+
485+
// Factory method for iterative logical "@not_strictly_false" Execution step.
486+
std::unique_ptr<ExpressionStep> CreateNotStrictlyFalseStep(int64_t expr_id) {
487+
return std::make_unique<IterativeNotStrictlyFalseStep>(expr_id);
488+
}
489+
318490
} // namespace google::api::expr::runtime

eval/eval/logic_step.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ absl::StatusOr<std::unique_ptr<ExpressionStep>> CreateAndStep(int64_t expr_id);
2828
// Factory method for "Or" Execution step
2929
absl::StatusOr<std::unique_ptr<ExpressionStep>> CreateOrStep(int64_t expr_id);
3030

31+
// Factory method for recursive logical not "!" Execution step
32+
std::unique_ptr<DirectExpressionStep> CreateDirectNotStep(
33+
std::unique_ptr<DirectExpressionStep> operand, int64_t expr_id);
34+
35+
// Factory method for iterative logical not "!" Execution step
36+
std::unique_ptr<ExpressionStep> CreateNotStep(int64_t expr_id);
37+
38+
// Factory method for recursive logical "@not_strictly_false" Execution step.
39+
std::unique_ptr<DirectExpressionStep> CreateDirectNotStrictlyFalseStep(
40+
std::unique_ptr<DirectExpressionStep> operand, int64_t expr_id);
41+
42+
// Factory method for iterative logical "@not_strictly_false" Execution step.
43+
std::unique_ptr<ExpressionStep> CreateNotStrictlyFalseStep(int64_t expr_id);
44+
3145
} // namespace google::api::expr::runtime
3246

3347
#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_LOGIC_STEP_H_

0 commit comments

Comments
 (0)