Skip to content

Commit 80814ee

Browse files
jnthntatumcopybara-github
authored andcommitted
Add test cases demonstrating Value::WrapMessage behavior.
PiperOrigin-RevId: 831422369
1 parent b48a3b6 commit 80814ee

2 files changed

Lines changed: 210 additions & 15 deletions

File tree

runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ cc_test(
643643
"@com_google_absl//absl/status:statusor",
644644
"@com_google_absl//absl/strings",
645645
"@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto",
646+
"@com_google_protobuf//:any_cc_proto",
646647
"@com_google_protobuf//:differencer",
647648
"@com_google_protobuf//:protobuf",
648649
],

runtime/memory_safety_test.cc

Lines changed: 209 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
// limitations under the License.
1414
//
1515
// Tests for memory safety using the CEL Evaluator.
16+
#include <functional>
1617
#include <memory>
1718
#include <string>
1819
#include <tuple>
1920
#include <utility>
2021
#include <vector>
2122

23+
#include "google/protobuf/any.pb.h"
2224
#include "absl/base/no_destructor.h"
2325
#include "absl/container/flat_hash_map.h"
2426
#include "absl/log/absl_check.h"
@@ -61,6 +63,8 @@ using ::absl_testing::IsOkAndHolds;
6163
using ::cel::expr::conformance::proto3::NestedTestAllTypes;
6264
using ::cel::expr::conformance::proto3::TestAllTypes;
6365
using ::cel::test::ValueMatcher;
66+
using ::google::protobuf::Any;
67+
using ::testing::Not;
6468

6569
struct TestCase {
6670
std::string name;
@@ -91,6 +95,8 @@ absl::StatusOr<std::unique_ptr<Compiler>> CreateCompiler() {
9195
cb.AddVariable(MakeVariableDecl("string_var", StringType())));
9296
CEL_RETURN_IF_ERROR(
9397
cb.AddVariable(MakeVariableDecl("condition", BoolType())));
98+
CEL_RETURN_IF_ERROR(cb.AddVariable(MakeVariableDecl(
99+
"nested_test_all_types", MessageType(NestedTestAllTypes::descriptor()))));
94100

95101
CEL_RETURN_IF_ERROR(cb.AddFunction(
96102
MakeFunctionDecl("IsPrivate", MakeOverloadDecl("IsPrivate_string",
@@ -246,6 +252,20 @@ Value MakeStringValue(absl::string_view str) {
246252
return StringValue::Wrap(str, kArena.get());
247253
}
248254

255+
MATCHER_P(ParsedProtoStructEquals, expected, "") {
256+
const cel::StructValue& got = arg;
257+
if (!got.IsParsedMessage()) {
258+
return false;
259+
}
260+
auto& msg = got.GetParsedMessage();
261+
auto cmp = absl::WrapUnique(msg->New());
262+
if (!google::protobuf::TextFormat::ParseFromString(expected, cmp.get())) {
263+
*result_listener << "Failed to parse expected proto";
264+
return false;
265+
}
266+
return google::protobuf::util::MessageDifferencer::Equals(*msg, *cmp);
267+
}
268+
249269
INSTANTIATE_TEST_SUITE_P(
250270
Expression, EvaluatorMemorySafetyTest,
251271
testing::Combine(
@@ -314,21 +334,10 @@ INSTANTIATE_TEST_SUITE_P(
314334
}
315335
})cel",
316336
{},
317-
test::StructValueIs(testing::Truly([](const StructValue& v)
318-
-> bool {
319-
if (!v.IsParsedMessage()) {
320-
return false;
321-
}
322-
auto& msg = v.GetParsedMessage();
323-
auto cmp = absl::WrapUnique(msg->New());
324-
google::protobuf::TextFormat::ParseFromString(
325-
R"pb(
326-
child { payload { repeated_int32: [ 1, 2, 3 ] } }
327-
payload { repeated_string: [ "foo", "bar", "baz" ] }
328-
)pb",
329-
cmp.get());
330-
return google::protobuf::util::MessageDifferencer::Equals(*msg, *cmp);
331-
})),
337+
test::StructValueIs(ParsedProtoStructEquals(R"pb(
338+
child { payload { repeated_int32: [ 1, 2, 3 ] } }
339+
payload { repeated_string: [ "foo", "bar", "baz" ] }
340+
)pb")),
332341
},
333342
{"extension_function",
334343
"IsPrivate('8.8.8.8')",
@@ -357,5 +366,190 @@ INSTANTIATE_TEST_SUITE_P(
357366
Options::kFoldConstants)),
358367
&TestCaseName);
359368

369+
MATCHER_P(IsSameInstance, expected, "") {
370+
return std::mem_fn(&ParsedMessageValue::operator->)(&arg) == expected;
371+
}
372+
373+
class ViewTypesMemorySafetyTest : public testing::TestWithParam<Options> {
374+
protected:
375+
Options EvaluationOptions() { return GetParam(); }
376+
};
377+
378+
// Test cases demonstrating how inputs as views are handled.
379+
TEST_P(ViewTypesMemorySafetyTest, WrappedMessage) {
380+
// Arrange: create the runtime and expression.
381+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Runtime> runtime,
382+
ConfigureRuntimeImpl(false, EvaluationOptions()));
383+
constexpr absl::string_view kProtoValue = R"pb(
384+
child { payload { repeated_int32: [ 1, 2, 3 ] } }
385+
payload { repeated_string: [ "foo", "bar", "baz" ] }
386+
)pb";
387+
388+
ASSERT_OK_AND_ASSIGN(
389+
ValidationResult validation,
390+
GetCompiler().Compile(
391+
"condition ? nested_test_all_types : NestedTestAllTypes{}"));
392+
ASSERT_TRUE(validation.IsValid()) << validation.FormatError();
393+
ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst());
394+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
395+
runtime->CreateProgram(std::move(ast)));
396+
397+
// Act: wrap the message and evaluate the expression.
398+
google::protobuf::Arena arena;
399+
NestedTestAllTypes* proto =
400+
NestedTestAllTypes::default_instance().New(&arena);
401+
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto));
402+
Activation activation;
403+
activation.InsertOrAssignValue("condition", BoolValue(true));
404+
activation.InsertOrAssignValue(
405+
"nested_test_all_types",
406+
Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(),
407+
google::protobuf::MessageFactory::generated_factory(), &arena));
408+
ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation));
409+
410+
// Assert: the result is the input message.
411+
ASSERT_TRUE(result.IsParsedMessage());
412+
const ParsedMessageValue& result_msg = result.GetParsedMessage();
413+
EXPECT_THAT(result_msg,
414+
test::StructValueIs(ParsedProtoStructEquals(kProtoValue)));
415+
EXPECT_EQ(result_msg->GetArena(), &arena);
416+
EXPECT_THAT(result_msg, IsSameInstance(proto));
417+
}
418+
419+
// Test cases demonstrating how inputs as views are handled.
420+
TEST_P(ViewTypesMemorySafetyTest, WrappedMessageFields) {
421+
// Arrange: create the runtime and expression.
422+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Runtime> runtime,
423+
ConfigureRuntimeImpl(false, EvaluationOptions()));
424+
constexpr absl::string_view kProtoValue = R"pb(
425+
child { payload { repeated_int32: [ 1, 2, 3 ] } }
426+
payload { repeated_string: [ "foo", "bar", "baz" ] }
427+
)pb";
428+
ASSERT_OK_AND_ASSIGN(
429+
ValidationResult validation,
430+
GetCompiler().Compile("nested_test_all_types.child.payload"));
431+
ASSERT_TRUE(validation.IsValid()) << validation.FormatError();
432+
ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst());
433+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
434+
runtime->CreateProgram(std::move(ast)));
435+
436+
// Act: wrap the message and evaluate the expression.
437+
google::protobuf::Arena arena;
438+
NestedTestAllTypes* proto =
439+
NestedTestAllTypes::default_instance().New(&arena);
440+
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto));
441+
Activation activation;
442+
activation.InsertOrAssignValue("condition", BoolValue(true));
443+
activation.InsertOrAssignValue(
444+
"nested_test_all_types",
445+
Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(),
446+
google::protobuf::MessageFactory::generated_factory(), &arena));
447+
ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation));
448+
449+
// Assert: the result is an alias of a sub-message in the input.
450+
ASSERT_TRUE(result.IsParsedMessage());
451+
const ParsedMessageValue& result_msg = result.GetParsedMessage();
452+
EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals(
453+
"repeated_int32: [ 1, 2, 3 ]")));
454+
EXPECT_EQ(result_msg->GetArena(), &arena);
455+
EXPECT_THAT(result_msg, IsSameInstance(&(proto->child().payload())));
456+
}
457+
458+
TEST_P(ViewTypesMemorySafetyTest, WrappedMessageDifferentArena) {
459+
// Arrange: create the runtime and expression.
460+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Runtime> runtime,
461+
ConfigureRuntimeImpl(false, EvaluationOptions()));
462+
constexpr absl::string_view kProtoValue = R"pb(
463+
child { payload { repeated_int32: [ 1, 2, 3 ] } }
464+
payload { repeated_string: [ "foo", "bar", "baz" ] }
465+
)pb";
466+
467+
ASSERT_OK_AND_ASSIGN(
468+
ValidationResult validation,
469+
GetCompiler().Compile(
470+
"condition ? nested_test_all_types : NestedTestAllTypes{}"));
471+
ASSERT_TRUE(validation.IsValid()) << validation.FormatError();
472+
ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst());
473+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
474+
runtime->CreateProgram(std::move(ast)));
475+
476+
// Act: wrap the message and evaluate the expression.
477+
google::protobuf::Arena arena;
478+
google::protobuf::Arena other_arena;
479+
NestedTestAllTypes* proto =
480+
NestedTestAllTypes::default_instance().New(&other_arena);
481+
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto));
482+
Activation activation;
483+
activation.InsertOrAssignValue("condition", BoolValue(true));
484+
activation.InsertOrAssignValue(
485+
"nested_test_all_types",
486+
Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(),
487+
google::protobuf::MessageFactory::generated_factory(), &arena));
488+
ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation));
489+
490+
// Assert: the result is a copy of the input message.
491+
ASSERT_TRUE(result.IsParsedMessage());
492+
const ParsedMessageValue& result_msg = result.GetParsedMessage();
493+
EXPECT_THAT(result_msg,
494+
test::StructValueIs(ParsedProtoStructEquals(kProtoValue)));
495+
EXPECT_EQ(result_msg->GetArena(), &arena);
496+
EXPECT_THAT(result_msg, Not(IsSameInstance(proto)));
497+
}
498+
499+
TEST_P(ViewTypesMemorySafetyTest, WrappedMessageFromAny) {
500+
// Arrange: create the runtime.
501+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Runtime> runtime,
502+
ConfigureRuntimeImpl(false, EvaluationOptions()));
503+
constexpr absl::string_view kProtoValue = R"pb(
504+
child { payload { repeated_int32: [ 1, 2, 3 ] } }
505+
payload { repeated_string: [ "foo", "bar", "baz" ] }
506+
)pb";
507+
508+
ASSERT_OK_AND_ASSIGN(
509+
ValidationResult validation,
510+
GetCompiler().Compile(
511+
"condition ? nested_test_all_types : NestedTestAllTypes{}"));
512+
ASSERT_TRUE(validation.IsValid()) << validation.FormatError();
513+
ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst());
514+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
515+
runtime->CreateProgram(std::move(ast)));
516+
517+
// Act: wrap the message and evaluate the expression.
518+
google::protobuf::Arena arena;
519+
NestedTestAllTypes proto;
520+
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto));
521+
Any any;
522+
any.PackFrom(proto);
523+
Activation activation;
524+
activation.InsertOrAssignValue("condition", BoolValue(true));
525+
activation.InsertOrAssignValue(
526+
"nested_test_all_types",
527+
Value::WrapMessage(&any, google::protobuf::DescriptorPool::generated_pool(),
528+
google::protobuf::MessageFactory::generated_factory(), &arena));
529+
530+
// Assert
531+
ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation));
532+
ASSERT_TRUE(result.IsParsedMessage());
533+
const ParsedMessageValue& result_msg = result.GetParsedMessage();
534+
EXPECT_THAT(result_msg,
535+
test::StructValueIs(ParsedProtoStructEquals(kProtoValue)));
536+
EXPECT_EQ(result_msg->GetArena(), &arena);
537+
}
538+
539+
INSTANTIATE_TEST_SUITE_P(Cases, ViewTypesMemorySafetyTest,
540+
testing::Values(Options::kDefault,
541+
Options::kExhaustive,
542+
Options::kFoldConstants),
543+
[](const testing::TestParamInfo<Options>& info) {
544+
switch (info.param) {
545+
case Options::kDefault:
546+
return "default";
547+
case Options::kExhaustive:
548+
return "exhaustive";
549+
case Options::kFoldConstants:
550+
return "opt";
551+
}
552+
});
553+
360554
} // namespace
361555
} // namespace cel

0 commit comments

Comments
 (0)