From bfb9f187d90903573fcde39dfd410d7b07af0f43 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 8 Jun 2026 16:30:07 -0700 Subject: [PATCH] Support overload lookup and filtering by signature as an alternative to overload_id PiperOrigin-RevId: 928833427 --- checker/BUILD | 1 + checker/internal/type_checker_builder_impl.cc | 4 +- checker/type_checker_builder.h | 2 +- checker/type_checker_builder_factory_test.cc | 17 +++-- checker/type_checker_subset_factory.cc | 11 +-- checker/type_checker_subset_factory_test.cc | 33 +++++++-- common/BUILD | 3 +- common/decl.cc | 69 +++++++++++-------- common/decl.h | 64 +++++++---------- common/decl_test.cc | 47 +++++++++++++ env/config_test.cc | 55 +++++++++++++++ env/env.cc | 34 +++++---- env/env_test.cc | 16 +++++ env/env_yaml_test.cc | 12 ++-- 14 files changed, 260 insertions(+), 108 deletions(-) diff --git a/checker/BUILD b/checker/BUILD index 27a1eb84e..6705cd749 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -229,6 +229,7 @@ cc_library( hdrs = ["type_checker_subset_factory.h"], deps = [ ":type_checker_builder", + "//common:decl", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 9b91fc926..f0332b999 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -187,8 +187,8 @@ std::optional FilterDecl(FunctionDecl decl, FunctionDecl filtered; std::string name = decl.release_name(); std::vector overloads = decl.release_overloads(); - for (const auto& ovl : overloads) { - if (subset.should_include_overload(name, ovl.id())) { + for (auto& ovl : overloads) { + if (subset.should_include_overload(name, ovl)) { absl::Status s = filtered.AddOverload(std::move(ovl)); if (!s.ok()) { // Should not be possible to construct the original decl in a way that diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index f145b8a98..c2d0cbf7b 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -52,7 +52,7 @@ struct CheckerLibrary { // Represents a declaration to only use a subset of a library. struct TypeCheckerSubset { using FunctionPredicate = absl::AnyInvocable; + absl::string_view function, const OverloadDecl& overload) const>; // The id of the library to subset. Only one subset can be applied per // library id. diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index 9c4775e7f..40406948d 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -235,8 +235,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryIncludeSubset) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", - [](absl::string_view /*function*/, absl::string_view overload_id) { - return (overload_id == "add_int" || overload_id == "sub_int"); + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() == "add_int" || overload.id() == "sub_int"); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); @@ -274,9 +274,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryExcludeSubset) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", - [](absl::string_view /*function*/, absl::string_view overload_id) { - return (overload_id != "add_int" && overload_id != "sub_int"); - ; + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() != "add_int" && overload.id() != "sub_int"); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); @@ -313,7 +312,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetRemoveAllOvl) { ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); ASSERT_THAT(builder->AddLibrarySubset({"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { + const OverloadDecl& /*overload*/) { return function != "add"; }}), IsOk()); @@ -352,12 +351,12 @@ TEST(TypeCheckerBuilderTest, AddLibraryOneSubsetPerLibraryId) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { return true; }}), + const OverloadDecl& /*overload*/) { return true; }}), IsOk()); EXPECT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { return true; }}), + const OverloadDecl& /*overload*/) { return true; }}), StatusIs(absl::StatusCode::kAlreadyExists)); } @@ -369,7 +368,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetLibraryIdRequireds) { ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); EXPECT_THAT(builder->AddLibrarySubset({"", [](absl::string_view function, - absl::string_view /*overload_id*/) { + const OverloadDecl& /*overload*/) { return function == "add"; }}), StatusIs(absl::StatusCode::kInvalidArgument)); diff --git a/checker/type_checker_subset_factory.cc b/checker/type_checker_subset_factory.cc index 6a05ce220..80deea3a2 100644 --- a/checker/type_checker_subset_factory.cc +++ b/checker/type_checker_subset_factory.cc @@ -21,14 +21,16 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "checker/type_checker_builder.h" +#include "common/decl.h" namespace cel { TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( - absl::string_view /*function*/, absl::string_view overload_id) { - return overload_ids.contains(overload_id); + absl::string_view /*function*/, const OverloadDecl& overload) { + return overload_ids.contains(overload.id()) || + overload_ids.contains(overload.signature()); }; } @@ -41,8 +43,9 @@ TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( - absl::string_view /*function*/, absl::string_view overload_id) { - return !overload_ids.contains(overload_id); + absl::string_view /*function*/, const OverloadDecl& overload) { + return !overload_ids.contains(overload.id()) && + !overload_ids.contains(overload.signature()); }; } diff --git a/checker/type_checker_subset_factory_test.cc b/checker/type_checker_subset_factory_test.cc index fa38e1c0d..5b644ec7c 100644 --- a/checker/type_checker_subset_factory_test.cc +++ b/checker/type_checker_subset_factory_test.cc @@ -43,6 +43,8 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { StandardOverloadIds::kEquals, StandardOverloadIds::kNotEquals, StandardOverloadIds::kNotStrictlyFalse, + "matches(string,string)", + "string.matches(string)", }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ @@ -65,15 +67,19 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { EXPECT_TRUE(r.IsValid()); + // Allowed by signature. + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); + EXPECT_TRUE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); + EXPECT_TRUE(r.IsValid()); + // Not in allowlist. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_FALSE(r.IsValid()); - - ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); - EXPECT_FALSE(r.IsValid()); } TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { @@ -83,6 +89,8 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { absl::string_view exclude_list[] = { StandardOverloadIds::kMatches, StandardOverloadIds::kMatchesMember, + "size(string)", + "string.size()", }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ @@ -105,18 +113,35 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { EXPECT_TRUE(r.IsValid()); - // Not in allowlist. + // Allowed. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_TRUE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_TRUE(r.IsValid()); + // Excluded by ID. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); EXPECT_FALSE(r.IsValid()); + + // Excluded by signature (top-level function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size('abc')")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size([1, 2, 3])")); + EXPECT_TRUE(r.IsValid()); + + // Excluded by signature (member function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc'.size()")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size member). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("[1, 2, 3].size()")); + EXPECT_TRUE(r.IsValid()); } } // namespace diff --git a/common/BUILD b/common/BUILD index f7c897e57..154b64a15 100644 --- a/common/BUILD +++ b/common/BUILD @@ -151,9 +151,8 @@ cc_library( "//internal:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/common/decl.cc b/common/decl.cc index b338bfd4f..7d87958f0 100644 --- a/common/decl.cc +++ b/common/decl.cc @@ -20,8 +20,8 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -109,43 +109,47 @@ bool SignaturesOverlap(const OverloadDecl& lhs, const OverloadDecl& rhs) { template void AddOverloadInternal(std::string_view function_name, std::vector& insertion_order, - OverloadDeclHashSet& overloads, Overload&& overload, - absl::Status& status) { + absl::flat_hash_map& by_id, + absl::flat_hash_map& by_signature, + Overload&& overload, absl::Status& status) { if (!status.ok()) { return; } - if (overload.id().empty()) { - OverloadDecl overload_decl = overload; - absl::StatusOr overload_id = - common_internal::MakeOverloadSignature( - function_name, overload_decl.args(), overload_decl.member()); - if (!overload_id.ok()) { - status = overload_id.status(); - return; - } - overload_decl.set_id(*overload_id); - AddOverloadInternal(function_name, insertion_order, overloads, - std::move(overload_decl), status); + absl::StatusOr signature = + common_internal::MakeOverloadSignature(function_name, overload.args(), + overload.member()); + if (!signature.ok()) { + status = signature.status(); return; } - if (auto it = overloads.find(overload.id()); it != overloads.end()) { + OverloadDecl mutable_overload = std::forward(overload); + mutable_overload.set_signature(*signature); + + if (mutable_overload.id().empty()) { + mutable_overload.set_id(mutable_overload.signature()); + } + + if (auto it = by_id.find(mutable_overload.id()); it != by_id.end()) { status = absl::AlreadyExistsError( - absl::StrCat("overload already exists: ", overload.id())); + absl::StrCat("overload exists: ", mutable_overload.id())); return; } - for (const auto& existing : overloads) { - if (SignaturesOverlap(overload, existing)) { + + for (const auto& existing : insertion_order) { + if (SignaturesOverlap(mutable_overload, existing)) { status = absl::InvalidArgumentError( absl::StrCat("overload signature collision: ", existing.id(), - " collides with ", overload.id())); + " collides with ", mutable_overload.id())); return; } } - const auto inserted = overloads.insert(std::forward(overload)); - ABSL_DCHECK(inserted.second); - insertion_order.push_back(*inserted.first); + + size_t index = insertion_order.size(); + by_id[mutable_overload.id()] = index; + by_signature[mutable_overload.signature()] = index; + insertion_order.push_back(std::move(mutable_overload)); } void CollectTypeParams(absl::flat_hash_set& type_params, @@ -195,14 +199,25 @@ absl::flat_hash_set OverloadDecl::GetTypeParams() const { void FunctionDecl::AddOverloadImpl(const OverloadDecl& overload, absl::Status& status) { - AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, - overload, status); + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, overload, status); } void FunctionDecl::AddOverloadImpl(OverloadDecl&& overload, absl::Status& status) { - AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, - std::move(overload), status); + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, std::move(overload), status); +} + +const OverloadDecl* FunctionDecl::FindOverloadById(absl::string_view id) const { + if (auto it = overloads_.by_id.find(id); it != overloads_.by_id.end()) { + return &overloads_.insertion_order[it->second]; + } + if (auto it = overloads_.by_signature.find(id); + it != overloads_.by_signature.end()) { + return &overloads_.insertion_order[it->second]; + } + return nullptr; } } // namespace cel diff --git a/common/decl.h b/common/decl.h index 22ee8cbf0..e6e256560 100644 --- a/common/decl.h +++ b/common/decl.h @@ -22,8 +22,8 @@ #include "absl/algorithm/container.h" #include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -191,6 +191,19 @@ class OverloadDecl final { void set_member(bool member) { member_ = member; } + ABSL_MUST_USE_RESULT const std::string& signature() const + ABSL_ATTRIBUTE_LIFETIME_BOUND { + return signature_; + } + + void set_signature(std::string signature) { + signature_ = std::move(signature); + } + + void set_signature(absl::string_view signature) { + signature_.assign(signature.data(), signature.size()); + } + absl::flat_hash_set GetTypeParams() const; private: @@ -198,11 +211,13 @@ class OverloadDecl final { std::vector args_; Type result_ = DynType{}; bool member_ = false; + std::string signature_; }; inline bool operator==(const OverloadDecl& lhs, const OverloadDecl& rhs) { return lhs.id() == rhs.id() && absl::c_equal(lhs.args(), rhs.args()) && - lhs.result() == rhs.result() && lhs.member() == rhs.member(); + lhs.result() == rhs.result() && lhs.member() == rhs.member() && + lhs.signature() == rhs.signature(); } inline bool operator!=(const OverloadDecl& lhs, const OverloadDecl& rhs) { @@ -264,39 +279,6 @@ OverloadDecl MakeMemberOverloadDecl(absl::string_view id, Type result, return overload_decl; } -struct OverloadDeclHash { - using is_transparent = void; - - size_t operator()(const OverloadDecl& overload_decl) const { - return (*this)(overload_decl.id()); - } - - size_t operator()(absl::string_view id) const { return absl::HashOf(id); } -}; - -struct OverloadDeclEqualTo { - using is_transparent = void; - - bool operator()(const OverloadDecl& lhs, const OverloadDecl& rhs) const { - return (*this)(lhs.id(), rhs.id()); - } - - bool operator()(const OverloadDecl& lhs, absl::string_view rhs) const { - return (*this)(lhs.id(), rhs); - } - - bool operator()(absl::string_view lhs, const OverloadDecl& rhs) const { - return (*this)(lhs, rhs.id()); - } - - bool operator()(absl::string_view lhs, absl::string_view rhs) const { - return lhs == rhs; - } -}; - -using OverloadDeclHashSet = - absl::flat_hash_set; - template absl::StatusOr MakeFunctionDecl(std::string name, Overloads&&... overloads); @@ -346,21 +328,27 @@ class FunctionDecl final { return overloads_.insertion_order; } + ABSL_MUST_USE_RESULT const OverloadDecl* FindOverloadById( + absl::string_view id) const; + std::vector release_overloads() { std::vector released = std::move(overloads_.insertion_order); overloads_.insertion_order.clear(); - overloads_.set.clear(); + overloads_.by_id.clear(); + overloads_.by_signature.clear(); return released; } private: struct Overloads { std::vector insertion_order; - OverloadDeclHashSet set; + absl::flat_hash_map by_id; + absl::flat_hash_map by_signature; void Reserve(size_t size) { insertion_order.reserve(size); - set.reserve(size); + by_id.reserve(size); + by_signature.reserve(size); } }; diff --git a/common/decl_test.cc b/common/decl_test.cc index 510cd5017..72e7f1b93 100644 --- a/common/decl_test.cc +++ b/common/decl_test.cc @@ -165,6 +165,53 @@ TEST(FunctionDecl, Overloads) { StatusIs(absl::StatusCode::kInvalidArgument)); } +TEST(FunctionDecl, AddOverloadInvalidSignature) { + FunctionDecl function_decl; + function_decl.set_name("foo"); + // Member overload must have at least one argument (the receiver). + // This should fail to add because signature generation fails. + EXPECT_THAT(function_decl.AddOverload(MakeMemberOverloadDecl(StringType{})), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(FunctionDecl, AddOverloadDuplicateId) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl("hello", + MakeOverloadDecl("foo", StringType{}, StringType{}))); + // Adding another overload with the same ID "foo" should fail. + EXPECT_THAT( + function_decl.AddOverload(MakeOverloadDecl("foo", IntType{}, IntType{})), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(FunctionDecl, FindOverload) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl( + "hello", MakeOverloadDecl("foo", StringType{}, StringType{}), + MakeMemberOverloadDecl("bar", StringType{}, StringType{}), + MakeOverloadDecl(IntType{}, IntType{}))); + + // Find by explicit ID + const OverloadDecl* overload = function_decl.FindOverloadById("foo"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find by ID fallback to signature + overload = function_decl.FindOverloadById("hello(string)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find implicit overload (where ID == signature) + overload = function_decl.FindOverloadById("hello(int)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "hello(int)"); + + // Non-existent + EXPECT_EQ(function_decl.FindOverloadById("non_existent"), nullptr); +} + TEST(FunctionDecl, OverloadId) { google::protobuf::Arena arena; const auto* descriptor = diff --git a/env/config_test.cc b/env/config_test.cc index df0d6f875..8cfc3cf7f 100644 --- a/env/config_test.cc +++ b/env/config_test.cc @@ -88,6 +88,34 @@ INSTANTIATE_TEST_SUITE_P( StandardLibraryConfigTestCase{ .standard_library_config = {}, }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, StandardLibraryConfigTestCase{ .standard_library_config = { @@ -106,6 +134,15 @@ INSTANTIATE_TEST_SUITE_P( .expected_error = "Cannot set both included and excluded functions.", }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, StandardLibraryConfigTestCase{ .standard_library_config = { @@ -114,6 +151,15 @@ INSTANTIATE_TEST_SUITE_P( .expected_error = "Cannot include function '_+_' and also its " "specific overload 'add_list'", }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add(int,int)'", + }, StandardLibraryConfigTestCase{ .standard_library_config = { @@ -121,6 +167,15 @@ INSTANTIATE_TEST_SUITE_P( }, .expected_error = "Cannot exclude function '_+_' and also its " "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add(int,int)'", })); TEST(VariableConfigTest, VariableConfig) { diff --git a/env/env.cc b/env/env.cc index 6cd3a3cdc..38903d1e9 100644 --- a/env/env.cc +++ b/env/env.cc @@ -57,19 +57,24 @@ bool ShouldIncludeMacro(const Config::StandardLibraryConfig& config, bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, absl::string_view function, - absl::string_view overload_id) { - if (config.excluded_functions.contains( - std::make_pair(std::string(function), std::string(overload_id))) || - config.excluded_functions.contains( - std::make_pair(std::string(function), ""))) { - return false; + const OverloadDecl& overload) { + if (!config.excluded_functions.empty()) { + if (config.excluded_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.excluded_functions.contains( + std::make_pair(std::string(function), overload.signature())) || + config.excluded_functions.contains( + std::make_pair(std::string(function), ""))) { + return false; + } } - if (!config.included_functions.empty() && - !config.included_functions.contains( - std::make_pair(std::string(function), "")) && - !config.included_functions.contains( - std::make_pair(std::string(function), std::string(overload_id)))) { - return false; + if (!config.included_functions.empty()) { + return config.included_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.included_functions.contains( + std::make_pair(std::string(function), overload.signature())) || + config.included_functions.contains( + std::make_pair(std::string(function), "")); } return true; } @@ -87,9 +92,8 @@ absl::StatusOr MakeStdlibSubset( }; subset.should_include_overload = [&standard_library_config]( absl::string_view function, - absl::string_view overload_id) { - return ShouldIncludeFunction(standard_library_config, function, - overload_id); + const OverloadDecl& overload) { + return ShouldIncludeFunction(standard_library_config, function, overload); }; return subset; } diff --git a/env/env_test.cc b/env/env_test.cc index b599aa569..f81482ebb 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -280,6 +280,15 @@ INSTANTIATE_TEST_SUITE_P( .expected_invalid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", "'hello' + 'world'"}, }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "_+_(bytes,bytes)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}, + {"_+_", "_+_(string,string)"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, StandardLibraryConfigTestCase{ .standard_library_config = {.excluded_functions = {{"_+_", "add_bytes"}, @@ -294,6 +303,13 @@ INSTANTIATE_TEST_SUITE_P( .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", "'hello' + 'world'"}, }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "_+_(int,int)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + }, StandardLibraryConfigTestCase{ .standard_library_config = {.included_functions = {{"_+_", "add_int64"}, diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index a60048617..8df6ab532 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -153,7 +153,7 @@ TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { - name: "_+_" overloads: - id: add_bytes - - id: add_list + - id: "_+_(list<~A>,list<~A>)" - name: "matches" - name: "timestamp" overloads: @@ -166,7 +166,7 @@ TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { EXPECT_THAT( stdlib_config.included_functions, UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), - std::make_pair("_+_", "add_list"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), std::make_pair("matches", ""), std::make_pair("timestamp", "string_to_timestamp"))) << " Actual stdlib config: " << stdlib_config; @@ -1364,9 +1364,9 @@ std::vector GetExportTestCases() { .included_functions = { std::make_pair("timestamp", "string_to_timestamp"), - std::make_pair("_+_", "add_list"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), std::make_pair("matches", ""), - std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "_+_(bytes,bytes)"), }, })); return config; @@ -1376,8 +1376,8 @@ std::vector GetExportTestCases() { include_functions: - name: "_+_" overloads: - - id: "add_bytes" - - id: "add_list" + - id: "_+_(bytes,bytes)" + - id: "_+_(list<~A>,list<~A>)" - name: "matches" - name: "timestamp" overloads: