|
54 | 54 | #include "common/ast.h" |
55 | 55 | #include "common/ast_traverse.h" |
56 | 56 | #include "common/ast_visitor.h" |
| 57 | +#include "common/kind.h" |
57 | 58 | #include "common/memory.h" |
58 | 59 | #include "common/type.h" |
59 | 60 | #include "common/value.h" |
|
68 | 69 | #include "eval/eval/create_map_step.h" |
69 | 70 | #include "eval/eval/create_struct_step.h" |
70 | 71 | #include "eval/eval/direct_expression_step.h" |
| 72 | +#include "eval/eval/equality_steps.h" |
71 | 73 | #include "eval/eval/evaluator_core.h" |
72 | 74 | #include "eval/eval/function_step.h" |
73 | 75 | #include "eval/eval/ident_step.h" |
@@ -527,6 +529,36 @@ class FlatExprVisitor : public cel::AstVisitor { |
527 | 529 | const cel::ast_internal::Call& call) { |
528 | 530 | return HandleNot(expr, call); |
529 | 531 | }; |
| 532 | + if (options_.enable_heterogeneous_equality) { |
| 533 | + for (const auto& in_op : |
| 534 | + {cel::builtin::kIn, cel::builtin::kInDeprecated, |
| 535 | + cel::builtin::kInFunction}) { |
| 536 | + call_handlers_[in_op] = [this](const cel::ast_internal::Expr& expr, |
| 537 | + const cel::ast_internal::Call& call) { |
| 538 | + return HandleHeterogeneousEqualityIn(expr, call); |
| 539 | + }; |
| 540 | + } |
| 541 | + // Try to detect if the environment is setup with a custom equality |
| 542 | + // implementation. |
| 543 | + if (resolver_ |
| 544 | + .FindOverloads(cel::builtin::kEqual, |
| 545 | + /*receiver_style=*/false, |
| 546 | + {cel::Kind::kAny, cel::Kind::kAny}) |
| 547 | + .empty()) { |
| 548 | + call_handlers_[cel::builtin::kEqual] = |
| 549 | + [this](const cel::ast_internal::Expr& expr, |
| 550 | + const cel::ast_internal::Call& call) { |
| 551 | + return HandleHeterogeneousEquality(expr, call, |
| 552 | + /*inequality=*/false); |
| 553 | + }; |
| 554 | + call_handlers_[cel::builtin::kInequal] = |
| 555 | + [this](const cel::ast_internal::Expr& expr, |
| 556 | + const cel::ast_internal::Call& call) { |
| 557 | + return HandleHeterogeneousEquality(expr, call, |
| 558 | + /*inequality=*/true); |
| 559 | + }; |
| 560 | + } |
| 561 | + } |
530 | 562 | } |
531 | 563 | } |
532 | 564 |
|
@@ -1874,6 +1906,13 @@ class FlatExprVisitor : public cel::AstVisitor { |
1874 | 1906 | CallHandlerResult HandleNotStrictlyFalse(const cel::ast_internal::Expr& expr, |
1875 | 1907 | const cel::ast_internal::Call& call); |
1876 | 1908 |
|
| 1909 | + CallHandlerResult HandleHeterogeneousEquality( |
| 1910 | + const cel::ast_internal::Expr& expr, const cel::ast_internal::Call& call, |
| 1911 | + bool inequality); |
| 1912 | + |
| 1913 | + CallHandlerResult HandleHeterogeneousEqualityIn( |
| 1914 | + const cel::ast_internal::Expr& expr, const cel::ast_internal::Call& call); |
| 1915 | + |
1877 | 1916 | const Resolver& resolver_; |
1878 | 1917 | ValueManager& value_factory_; |
1879 | 1918 | absl::Status progress_status_; |
@@ -2026,6 +2065,59 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleListAppend( |
2026 | 2065 | return CallHandlerResult::kNotIntercepted; |
2027 | 2066 | } |
2028 | 2067 |
|
| 2068 | +FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( |
| 2069 | + const cel::ast_internal::Expr& expr, const cel::ast_internal::Call& call, |
| 2070 | + bool inequality) { |
| 2071 | + if (!ValidateOrError( |
| 2072 | + call.args().size() == 2, |
| 2073 | + "unexpected number of args for builtin equality operator")) { |
| 2074 | + return CallHandlerResult::kIntercepted; |
| 2075 | + } |
| 2076 | + auto depth = RecursionEligible(); |
| 2077 | + |
| 2078 | + if (depth.has_value()) { |
| 2079 | + auto args = ExtractRecursiveDependencies(); |
| 2080 | + if (args.size() != 2) { |
| 2081 | + SetProgressStatusError(absl::InvalidArgumentError( |
| 2082 | + "unexpected number of args for builtin equality operator")); |
| 2083 | + return CallHandlerResult::kIntercepted; |
| 2084 | + } |
| 2085 | + SetRecursiveStep( |
| 2086 | + CreateDirectEqualityStep(std::move(args[0]), std::move(args[1]), |
| 2087 | + inequality, expr.id()), |
| 2088 | + *depth + 1); |
| 2089 | + return CallHandlerResult::kIntercepted; |
| 2090 | + } |
| 2091 | + AddStep(CreateEqualityStep(inequality, expr.id())); |
| 2092 | + return CallHandlerResult::kIntercepted; |
| 2093 | +} |
| 2094 | + |
| 2095 | +FlatExprVisitor::CallHandlerResult |
| 2096 | +FlatExprVisitor::HandleHeterogeneousEqualityIn( |
| 2097 | + const cel::ast_internal::Expr& expr, const cel::ast_internal::Call& call) { |
| 2098 | + if (!ValidateOrError(call.args().size() == 2, |
| 2099 | + "unexpected number of args for builtin 'in' operator")) { |
| 2100 | + return CallHandlerResult::kIntercepted; |
| 2101 | + } |
| 2102 | + |
| 2103 | + auto depth = RecursionEligible(); |
| 2104 | + if (depth.has_value()) { |
| 2105 | + auto args = ExtractRecursiveDependencies(); |
| 2106 | + if (args.size() != 2) { |
| 2107 | + SetProgressStatusError(absl::InvalidArgumentError( |
| 2108 | + "unexpected number of args for builtin 'in' operator")); |
| 2109 | + return CallHandlerResult::kIntercepted; |
| 2110 | + } |
| 2111 | + SetRecursiveStep( |
| 2112 | + CreateDirectInStep(std::move(args[0]), std::move(args[1]), expr.id()), |
| 2113 | + *depth + 1); |
| 2114 | + return CallHandlerResult::kIntercepted; |
| 2115 | + } |
| 2116 | + |
| 2117 | + AddStep(CreateInStep(expr.id())); |
| 2118 | + return CallHandlerResult::kIntercepted; |
| 2119 | +} |
| 2120 | + |
2029 | 2121 | void BinaryCondVisitor::PreVisit(const cel::ast_internal::Expr* expr) { |
2030 | 2122 | switch (cond_) { |
2031 | 2123 | case BinaryCond::kAnd: |
|
0 commit comments