|
13 | 13 | // limitations under the License. |
14 | 14 | // |
15 | 15 | // Tests for memory safety using the CEL Evaluator. |
| 16 | +#include <functional> |
16 | 17 | #include <memory> |
17 | 18 | #include <string> |
18 | 19 | #include <tuple> |
19 | 20 | #include <utility> |
20 | 21 | #include <vector> |
21 | 22 |
|
| 23 | +#include "google/protobuf/any.pb.h" |
22 | 24 | #include "absl/base/no_destructor.h" |
23 | 25 | #include "absl/container/flat_hash_map.h" |
24 | 26 | #include "absl/log/absl_check.h" |
@@ -61,6 +63,8 @@ using ::absl_testing::IsOkAndHolds; |
61 | 63 | using ::cel::expr::conformance::proto3::NestedTestAllTypes; |
62 | 64 | using ::cel::expr::conformance::proto3::TestAllTypes; |
63 | 65 | using ::cel::test::ValueMatcher; |
| 66 | +using ::google::protobuf::Any; |
| 67 | +using ::testing::Not; |
64 | 68 |
|
65 | 69 | struct TestCase { |
66 | 70 | std::string name; |
@@ -91,6 +95,8 @@ absl::StatusOr<std::unique_ptr<Compiler>> CreateCompiler() { |
91 | 95 | cb.AddVariable(MakeVariableDecl("string_var", StringType()))); |
92 | 96 | CEL_RETURN_IF_ERROR( |
93 | 97 | cb.AddVariable(MakeVariableDecl("condition", BoolType()))); |
| 98 | + CEL_RETURN_IF_ERROR(cb.AddVariable(MakeVariableDecl( |
| 99 | + "nested_test_all_types", MessageType(NestedTestAllTypes::descriptor())))); |
94 | 100 |
|
95 | 101 | CEL_RETURN_IF_ERROR(cb.AddFunction( |
96 | 102 | MakeFunctionDecl("IsPrivate", MakeOverloadDecl("IsPrivate_string", |
@@ -246,6 +252,20 @@ Value MakeStringValue(absl::string_view str) { |
246 | 252 | return StringValue::Wrap(str, kArena.get()); |
247 | 253 | } |
248 | 254 |
|
| 255 | +MATCHER_P(ParsedProtoStructEquals, expected, "") { |
| 256 | + const cel::StructValue& got = arg; |
| 257 | + if (!got.IsParsedMessage()) { |
| 258 | + return false; |
| 259 | + } |
| 260 | + auto& msg = got.GetParsedMessage(); |
| 261 | + auto cmp = absl::WrapUnique(msg->New()); |
| 262 | + if (!google::protobuf::TextFormat::ParseFromString(expected, cmp.get())) { |
| 263 | + *result_listener << "Failed to parse expected proto"; |
| 264 | + return false; |
| 265 | + } |
| 266 | + return google::protobuf::util::MessageDifferencer::Equals(*msg, *cmp); |
| 267 | +} |
| 268 | + |
249 | 269 | INSTANTIATE_TEST_SUITE_P( |
250 | 270 | Expression, EvaluatorMemorySafetyTest, |
251 | 271 | testing::Combine( |
@@ -314,21 +334,10 @@ INSTANTIATE_TEST_SUITE_P( |
314 | 334 | } |
315 | 335 | })cel", |
316 | 336 | {}, |
317 | | - test::StructValueIs(testing::Truly([](const StructValue& v) |
318 | | - -> bool { |
319 | | - if (!v.IsParsedMessage()) { |
320 | | - return false; |
321 | | - } |
322 | | - auto& msg = v.GetParsedMessage(); |
323 | | - auto cmp = absl::WrapUnique(msg->New()); |
324 | | - google::protobuf::TextFormat::ParseFromString( |
325 | | - R"pb( |
326 | | - child { payload { repeated_int32: [ 1, 2, 3 ] } } |
327 | | - payload { repeated_string: [ "foo", "bar", "baz" ] } |
328 | | - )pb", |
329 | | - cmp.get()); |
330 | | - return google::protobuf::util::MessageDifferencer::Equals(*msg, *cmp); |
331 | | - })), |
| 337 | + test::StructValueIs(ParsedProtoStructEquals(R"pb( |
| 338 | + child { payload { repeated_int32: [ 1, 2, 3 ] } } |
| 339 | + payload { repeated_string: [ "foo", "bar", "baz" ] } |
| 340 | + )pb")), |
332 | 341 | }, |
333 | 342 | {"extension_function", |
334 | 343 | "IsPrivate('8.8.8.8')", |
@@ -357,5 +366,190 @@ INSTANTIATE_TEST_SUITE_P( |
357 | 366 | Options::kFoldConstants)), |
358 | 367 | &TestCaseName); |
359 | 368 |
|
| 369 | +MATCHER_P(IsSameInstance, expected, "") { |
| 370 | + return std::mem_fn(&ParsedMessageValue::operator->)(&arg) == expected; |
| 371 | +} |
| 372 | + |
| 373 | +class ViewTypesMemorySafetyTest : public testing::TestWithParam<Options> { |
| 374 | + protected: |
| 375 | + Options EvaluationOptions() { return GetParam(); } |
| 376 | +}; |
| 377 | + |
| 378 | +// Test cases demonstrating how inputs as views are handled. |
| 379 | +TEST_P(ViewTypesMemorySafetyTest, WrappedMessage) { |
| 380 | + // Arrange: create the runtime and expression. |
| 381 | + ASSERT_OK_AND_ASSIGN(std::unique_ptr<Runtime> runtime, |
| 382 | + ConfigureRuntimeImpl(false, EvaluationOptions())); |
| 383 | + constexpr absl::string_view kProtoValue = R"pb( |
| 384 | + child { payload { repeated_int32: [ 1, 2, 3 ] } } |
| 385 | + payload { repeated_string: [ "foo", "bar", "baz" ] } |
| 386 | + )pb"; |
| 387 | + |
| 388 | + ASSERT_OK_AND_ASSIGN( |
| 389 | + ValidationResult validation, |
| 390 | + GetCompiler().Compile( |
| 391 | + "condition ? nested_test_all_types : NestedTestAllTypes{}")); |
| 392 | + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); |
| 393 | + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); |
| 394 | + ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program, |
| 395 | + runtime->CreateProgram(std::move(ast))); |
| 396 | + |
| 397 | + // Act: wrap the message and evaluate the expression. |
| 398 | + google::protobuf::Arena arena; |
| 399 | + NestedTestAllTypes* proto = |
| 400 | + NestedTestAllTypes::default_instance().New(&arena); |
| 401 | + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto)); |
| 402 | + Activation activation; |
| 403 | + activation.InsertOrAssignValue("condition", BoolValue(true)); |
| 404 | + activation.InsertOrAssignValue( |
| 405 | + "nested_test_all_types", |
| 406 | + Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(), |
| 407 | + google::protobuf::MessageFactory::generated_factory(), &arena)); |
| 408 | + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); |
| 409 | + |
| 410 | + // Assert: the result is the input message. |
| 411 | + ASSERT_TRUE(result.IsParsedMessage()); |
| 412 | + const ParsedMessageValue& result_msg = result.GetParsedMessage(); |
| 413 | + EXPECT_THAT(result_msg, |
| 414 | + test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); |
| 415 | + EXPECT_EQ(result_msg->GetArena(), &arena); |
| 416 | + EXPECT_THAT(result_msg, IsSameInstance(proto)); |
| 417 | +} |
| 418 | + |
| 419 | +// Test cases demonstrating how inputs as views are handled. |
| 420 | +TEST_P(ViewTypesMemorySafetyTest, WrappedMessageFields) { |
| 421 | + // Arrange: create the runtime and expression. |
| 422 | + ASSERT_OK_AND_ASSIGN(std::unique_ptr<Runtime> runtime, |
| 423 | + ConfigureRuntimeImpl(false, EvaluationOptions())); |
| 424 | + constexpr absl::string_view kProtoValue = R"pb( |
| 425 | + child { payload { repeated_int32: [ 1, 2, 3 ] } } |
| 426 | + payload { repeated_string: [ "foo", "bar", "baz" ] } |
| 427 | + )pb"; |
| 428 | + ASSERT_OK_AND_ASSIGN( |
| 429 | + ValidationResult validation, |
| 430 | + GetCompiler().Compile("nested_test_all_types.child.payload")); |
| 431 | + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); |
| 432 | + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); |
| 433 | + ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program, |
| 434 | + runtime->CreateProgram(std::move(ast))); |
| 435 | + |
| 436 | + // Act: wrap the message and evaluate the expression. |
| 437 | + google::protobuf::Arena arena; |
| 438 | + NestedTestAllTypes* proto = |
| 439 | + NestedTestAllTypes::default_instance().New(&arena); |
| 440 | + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto)); |
| 441 | + Activation activation; |
| 442 | + activation.InsertOrAssignValue("condition", BoolValue(true)); |
| 443 | + activation.InsertOrAssignValue( |
| 444 | + "nested_test_all_types", |
| 445 | + Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(), |
| 446 | + google::protobuf::MessageFactory::generated_factory(), &arena)); |
| 447 | + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); |
| 448 | + |
| 449 | + // Assert: the result is an alias of a sub-message in the input. |
| 450 | + ASSERT_TRUE(result.IsParsedMessage()); |
| 451 | + const ParsedMessageValue& result_msg = result.GetParsedMessage(); |
| 452 | + EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals( |
| 453 | + "repeated_int32: [ 1, 2, 3 ]"))); |
| 454 | + EXPECT_EQ(result_msg->GetArena(), &arena); |
| 455 | + EXPECT_THAT(result_msg, IsSameInstance(&(proto->child().payload()))); |
| 456 | +} |
| 457 | + |
| 458 | +TEST_P(ViewTypesMemorySafetyTest, WrappedMessageDifferentArena) { |
| 459 | + // Arrange: create the runtime and expression. |
| 460 | + ASSERT_OK_AND_ASSIGN(std::unique_ptr<Runtime> runtime, |
| 461 | + ConfigureRuntimeImpl(false, EvaluationOptions())); |
| 462 | + constexpr absl::string_view kProtoValue = R"pb( |
| 463 | + child { payload { repeated_int32: [ 1, 2, 3 ] } } |
| 464 | + payload { repeated_string: [ "foo", "bar", "baz" ] } |
| 465 | + )pb"; |
| 466 | + |
| 467 | + ASSERT_OK_AND_ASSIGN( |
| 468 | + ValidationResult validation, |
| 469 | + GetCompiler().Compile( |
| 470 | + "condition ? nested_test_all_types : NestedTestAllTypes{}")); |
| 471 | + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); |
| 472 | + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); |
| 473 | + ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program, |
| 474 | + runtime->CreateProgram(std::move(ast))); |
| 475 | + |
| 476 | + // Act: wrap the message and evaluate the expression. |
| 477 | + google::protobuf::Arena arena; |
| 478 | + google::protobuf::Arena other_arena; |
| 479 | + NestedTestAllTypes* proto = |
| 480 | + NestedTestAllTypes::default_instance().New(&other_arena); |
| 481 | + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto)); |
| 482 | + Activation activation; |
| 483 | + activation.InsertOrAssignValue("condition", BoolValue(true)); |
| 484 | + activation.InsertOrAssignValue( |
| 485 | + "nested_test_all_types", |
| 486 | + Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(), |
| 487 | + google::protobuf::MessageFactory::generated_factory(), &arena)); |
| 488 | + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); |
| 489 | + |
| 490 | + // Assert: the result is a copy of the input message. |
| 491 | + ASSERT_TRUE(result.IsParsedMessage()); |
| 492 | + const ParsedMessageValue& result_msg = result.GetParsedMessage(); |
| 493 | + EXPECT_THAT(result_msg, |
| 494 | + test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); |
| 495 | + EXPECT_EQ(result_msg->GetArena(), &arena); |
| 496 | + EXPECT_THAT(result_msg, Not(IsSameInstance(proto))); |
| 497 | +} |
| 498 | + |
| 499 | +TEST_P(ViewTypesMemorySafetyTest, WrappedMessageFromAny) { |
| 500 | + // Arrange: create the runtime. |
| 501 | + ASSERT_OK_AND_ASSIGN(std::unique_ptr<Runtime> runtime, |
| 502 | + ConfigureRuntimeImpl(false, EvaluationOptions())); |
| 503 | + constexpr absl::string_view kProtoValue = R"pb( |
| 504 | + child { payload { repeated_int32: [ 1, 2, 3 ] } } |
| 505 | + payload { repeated_string: [ "foo", "bar", "baz" ] } |
| 506 | + )pb"; |
| 507 | + |
| 508 | + ASSERT_OK_AND_ASSIGN( |
| 509 | + ValidationResult validation, |
| 510 | + GetCompiler().Compile( |
| 511 | + "condition ? nested_test_all_types : NestedTestAllTypes{}")); |
| 512 | + ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); |
| 513 | + ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); |
| 514 | + ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program, |
| 515 | + runtime->CreateProgram(std::move(ast))); |
| 516 | + |
| 517 | + // Act: wrap the message and evaluate the expression. |
| 518 | + google::protobuf::Arena arena; |
| 519 | + NestedTestAllTypes proto; |
| 520 | + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); |
| 521 | + Any any; |
| 522 | + any.PackFrom(proto); |
| 523 | + Activation activation; |
| 524 | + activation.InsertOrAssignValue("condition", BoolValue(true)); |
| 525 | + activation.InsertOrAssignValue( |
| 526 | + "nested_test_all_types", |
| 527 | + Value::WrapMessage(&any, google::protobuf::DescriptorPool::generated_pool(), |
| 528 | + google::protobuf::MessageFactory::generated_factory(), &arena)); |
| 529 | + |
| 530 | + // Assert |
| 531 | + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); |
| 532 | + ASSERT_TRUE(result.IsParsedMessage()); |
| 533 | + const ParsedMessageValue& result_msg = result.GetParsedMessage(); |
| 534 | + EXPECT_THAT(result_msg, |
| 535 | + test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); |
| 536 | + EXPECT_EQ(result_msg->GetArena(), &arena); |
| 537 | +} |
| 538 | + |
| 539 | +INSTANTIATE_TEST_SUITE_P(Cases, ViewTypesMemorySafetyTest, |
| 540 | + testing::Values(Options::kDefault, |
| 541 | + Options::kExhaustive, |
| 542 | + Options::kFoldConstants), |
| 543 | + [](const testing::TestParamInfo<Options>& info) { |
| 544 | + switch (info.param) { |
| 545 | + case Options::kDefault: |
| 546 | + return "default"; |
| 547 | + case Options::kExhaustive: |
| 548 | + return "exhaustive"; |
| 549 | + case Options::kFoldConstants: |
| 550 | + return "opt"; |
| 551 | + } |
| 552 | + }); |
| 553 | + |
360 | 554 | } // namespace |
361 | 555 | } // namespace cel |
0 commit comments