1616
1717#include < cstddef>
1818#include < memory>
19+ #include < optional>
1920#include < string>
2021#include < utility>
2122#include < vector>
2223
2324#include " absl/base/no_destructor.h"
2425#include " absl/base/nullability.h"
2526#include " absl/cleanup/cleanup.h"
27+ #include " absl/container/btree_set.h"
2628#include " absl/container/flat_hash_map.h"
29+ #include " absl/container/flat_hash_set.h"
2730#include " absl/log/absl_log.h"
2831#include " absl/status/status.h"
2932#include " absl/status/statusor.h"
3033#include " absl/strings/str_cat.h"
3134#include " absl/strings/string_view.h"
3235#include " absl/types/optional.h"
36+ #include " checker/internal/proto_type_mask.h"
3337#include " checker/internal/type_check_env.h"
3438#include " checker/internal/type_checker_impl.h"
3539#include " checker/type_checker.h"
@@ -86,10 +90,19 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) {
8690}
8791
8892absl::Status AddWellKnownContextDeclarationVariables (
89- const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env,
90- bool use_json_name) {
93+ const google::protobuf::Descriptor* absl_nonnull descriptor,
94+ const absl::flat_hash_map<absl::string_view,
95+ absl::btree_set<absl::string_view>>&
96+ context_type_fields,
97+ TypeCheckEnv& env, bool use_json_name) {
9198 for (int i = 0 ; i < descriptor->field_count (); ++i) {
9299 const google::protobuf::FieldDescriptor* field = descriptor->field (i);
100+ // Skip fields that are hidden because of a proto type mask.
101+ auto map_iterator = context_type_fields.find (descriptor->full_name ());
102+ if (map_iterator != context_type_fields.end () &&
103+ !map_iterator->second .contains (field->name ())) {
104+ continue ;
105+ }
93106 Type type = MessageTypeField (field).GetType ();
94107 if (type.IsEnum ()) {
95108 type = IntType ();
@@ -109,11 +122,15 @@ absl::Status AddWellKnownContextDeclarationVariables(
109122}
110123
111124absl::Status AddContextDeclarationVariables (
112- const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) {
125+ const google::protobuf::Descriptor* absl_nonnull descriptor,
126+ const absl::flat_hash_map<absl::string_view,
127+ absl::btree_set<absl::string_view>>&
128+ context_type_fields,
129+ TypeCheckEnv& env) {
113130 const bool use_json_name = env.proto_type_introspector ().use_json_name ();
114131 if (IsWellKnownMessageType (descriptor)) {
115- return AddWellKnownContextDeclarationVariables (descriptor, env,
116- use_json_name);
132+ return AddWellKnownContextDeclarationVariables (
133+ descriptor, context_type_fields, env, use_json_name);
117134 }
118135 CEL_ASSIGN_OR_RETURN (auto fields,
119136 env.proto_type_introspector ().ListFieldsForStructType (
@@ -131,6 +148,13 @@ absl::Status AddContextDeclarationVariables(
131148
132149 absl::string_view name = field_entry.name ;
133150
151+ // Skip fields that are hidden because of a proto type mask.
152+ auto map_iterator = context_type_fields.find (descriptor->full_name ());
153+ if (map_iterator != context_type_fields.end () &&
154+ !map_iterator->second .contains (name)) {
155+ continue ;
156+ }
157+
134158 if (!env.InsertVariableIfAbsent (MakeVariableDecl (name, type))) {
135159 return absl::AlreadyExistsError (
136160 absl::StrCat (" variable '" , name,
@@ -317,7 +341,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig(
317341 }
318342
319343 for (const google::protobuf::Descriptor* context_type : config.context_types ) {
320- CEL_RETURN_IF_ERROR (AddContextDeclarationVariables (context_type, env));
344+ CEL_RETURN_IF_ERROR (AddContextDeclarationVariables (
345+ context_type, config.context_type_fields , env));
321346 }
322347
323348 for (VariableDeclRecord& var : config.variables ) {
@@ -339,6 +364,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig(
339364 }
340365 }
341366
367+ CEL_RETURN_IF_ERROR (env.CreateProtoTypeMaskRegistry (config.proto_type_masks ));
368+
342369 return absl::OkStatus ();
343370}
344371
@@ -462,6 +489,23 @@ absl::Status TypeCheckerBuilderImpl::AddContextDeclaration(
462489 return absl::OkStatus ();
463490}
464491
492+ absl::Status TypeCheckerBuilderImpl::AddContextDeclarationWithProtoTypeMask (
493+ absl::string_view type, std::vector<std::string> field_paths) {
494+ if (field_paths.empty ()) {
495+ return absl::InvalidArgumentError (" field paths cannot be the empty set" );
496+ }
497+
498+ ProtoTypeMask proto_type_mask (std::string (type), field_paths);
499+ target_config_->proto_type_masks .push_back (proto_type_mask);
500+
501+ CEL_RETURN_IF_ERROR (AddContextDeclaration (type));
502+ CEL_ASSIGN_OR_RETURN (
503+ absl::btree_set<absl::string_view> field_names,
504+ proto_type_mask.GetFieldNames (template_env_.descriptor_pool ()));
505+ target_config_->context_type_fields .insert ({type, std::move (field_names)});
506+ return absl::OkStatus ();
507+ }
508+
465509absl::Status TypeCheckerBuilderImpl::AddFunction (const FunctionDecl& decl) {
466510 CEL_RETURN_IF_ERROR (
467511 ValidateFunctionDecl (decl, options_.enable_type_parameter_name_validation ,
0 commit comments