Skip to content

Commit 37e1e9b

Browse files
TristonianJonescopybara-github
authored andcommitted
Variadic logical operators
PiperOrigin-RevId: 926911637
1 parent a75e273 commit 37e1e9b

11 files changed

Lines changed: 613 additions & 144 deletions

checker/internal/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ cc_library(
145145
"//common:container",
146146
"//common:decl",
147147
"//common:expr",
148+
"//common:standard_definitions",
148149
"//common:type",
149150
"//common:type_kind",
150151
"//internal:lexis",
@@ -238,6 +239,7 @@ cc_library(
238239
deps = [
239240
":format_type_name",
240241
"//common:decl",
242+
"//common:standard_definitions",
241243
"//common:type",
242244
"//common:type_kind",
243245
"@com_google_absl//absl/container:flat_hash_map",

checker/internal/type_checker_impl.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "common/constant.h"
5353
#include "common/decl.h"
5454
#include "common/expr.h"
55+
#include "common/standard_definitions.h"
5556
#include "common/type.h"
5657
#include "common/type_kind.h"
5758
#include "internal/status_macros.h"
@@ -894,8 +895,12 @@ const FunctionDecl* ResolveVisitor::ResolveFunctionCallShape(
894895
if (decl == nullptr) {
895896
return true;
896897
}
898+
bool is_logical_op = (candidate == cel::StandardFunctions::kAnd ||
899+
candidate == cel::StandardFunctions::kOr) &&
900+
arg_count >= 2;
897901
for (const auto& ovl : decl->overloads()) {
898-
if (ovl.member() == is_receiver && ovl.args().size() == arg_count) {
902+
if (ovl.member() == is_receiver &&
903+
(ovl.args().size() == arg_count || is_logical_op)) {
899904
return false;
900905
}
901906
}

checker/internal/type_checker_impl_test.cc

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include "cel/expr/conformance/proto3/test_all_types.pb.h"
5656
#include "google/protobuf/arena.h"
5757
#include "google/protobuf/message.h"
58+
#include "google/protobuf/text_format.h"
5859

5960
namespace cel {
6061
namespace checker_internal {
@@ -1471,6 +1472,93 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) {
14711472
std::make_unique<AstType>(DynTypeSpec())))))));
14721473
}
14731474

1475+
struct VariadicLogicalCheckerTestCase {
1476+
std::string expr;
1477+
};
1478+
1479+
class VariadicLogicalCheckerTest
1480+
: public testing::TestWithParam<VariadicLogicalCheckerTestCase> {};
1481+
1482+
TEST_P(VariadicLogicalCheckerTest, Check) {
1483+
const auto& test_case = GetParam();
1484+
1485+
auto builder = cel::NewParserBuilder();
1486+
builder->GetOptions().enable_variadic_logical_operators = true;
1487+
ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build());
1488+
ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.expr));
1489+
ASSERT_OK_AND_ASSIGN(auto parsed_ast, parser->Parse(*source));
1490+
1491+
google::protobuf::Arena arena;
1492+
TypeCheckEnv env(GetSharedTestingDescriptorPool());
1493+
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());
1494+
TypeCheckerImpl impl(std::move(env));
1495+
auto checker_builder = impl.ToBuilder();
1496+
ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("a", BoolType())),
1497+
IsOk());
1498+
ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("b", BoolType())),
1499+
IsOk());
1500+
ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("c", BoolType())),
1501+
IsOk());
1502+
ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("d", BoolType())),
1503+
IsOk());
1504+
ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("e", BoolType())),
1505+
IsOk());
1506+
1507+
ASSERT_OK_AND_ASSIGN(auto checker, checker_builder->Build());
1508+
ASSERT_OK_AND_ASSIGN(ValidationResult result,
1509+
checker->Check(std::move(parsed_ast)));
1510+
1511+
ASSERT_TRUE(result.IsValid())
1512+
<< absl::StrJoin(result.GetIssues(), "\n",
1513+
[](std::string* out, const TypeCheckIssue& issue) {
1514+
absl::StrAppend(out, issue.message());
1515+
});
1516+
1517+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> checked_ast, result.ReleaseAst());
1518+
EXPECT_THAT(checked_ast->type_map(),
1519+
Contains(Pair(checked_ast->root_expr().id(),
1520+
Eq(AstType(PrimitiveType::kBool)))));
1521+
}
1522+
1523+
INSTANTIATE_TEST_SUITE_P(
1524+
VariadicLogicalChecker, VariadicLogicalCheckerTest,
1525+
testing::Values(VariadicLogicalCheckerTestCase{"true && false && true"},
1526+
VariadicLogicalCheckerTestCase{"a && b && c && d"},
1527+
VariadicLogicalCheckerTestCase{"a || b || c || d"},
1528+
VariadicLogicalCheckerTestCase{"a && b && (c || d || e)"},
1529+
VariadicLogicalCheckerTestCase{"a && b && c"},
1530+
VariadicLogicalCheckerTestCase{"a || b || c"},
1531+
VariadicLogicalCheckerTestCase{"[a, b, c].exists(x, x)"},
1532+
VariadicLogicalCheckerTestCase{"[a, b, c].all(x, x)"}));
1533+
1534+
TEST(TypeCheckerImplTest, VariadicLogicalOperatorsError) {
1535+
cel::expr::ParsedExpr parsed_expr;
1536+
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
1537+
R"pb(
1538+
expr {
1539+
call_expr {
1540+
function: "_&&_"
1541+
args { const_expr { bool_value: true } }
1542+
}
1543+
}
1544+
)pb",
1545+
&parsed_expr));
1546+
ASSERT_OK_AND_ASSIGN(auto parsed_ast,
1547+
cel::CreateAstFromParsedExpr(parsed_expr));
1548+
1549+
google::protobuf::Arena arena;
1550+
TypeCheckEnv env(GetSharedTestingDescriptorPool());
1551+
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());
1552+
TypeCheckerImpl impl(std::move(env));
1553+
ASSERT_OK_AND_ASSIGN(ValidationResult result,
1554+
impl.Check(std::move(parsed_ast)));
1555+
1556+
EXPECT_FALSE(result.IsValid());
1557+
EXPECT_THAT(
1558+
result.GetIssues(),
1559+
Contains(IsIssueWithSubstring(Severity::kError, "undeclared reference")));
1560+
}
1561+
14741562
TEST(TypeCheckerImplTest, ExpectedTypeMatches) {
14751563
google::protobuf::Arena arena;
14761564
TypeCheckEnv env(GetSharedTestingDescriptorPool());

checker/internal/type_inference_context.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "absl/types/span.h"
3131
#include "checker/internal/format_type_name.h"
3232
#include "common/decl.h"
33+
#include "common/standard_definitions.h"
3334
#include "common/type.h"
3435
#include "common/type_kind.h"
3536

@@ -537,21 +538,28 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl,
537538
bool is_receiver) {
538539
std::optional<Type> result_type;
539540

541+
bool is_logical_op = (decl.name() == cel::StandardFunctions::kAnd ||
542+
decl.name() == cel::StandardFunctions::kOr) &&
543+
argument_types.size() >= 2;
544+
540545
std::vector<OverloadDecl> matching_overloads;
541546
for (const auto& ovl : decl.overloads()) {
542547
if (ovl.member() != is_receiver ||
543-
argument_types.size() != ovl.args().size()) {
548+
(!is_logical_op && argument_types.size() != ovl.args().size())) {
544549
continue;
545550
}
546551

547552
auto call_type_instance = InstantiateFunctionOverload(*this, ovl);
548-
ABSL_DCHECK_EQ(argument_types.size(),
549-
call_type_instance.param_types.size());
553+
if (!is_logical_op) {
554+
ABSL_DCHECK_EQ(argument_types.size(),
555+
call_type_instance.param_types.size());
556+
}
550557
bool is_match = true;
551558
AssignabilityContext assignability_context = CreateAssignabilityContext();
552559
for (int i = 0; i < argument_types.size(); ++i) {
560+
int param_index = is_logical_op ? 0 : i;
553561
if (!assignability_context.IsAssignable(
554-
argument_types[i], call_type_instance.param_types[i])) {
562+
argument_types[i], call_type_instance.param_types[param_index])) {
555563
is_match = false;
556564
break;
557565
}

0 commit comments

Comments
 (0)