Skip to content

Commit 707a4ce

Browse files
jnthntatumcopybara-github
authored andcommitted
Add an EmbedderContext type for passing opaque state from the
calling application to extension functions. PiperOrigin-RevId: 846796365
1 parent cc043aa commit 707a4ce

5 files changed

Lines changed: 292 additions & 1 deletion

File tree

common/typeinfo.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPEINFO_H_
1616
#define THIRD_PARTY_CEL_CPP_COMMON_TYPEINFO_H_
1717

18+
#include <atomic>
1819
#include <cstddef>
1920
#include <ostream>
2021
#include <string>
@@ -175,7 +176,6 @@ struct TypeTag final {
175176

176177
template <typename T>
177178
TypeInfo TypeId() {
178-
static_assert(!std::is_pointer_v<T>);
179179
static_assert(std::is_same_v<T, std::decay_t<T>>);
180180
static_assert(!std::is_same_v<TypeInfo, std::decay_t<T>>);
181181
#ifdef CEL_INTERNAL_HAVE_RTTI
@@ -193,6 +193,29 @@ inline std::ostream& operator<<(std::ostream& out, TypeInfo id) {
193193
return out << id.DebugString();
194194
}
195195

196+
// Helper class for adapting a type to an index in a tuple or array.
197+
// Scope is an arbitrary type used as a namespace for the index.
198+
template <typename Scope>
199+
class TypeIdInSet {
200+
public:
201+
template <typename T>
202+
static size_t IndexFor() {
203+
static size_t index =
204+
type_id_set_index_.fetch_add(1, std::memory_order_relaxed);
205+
return index;
206+
}
207+
208+
static size_t Size() {
209+
return type_id_set_index_.load(std::memory_order_relaxed);
210+
}
211+
212+
private:
213+
static std::atomic<size_t> type_id_set_index_;
214+
};
215+
216+
template <typename Scope>
217+
std::atomic<size_t> TypeIdInSet<Scope>::type_id_set_index_ = 0;
218+
196219
} // namespace cel
197220

198221
#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPEINFO_H_

common/values/custom_value.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ class CustomValueContent final {
7070
return value;
7171
}
7272

73+
bool IsZero() const {
74+
static const CustomValueContent kZero = Zero();
75+
return std::memcmp(raw_, kZero.raw_, sizeof(raw_)) == 0;
76+
}
77+
7378
private:
7479
alignas(void*) std::byte raw_[16];
7580
};

runtime/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,3 +650,26 @@ cc_test(
650650
"@com_google_protobuf//:protobuf",
651651
],
652652
)
653+
654+
cc_library(
655+
name = "embedder_context",
656+
hdrs = ["embedder_context.h"],
657+
deps = [
658+
"//common:typeinfo",
659+
"//common:value",
660+
"@com_google_absl//absl/container:inlined_vector",
661+
"@com_google_absl//absl/log:absl_check",
662+
"@com_google_absl//absl/types:optional",
663+
],
664+
)
665+
666+
cc_test(
667+
name = "embedder_context_test",
668+
srcs = ["embedder_context_test.cc"],
669+
deps = [
670+
":embedder_context",
671+
"//common:typeinfo",
672+
"//internal:testing",
673+
"@com_google_absl//absl/types:optional",
674+
],
675+
)

runtime/embedder_context.h

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_
16+
#define THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_
17+
18+
#include <cstddef>
19+
#include <type_traits>
20+
21+
#include "absl/container/inlined_vector.h"
22+
#include "absl/log/absl_check.h"
23+
#include "absl/types/optional.h"
24+
#include "common/typeinfo.h"
25+
#include "common/value.h"
26+
27+
namespace cel {
28+
29+
// EmbedderContext is used to package custom content defined by the embedder
30+
// during CEL evaluation. The custom content is indexed by type. Value types
31+
// are returned as absl::optional<T> where T is the value type. Pointer types
32+
// are returned as T*.
33+
//
34+
// The content values must be trivially copyable and have a size <= 16 bytes.
35+
// These are typically pointers or small value types (e.g. primitives, enums).
36+
//
37+
// An all zero memory value is used to represent an empty value. The caller
38+
// must provide some way to disambiguate if that is a meaningfully distinct
39+
// value from nullopt / nullptr.
40+
//
41+
// Scope is used to provide a distinction between multiple usages of CEL in the
42+
// same binary.
43+
class EmbedderContext {
44+
public:
45+
template <typename Scope, typename... Args>
46+
static EmbedderContext From(Args... args);
47+
48+
// Convenience using a default scope.
49+
template <typename... Args>
50+
static EmbedderContext From(Args... args) {
51+
return From<EmbedderContext, Args...>(args...);
52+
}
53+
54+
template <typename Scope, typename T>
55+
std::enable_if_t<!std::is_pointer_v<T>, absl::optional<T>> Get() const;
56+
57+
template <typename Scope, typename T>
58+
std::enable_if_t<std::is_pointer_v<T>, T> Get() const;
59+
60+
template <typename T>
61+
std::enable_if_t<!std::is_pointer_v<T>, absl::optional<T>> Get() const {
62+
return Get<EmbedderContext, T>();
63+
}
64+
65+
template <typename T>
66+
std::enable_if_t<std::is_pointer_v<T>, T> Get() const {
67+
return Get<EmbedderContext, T>();
68+
}
69+
70+
private:
71+
template <typename Scope, typename T, typename... Ts>
72+
void Set(T arg, Ts... args);
73+
74+
template <typename Scope>
75+
void Set() {}
76+
77+
absl::InlinedVector<cel::CustomValueContent, 2> values_;
78+
// These are included to check for bad accesses in debug mode.
79+
absl::InlinedVector<TypeInfo, 2> type_ids_;
80+
TypeInfo scope_;
81+
};
82+
83+
template <typename Scope, typename Arg, typename... Args>
84+
void EmbedderContext::Set(Arg arg, Args... args) {
85+
using IndexType = std::decay_t<Arg>;
86+
size_t index = TypeIdInSet<Scope>::template IndexFor<IndexType>();
87+
if (index >= values_.size()) {
88+
values_.resize(index + 1, cel::CustomValueContent::Zero());
89+
type_ids_.resize(index + 1);
90+
}
91+
values_[index] = cel::CustomValueContent::From(arg);
92+
type_ids_[index] = cel::TypeId<IndexType>();
93+
Set<Scope>(args...);
94+
}
95+
96+
template <typename Scope, typename T>
97+
std::enable_if_t<!std::is_pointer_v<T>, absl::optional<T>>
98+
EmbedderContext::Get() const {
99+
ABSL_DCHECK_EQ(cel::TypeId<Scope>(), scope_)
100+
<< "EmbedderContext::Get wrong scope";
101+
using IndexType = std::decay_t<T>;
102+
size_t index = TypeIdInSet<Scope>::template IndexFor<IndexType>();
103+
if (index >= values_.size()) {
104+
return absl::nullopt;
105+
}
106+
107+
const auto& content = values_[index];
108+
if (content.IsZero()) return absl::nullopt;
109+
110+
ABSL_DCHECK_EQ(type_ids_.size(), values_.size());
111+
ABSL_DCHECK_EQ(type_ids_[index], cel::TypeId<IndexType>())
112+
<< "EmbedderContext::Get wrong type id";
113+
114+
return content.To<T>();
115+
}
116+
117+
template <typename Scope, typename T>
118+
std::enable_if_t<std::is_pointer_v<T>, T> EmbedderContext::Get() const {
119+
ABSL_DCHECK_EQ(cel::TypeId<Scope>(), scope_)
120+
<< "EmbedderContext::Get wrong scope";
121+
using IndexType = std::decay_t<T>;
122+
size_t index = TypeIdInSet<Scope>::template IndexFor<IndexType>();
123+
if (index >= values_.size()) {
124+
return nullptr;
125+
}
126+
127+
const auto& content = values_[index];
128+
if (content.IsZero()) return nullptr;
129+
130+
ABSL_DCHECK_EQ(type_ids_.size(), values_.size());
131+
ABSL_DCHECK_EQ(type_ids_[index], cel::TypeId<IndexType>())
132+
<< "EmbedderContext::Get wrong type id";
133+
134+
return content.To<T>();
135+
}
136+
137+
template <typename Scope, typename... Args>
138+
EmbedderContext EmbedderContext::From(Args... args) {
139+
EmbedderContext context;
140+
context.scope_ = TypeId<Scope>();
141+
context.Set<Scope>(args...);
142+
return context;
143+
}
144+
145+
} // namespace cel
146+
147+
#endif // THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_

runtime/embedder_context_test.cc

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "runtime/embedder_context.h"
16+
17+
#include <cstdint>
18+
19+
#include "absl/types/optional.h"
20+
#include "common/typeinfo.h"
21+
#include "internal/testing.h"
22+
23+
namespace cel {
24+
namespace {
25+
26+
using ::testing::Optional;
27+
28+
TEST(EmbedderContextTest, From) {
29+
struct TestScope {};
30+
EmbedderContext context = EmbedderContext::From<TestScope>(int64_t{42});
31+
EXPECT_THAT((context.Get<TestScope, int64_t>()), Optional(42));
32+
EXPECT_EQ((context.Get<TestScope, uint64_t>()), absl::nullopt);
33+
34+
EmbedderContext context2 = EmbedderContext::From<TestScope>(uint64_t{42});
35+
EXPECT_THAT((context2.Get<TestScope, uint64_t>()), Optional(42));
36+
EXPECT_EQ((context2.Get<TestScope, int64_t>()), absl::nullopt);
37+
38+
// Side effect, but checking that we keep a dense range.
39+
EXPECT_EQ(cel::TypeIdInSet<TestScope>::Size(), 2);
40+
}
41+
42+
TEST(EmbedderContextTest, FromOutOfLine) {
43+
struct TestScope {};
44+
EmbedderContext context =
45+
EmbedderContext::From<TestScope>(int64_t{42}, uint64_t{43}, double{44});
46+
47+
EXPECT_THAT((context.Get<TestScope, int64_t>()), Optional(42));
48+
EXPECT_THAT((context.Get<TestScope, uint64_t>()), Optional(43));
49+
EXPECT_THAT((context.Get<TestScope, double>()), Optional(44));
50+
EXPECT_EQ((context.Get<TestScope, bool>()), absl::nullopt);
51+
52+
// Note: Referencing a type not intended to be stored will still reserve a
53+
// slot in the TypeIdInSet.
54+
EXPECT_EQ(cel::TypeIdInSet<TestScope>::Size(), 4);
55+
}
56+
57+
TEST(EmbedderContextTest, FromPtrs) {
58+
struct TestScope {};
59+
struct TestPointee {
60+
} foo;
61+
int64_t pointee2;
62+
63+
EmbedderContext context = EmbedderContext::From<TestScope>(
64+
&foo, const_cast<const int64_t*>(&pointee2));
65+
EXPECT_EQ((context.Get<TestScope, const int64_t*>()), &pointee2);
66+
EXPECT_EQ((context.Get<TestScope, TestPointee*>()), &foo);
67+
68+
EmbedderContext context2 = EmbedderContext::From<TestScope>(&foo);
69+
EXPECT_EQ((context2.Get<TestScope, int64_t*>()), nullptr);
70+
EXPECT_EQ((context2.Get<TestScope, TestPointee*>()), &foo);
71+
72+
// Note: const int* not the same as int*.
73+
EXPECT_EQ(cel::TypeIdInSet<TestScope>::Size(), 3);
74+
}
75+
76+
TEST(EmbedderContextTest, FromDefaultScope) {
77+
EmbedderContext context = EmbedderContext::From(int64_t{42});
78+
EXPECT_THAT((context.Get<int64_t>()), Optional(42));
79+
EXPECT_EQ((context.Get<uint64_t>()), absl::nullopt);
80+
}
81+
82+
// These death assertions are only enabled when compiled in debug mode.
83+
// Caller is responsible for adequately testing since we're limited in what
84+
// we can statically check due to the type-erasure.
85+
TEST(EmbedderContextDeathTest, GetWithWrongScope) {
86+
struct TestScope {};
87+
EmbedderContext context = EmbedderContext::From<TestScope>(int64_t{42});
88+
EXPECT_DEBUG_DEATH(
89+
{ context.Get<int64_t>(); }, "EmbedderContext::Get wrong scope");
90+
}
91+
92+
} // namespace
93+
} // namespace cel

0 commit comments

Comments
 (0)