Skip to content

Commit b047e87

Browse files
dmitriplotnikovcopybara-github
authored andcommitted
Support overload lookup and filtering by signature as an alternative to overload_id
PiperOrigin-RevId: 928833427
1 parent 87fea87 commit b047e87

14 files changed

Lines changed: 240 additions & 108 deletions

checker/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ cc_library(
229229
hdrs = ["type_checker_subset_factory.h"],
230230
deps = [
231231
":type_checker_builder",
232+
"//common:decl",
232233
"@com_google_absl//absl/container:flat_hash_set",
233234
"@com_google_absl//absl/strings:string_view",
234235
"@com_google_absl//absl/types:span",

checker/internal/type_checker_builder_impl.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ std::optional<FunctionDecl> FilterDecl(FunctionDecl decl,
163163
FunctionDecl filtered;
164164
std::string name = decl.release_name();
165165
std::vector<OverloadDecl> overloads = decl.release_overloads();
166-
for (const auto& ovl : overloads) {
167-
if (subset.should_include_overload(name, ovl.id())) {
166+
for (auto& ovl : overloads) {
167+
if (subset.should_include_overload(name, ovl)) {
168168
absl::Status s = filtered.AddOverload(std::move(ovl));
169169
if (!s.ok()) {
170170
// Should not be possible to construct the original decl in a way that

checker/type_checker_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct CheckerLibrary {
5151
// Represents a declaration to only use a subset of a library.
5252
struct TypeCheckerSubset {
5353
using FunctionPredicate = absl::AnyInvocable<bool(
54-
absl::string_view function, absl::string_view overload_id) const>;
54+
absl::string_view function, const OverloadDecl& overload) const>;
5555

5656
// The id of the library to subset. Only one subset can be applied per
5757
// library id.

checker/type_checker_builder_factory_test.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryIncludeSubset) {
235235
ASSERT_THAT(
236236
builder->AddLibrarySubset(
237237
{"testlib",
238-
[](absl::string_view /*function*/, absl::string_view overload_id) {
239-
return (overload_id == "add_int" || overload_id == "sub_int");
238+
[](absl::string_view /*function*/, const OverloadDecl& overload) {
239+
return (overload.id() == "add_int" || overload.id() == "sub_int");
240240
}}),
241241
IsOk());
242242
ASSERT_OK_AND_ASSIGN(auto checker, builder->Build());
@@ -274,9 +274,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryExcludeSubset) {
274274
ASSERT_THAT(
275275
builder->AddLibrarySubset(
276276
{"testlib",
277-
[](absl::string_view /*function*/, absl::string_view overload_id) {
278-
return (overload_id != "add_int" && overload_id != "sub_int");
279-
;
277+
[](absl::string_view /*function*/, const OverloadDecl& overload) {
278+
return (overload.id() != "add_int" && overload.id() != "sub_int");
280279
}}),
281280
IsOk());
282281
ASSERT_OK_AND_ASSIGN(auto checker, builder->Build());
@@ -313,7 +312,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetRemoveAllOvl) {
313312
ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk());
314313
ASSERT_THAT(builder->AddLibrarySubset({"testlib",
315314
[](absl::string_view function,
316-
absl::string_view /*overload_id*/) {
315+
const OverloadDecl& /*overload*/) {
317316
return function != "add";
318317
}}),
319318
IsOk());
@@ -352,12 +351,12 @@ TEST(TypeCheckerBuilderTest, AddLibraryOneSubsetPerLibraryId) {
352351
ASSERT_THAT(
353352
builder->AddLibrarySubset(
354353
{"testlib", [](absl::string_view function,
355-
absl::string_view /*overload_id*/) { return true; }}),
354+
const OverloadDecl& /*overload*/) { return true; }}),
356355
IsOk());
357356
EXPECT_THAT(
358357
builder->AddLibrarySubset(
359358
{"testlib", [](absl::string_view function,
360-
absl::string_view /*overload_id*/) { return true; }}),
359+
const OverloadDecl& /*overload*/) { return true; }}),
361360
StatusIs(absl::StatusCode::kAlreadyExists));
362361
}
363362

@@ -369,7 +368,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetLibraryIdRequireds) {
369368
ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk());
370369
EXPECT_THAT(builder->AddLibrarySubset({"",
371370
[](absl::string_view function,
372-
absl::string_view /*overload_id*/) {
371+
const OverloadDecl& /*overload*/) {
373372
return function == "add";
374373
}}),
375374
StatusIs(absl::StatusCode::kInvalidArgument));

checker/type_checker_subset_factory.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@
2121
#include "absl/strings/string_view.h"
2222
#include "absl/types/span.h"
2323
#include "checker/type_checker_builder.h"
24+
#include "common/decl.h"
2425

2526
namespace cel {
2627

2728
TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate(
2829
absl::flat_hash_set<std::string> overload_ids) {
2930
return [overload_ids = std::move(overload_ids)](
30-
absl::string_view /*function*/, absl::string_view overload_id) {
31-
return overload_ids.contains(overload_id);
31+
absl::string_view /*function*/, const OverloadDecl& overload) {
32+
return overload_ids.contains(overload.id()) ||
33+
overload_ids.contains(overload.signature());
3234
};
3335
}
3436

@@ -41,8 +43,9 @@ TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate(
4143
TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate(
4244
absl::flat_hash_set<std::string> overload_ids) {
4345
return [overload_ids = std::move(overload_ids)](
44-
absl::string_view /*function*/, absl::string_view overload_id) {
45-
return !overload_ids.contains(overload_id);
46+
absl::string_view /*function*/, const OverloadDecl& overload) {
47+
return !overload_ids.contains(overload.id()) &&
48+
!overload_ids.contains(overload.signature());
4649
};
4750
}
4851

checker/type_checker_subset_factory_test.cc

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) {
4343
StandardOverloadIds::kEquals,
4444
StandardOverloadIds::kNotEquals,
4545
StandardOverloadIds::kNotStrictlyFalse,
46+
"matches(string,string)",
47+
"string.matches(string)",
4648
};
4749
ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk());
4850
ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({
@@ -65,15 +67,19 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) {
6567

6668
EXPECT_TRUE(r.IsValid());
6769

70+
// Allowed by signature.
71+
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')"));
72+
EXPECT_TRUE(r.IsValid());
73+
74+
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')"));
75+
EXPECT_TRUE(r.IsValid());
76+
6877
// Not in allowlist.
6978
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3"));
7079
EXPECT_FALSE(r.IsValid());
7180

7281
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'"));
7382
EXPECT_FALSE(r.IsValid());
74-
75-
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')"));
76-
EXPECT_FALSE(r.IsValid());
7783
}
7884

7985
TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) {
@@ -83,6 +89,8 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) {
8389
absl::string_view exclude_list[] = {
8490
StandardOverloadIds::kMatches,
8591
StandardOverloadIds::kMatchesMember,
92+
"size(string)",
93+
"string.size()",
8694
};
8795
ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk());
8896
ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({
@@ -105,18 +113,35 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) {
105113

106114
EXPECT_TRUE(r.IsValid());
107115

108-
// Not in allowlist.
116+
// Allowed.
109117
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3"));
110118
EXPECT_TRUE(r.IsValid());
111119

112120
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'"));
113121
EXPECT_TRUE(r.IsValid());
114122

123+
// Excluded by ID.
115124
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')"));
116125
EXPECT_FALSE(r.IsValid());
117126

118127
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')"));
119128
EXPECT_FALSE(r.IsValid());
129+
130+
// Excluded by signature (top-level function).
131+
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size('abc')"));
132+
EXPECT_FALSE(r.IsValid());
133+
134+
// Allowed (other overloads of size).
135+
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size([1, 2, 3])"));
136+
EXPECT_TRUE(r.IsValid());
137+
138+
// Excluded by signature (member function).
139+
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc'.size()"));
140+
EXPECT_FALSE(r.IsValid());
141+
142+
// Allowed (other overloads of size member).
143+
ASSERT_OK_AND_ASSIGN(r, compiler->Compile("[1, 2, 3].size()"));
144+
EXPECT_TRUE(r.IsValid());
120145
}
121146

122147
} // namespace

common/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,8 @@ cc_library(
151151
"//internal:status_macros",
152152
"@com_google_absl//absl/algorithm:container",
153153
"@com_google_absl//absl/base:core_headers",
154+
"@com_google_absl//absl/container:flat_hash_map",
154155
"@com_google_absl//absl/container:flat_hash_set",
155-
"@com_google_absl//absl/hash",
156-
"@com_google_absl//absl/log:absl_check",
157156
"@com_google_absl//absl/status",
158157
"@com_google_absl//absl/status:statusor",
159158
"@com_google_absl//absl/strings",

common/decl.cc

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#include <utility>
2121
#include <vector>
2222

23+
#include "absl/container/flat_hash_map.h"
2324
#include "absl/container/flat_hash_set.h"
24-
#include "absl/log/absl_check.h"
2525
#include "absl/status/status.h"
2626
#include "absl/status/statusor.h"
2727
#include "absl/strings/str_cat.h"
@@ -109,43 +109,47 @@ bool SignaturesOverlap(const OverloadDecl& lhs, const OverloadDecl& rhs) {
109109
template <typename Overload>
110110
void AddOverloadInternal(std::string_view function_name,
111111
std::vector<OverloadDecl>& insertion_order,
112-
OverloadDeclHashSet& overloads, Overload&& overload,
113-
absl::Status& status) {
112+
absl::flat_hash_map<std::string, size_t>& by_id,
113+
absl::flat_hash_map<std::string, size_t>& by_signature,
114+
Overload&& overload, absl::Status& status) {
114115
if (!status.ok()) {
115116
return;
116117
}
117118

118-
if (overload.id().empty()) {
119-
OverloadDecl overload_decl = overload;
120-
absl::StatusOr<std::string> overload_id =
121-
common_internal::MakeOverloadSignature(
122-
function_name, overload_decl.args(), overload_decl.member());
123-
if (!overload_id.ok()) {
124-
status = overload_id.status();
125-
return;
126-
}
127-
overload_decl.set_id(*overload_id);
128-
AddOverloadInternal(function_name, insertion_order, overloads,
129-
std::move(overload_decl), status);
119+
absl::StatusOr<std::string> signature =
120+
common_internal::MakeOverloadSignature(function_name, overload.args(),
121+
overload.member());
122+
if (!signature.ok()) {
123+
status = signature.status();
130124
return;
131125
}
132126

133-
if (auto it = overloads.find(overload.id()); it != overloads.end()) {
127+
OverloadDecl mutable_overload = std::forward<Overload>(overload);
128+
mutable_overload.set_signature(*signature);
129+
130+
if (mutable_overload.id().empty()) {
131+
mutable_overload.set_id(mutable_overload.signature());
132+
}
133+
134+
if (auto it = by_id.find(mutable_overload.id()); it != by_id.end()) {
134135
status = absl::AlreadyExistsError(
135-
absl::StrCat("overload already exists: ", overload.id()));
136+
absl::StrCat("overload exists: ", mutable_overload.id()));
136137
return;
137138
}
138-
for (const auto& existing : overloads) {
139-
if (SignaturesOverlap(overload, existing)) {
139+
140+
for (const auto& existing : insertion_order) {
141+
if (SignaturesOverlap(mutable_overload, existing)) {
140142
status = absl::InvalidArgumentError(
141143
absl::StrCat("overload signature collision: ", existing.id(),
142-
" collides with ", overload.id()));
144+
" collides with ", mutable_overload.id()));
143145
return;
144146
}
145147
}
146-
const auto inserted = overloads.insert(std::forward<Overload>(overload));
147-
ABSL_DCHECK(inserted.second);
148-
insertion_order.push_back(*inserted.first);
148+
149+
size_t index = insertion_order.size();
150+
by_id[mutable_overload.id()] = index;
151+
by_signature[mutable_overload.signature()] = index;
152+
insertion_order.push_back(std::move(mutable_overload));
149153
}
150154

151155
void CollectTypeParams(absl::flat_hash_set<std::string>& type_params,
@@ -195,14 +199,25 @@ absl::flat_hash_set<std::string> OverloadDecl::GetTypeParams() const {
195199

196200
void FunctionDecl::AddOverloadImpl(const OverloadDecl& overload,
197201
absl::Status& status) {
198-
AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set,
199-
overload, status);
202+
AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id,
203+
overloads_.by_signature, overload, status);
200204
}
201205

202206
void FunctionDecl::AddOverloadImpl(OverloadDecl&& overload,
203207
absl::Status& status) {
204-
AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set,
205-
std::move(overload), status);
208+
AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id,
209+
overloads_.by_signature, std::move(overload), status);
210+
}
211+
212+
const OverloadDecl* FunctionDecl::FindOverloadById(absl::string_view id) const {
213+
if (auto it = overloads_.by_id.find(id); it != overloads_.by_id.end()) {
214+
return &overloads_.insertion_order[it->second];
215+
}
216+
if (auto it = overloads_.by_signature.find(id);
217+
it != overloads_.by_signature.end()) {
218+
return &overloads_.insertion_order[it->second];
219+
}
220+
return nullptr;
206221
}
207222

208223
} // namespace cel

0 commit comments

Comments
 (0)