-
Notifications
You must be signed in to change notification settings - Fork 37
Expand file tree
/
Copy pathcheck_unknown_fields.cc
More file actions
199 lines (169 loc) · 7.27 KB
/
check_unknown_fields.cc
File metadata and controls
199 lines (169 loc) · 7.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#include "pybind11_protobuf/check_unknown_fields.h"
#include <cassert>
#include <cstdint>
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/optional.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "google/protobuf/unknown_field_set.h"
#include "python/google/protobuf/proto_api.h"
namespace pybind11_protobuf::check_unknown_fields {
namespace {
using AllowListSet = absl::flat_hash_set<std::string>;
using MayContainExtensionsMap =
absl::flat_hash_map<const ::google::protobuf::Descriptor*, bool>;
AllowListSet* GetAllowList() {
static auto* allow_list = new AllowListSet();
return allow_list;
}
std::string MakeAllowListKey(
absl::string_view top_message_descriptor_full_name,
absl::string_view unknown_field_parent_message_fqn) {
return absl::StrCat(top_message_descriptor_full_name, ":",
unknown_field_parent_message_fqn);
}
/// Recurses through the message Descriptor class looking for valid extensions.
/// Stores the result to `memoized`.
bool MessageMayContainExtensionsRecursive(const ::google::protobuf::Descriptor* descriptor,
MayContainExtensionsMap* memoized) {
if (descriptor->extension_range_count() > 0) return true;
auto [it, inserted] = memoized->try_emplace(descriptor, false);
if (!inserted) {
return it->second;
}
for (int i = 0; i < descriptor->field_count(); i++) {
auto* fd = descriptor->field(i);
if (fd->cpp_type() != ::google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) continue;
if (MessageMayContainExtensionsRecursive(fd->message_type(), memoized)) {
(*memoized)[descriptor] = true;
return true;
}
}
return false;
}
bool MessageMayContainExtensionsMemoized(const ::google::protobuf::Descriptor* descriptor) {
static auto* memoized = new MayContainExtensionsMap();
static absl::Mutex lock;
absl::MutexLock l(lock);
return MessageMayContainExtensionsRecursive(descriptor, memoized);
}
struct HasUnknownFields {
HasUnknownFields(const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Descriptor* root_descriptor)
: py_proto_api(py_proto_api), root_descriptor(root_descriptor) {}
std::string FieldFQN() const { return absl::StrJoin(field_fqn_parts, "."); }
std::string FieldFQNWithFieldNumber() const {
return field_fqn_parts.empty()
? absl::StrCat(unknown_field_number)
: absl::StrCat(FieldFQN(), ".", unknown_field_number);
}
bool FindUnknownFieldsRecursive(const ::google::protobuf::Message* sub_message,
uint32_t depth);
std::string BuildErrorMessage() const;
const ::google::protobuf::python::PyProto_API* py_proto_api;
const ::google::protobuf::Descriptor* root_descriptor = nullptr;
const ::google::protobuf::Descriptor* unknown_field_parent_descriptor = nullptr;
std::vector<std::string> field_fqn_parts;
int unknown_field_number;
};
/// Recurses through the message fields class looking for UnknownFields.
bool HasUnknownFields::FindUnknownFieldsRecursive(
const ::google::protobuf::Message* sub_message, uint32_t depth) {
const ::google::protobuf::Reflection& reflection = *sub_message->GetReflection();
// If there are unknown fields, stop searching.
const ::google::protobuf::UnknownFieldSet& unknown_field_set =
reflection.GetUnknownFields(*sub_message);
if (!unknown_field_set.empty()) {
unknown_field_parent_descriptor = sub_message->GetDescriptor();
unknown_field_number = unknown_field_set.field(0).number();
// Stop only if the extension is known by Python.
if (py_proto_api->GetDefaultDescriptorPool()->FindExtensionByNumber(
unknown_field_parent_descriptor, unknown_field_number)) {
field_fqn_parts.resize(depth);
return true;
}
}
// If this message does not include submessages which allow extensions,
// then it cannot include unknown fields.
if (!MessageMayContainExtensionsMemoized(sub_message->GetDescriptor())) {
return false;
}
// Otherwise the method has to check all present fields, including
// extensions to determine if they include unknown fields.
std::vector<const ::google::protobuf::FieldDescriptor*> present_fields;
reflection.ListFields(*sub_message, &present_fields);
for (const auto* field : present_fields) {
if (field->cpp_type() != ::google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
continue;
}
if (field->is_repeated()) {
int field_size = reflection.FieldSize(*sub_message, field);
for (int i = 0; i != field_size; ++i) {
if (FindUnknownFieldsRecursive(
&reflection.GetRepeatedMessage(*sub_message, field, i),
depth + 1U)) {
field_fqn_parts[depth] = field->name();
return true;
}
}
} else if (FindUnknownFieldsRecursive(
&reflection.GetMessage(*sub_message, field), depth + 1U)) {
field_fqn_parts[depth] = field->name();
return true;
}
}
return false;
}
std::string HasUnknownFields::BuildErrorMessage() const {
assert(unknown_field_parent_descriptor != nullptr);
assert(root_descriptor != nullptr);
std::string emsg = absl::StrCat( //
"Proto Message of type ", root_descriptor->full_name(),
" has an Unknown Field");
if (root_descriptor != unknown_field_parent_descriptor) {
absl::StrAppend(&emsg, " with parent of type ",
unknown_field_parent_descriptor->full_name());
}
absl::StrAppend(&emsg, ": ", FieldFQNWithFieldNumber(), " (",
root_descriptor->file()->name());
if (root_descriptor->file() != unknown_field_parent_descriptor->file()) {
absl::StrAppend(&emsg, ", ",
unknown_field_parent_descriptor->file()->name());
}
absl::StrAppend(
&emsg,
"). Please add the required `cc_proto_library` `deps`. "
"Only if there is no alternative to suppressing this error, use "
"`pybind11_protobuf::AllowUnknownFieldsFor(\"",
root_descriptor->full_name(), "\", \"", FieldFQN(),
"\");` (Warning: suppressions may mask critical bugs.)");
return emsg;
}
} // namespace
void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
absl::string_view unknown_field_parent_message_fqn) {
GetAllowList()->insert(MakeAllowListKey(top_message_descriptor_full_name,
unknown_field_parent_message_fqn));
}
absl::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* message) {
const auto* root_descriptor = message->GetDescriptor();
HasUnknownFields search{py_proto_api, root_descriptor};
if (!search.FindUnknownFieldsRecursive(message, 0u)) {
return absl::nullopt;
}
if (GetAllowList()->count(MakeAllowListKey(root_descriptor->full_name(),
search.FieldFQN())) != 0) {
return absl::nullopt;
}
return search.BuildErrorMessage();
}
} // namespace pybind11_protobuf::check_unknown_fields