|
| 1 | +// Copyright 2025 Google LLC |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_ |
| 16 | +#define THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_ |
| 17 | + |
| 18 | +#include <cstddef> |
| 19 | +#include <type_traits> |
| 20 | + |
| 21 | +#include "absl/container/inlined_vector.h" |
| 22 | +#include "absl/log/absl_check.h" |
| 23 | +#include "absl/types/optional.h" |
| 24 | +#include "common/typeinfo.h" |
| 25 | +#include "common/value.h" |
| 26 | + |
| 27 | +namespace cel { |
| 28 | + |
| 29 | +// EmbedderContext is used to package custom content defined by the embedder |
| 30 | +// during CEL evaluation. The custom content is indexed by type. Value types |
| 31 | +// are returned as absl::optional<T> where T is the value type. Pointer types |
| 32 | +// are returned as T*. |
| 33 | +// |
| 34 | +// The content values must be trivially copyable and have a size <= 16 bytes. |
| 35 | +// These are typically pointers or small value types (e.g. primitives, enums). |
| 36 | +// |
| 37 | +// An all zero memory value is used to represent an empty value. The caller |
| 38 | +// must provide some way to disambiguate if that is a meaningfully distinct |
| 39 | +// value from nullopt / nullptr. |
| 40 | +// |
| 41 | +// Scope is used to provide a distinction between multiple usages of CEL in the |
| 42 | +// same binary. |
| 43 | +class EmbedderContext { |
| 44 | + public: |
| 45 | + template <typename Scope, typename... Args> |
| 46 | + static EmbedderContext From(Args... args); |
| 47 | + |
| 48 | + // Convenience using a default scope. |
| 49 | + template <typename... Args> |
| 50 | + static EmbedderContext From(Args... args) { |
| 51 | + return From<EmbedderContext, Args...>(args...); |
| 52 | + } |
| 53 | + |
| 54 | + template <typename Scope, typename T> |
| 55 | + std::enable_if_t<!std::is_pointer_v<T>, absl::optional<T>> Get() const; |
| 56 | + |
| 57 | + template <typename Scope, typename T> |
| 58 | + std::enable_if_t<std::is_pointer_v<T>, T> Get() const; |
| 59 | + |
| 60 | + template <typename T> |
| 61 | + std::enable_if_t<!std::is_pointer_v<T>, absl::optional<T>> Get() const { |
| 62 | + return Get<EmbedderContext, T>(); |
| 63 | + } |
| 64 | + |
| 65 | + template <typename T> |
| 66 | + std::enable_if_t<std::is_pointer_v<T>, T> Get() const { |
| 67 | + return Get<EmbedderContext, T>(); |
| 68 | + } |
| 69 | + |
| 70 | + private: |
| 71 | + template <typename Scope, typename T, typename... Ts> |
| 72 | + void Set(T arg, Ts... args); |
| 73 | + |
| 74 | + template <typename Scope> |
| 75 | + void Set() {} |
| 76 | + |
| 77 | + absl::InlinedVector<cel::CustomValueContent, 2> values_; |
| 78 | + // These are included to check for bad accesses in debug mode. |
| 79 | + absl::InlinedVector<TypeInfo, 2> type_ids_; |
| 80 | + TypeInfo scope_; |
| 81 | +}; |
| 82 | + |
| 83 | +template <typename Scope, typename Arg, typename... Args> |
| 84 | +void EmbedderContext::Set(Arg arg, Args... args) { |
| 85 | + using IndexType = std::decay_t<Arg>; |
| 86 | + size_t index = TypeIdInSet<Scope>::template IndexFor<IndexType>(); |
| 87 | + if (index >= values_.size()) { |
| 88 | + values_.resize(index + 1, cel::CustomValueContent::Zero()); |
| 89 | + type_ids_.resize(index + 1); |
| 90 | + } |
| 91 | + values_[index] = cel::CustomValueContent::From(arg); |
| 92 | + type_ids_[index] = cel::TypeId<IndexType>(); |
| 93 | + Set<Scope>(args...); |
| 94 | +} |
| 95 | + |
| 96 | +template <typename Scope, typename T> |
| 97 | +std::enable_if_t<!std::is_pointer_v<T>, absl::optional<T>> |
| 98 | +EmbedderContext::Get() const { |
| 99 | + ABSL_DCHECK_EQ(cel::TypeId<Scope>(), scope_) |
| 100 | + << "EmbedderContext::Get wrong scope"; |
| 101 | + using IndexType = std::decay_t<T>; |
| 102 | + size_t index = TypeIdInSet<Scope>::template IndexFor<IndexType>(); |
| 103 | + if (index >= values_.size()) { |
| 104 | + return absl::nullopt; |
| 105 | + } |
| 106 | + |
| 107 | + const auto& content = values_[index]; |
| 108 | + if (content.IsZero()) return absl::nullopt; |
| 109 | + |
| 110 | + ABSL_DCHECK_EQ(type_ids_.size(), values_.size()); |
| 111 | + ABSL_DCHECK_EQ(type_ids_[index], cel::TypeId<IndexType>()) |
| 112 | + << "EmbedderContext::Get wrong type id"; |
| 113 | + |
| 114 | + return content.To<T>(); |
| 115 | +} |
| 116 | + |
| 117 | +template <typename Scope, typename T> |
| 118 | +std::enable_if_t<std::is_pointer_v<T>, T> EmbedderContext::Get() const { |
| 119 | + ABSL_DCHECK_EQ(cel::TypeId<Scope>(), scope_) |
| 120 | + << "EmbedderContext::Get wrong scope"; |
| 121 | + using IndexType = std::decay_t<T>; |
| 122 | + size_t index = TypeIdInSet<Scope>::template IndexFor<IndexType>(); |
| 123 | + if (index >= values_.size()) { |
| 124 | + return nullptr; |
| 125 | + } |
| 126 | + |
| 127 | + const auto& content = values_[index]; |
| 128 | + if (content.IsZero()) return nullptr; |
| 129 | + |
| 130 | + ABSL_DCHECK_EQ(type_ids_.size(), values_.size()); |
| 131 | + ABSL_DCHECK_EQ(type_ids_[index], cel::TypeId<IndexType>()) |
| 132 | + << "EmbedderContext::Get wrong type id"; |
| 133 | + |
| 134 | + return content.To<T>(); |
| 135 | +} |
| 136 | + |
| 137 | +template <typename Scope, typename... Args> |
| 138 | +EmbedderContext EmbedderContext::From(Args... args) { |
| 139 | + EmbedderContext context; |
| 140 | + context.scope_ = TypeId<Scope>(); |
| 141 | + context.Set<Scope>(args...); |
| 142 | + return context; |
| 143 | +} |
| 144 | + |
| 145 | +} // namespace cel |
| 146 | + |
| 147 | +#endif // THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_ |
0 commit comments