|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #include "google/cloud/storage/internal/async/writer_connection_resumed.h" |
| 16 | +#include "google/cloud/mocks/mock_async_streaming_read_write_rpc.h" |
16 | 17 | #include "google/cloud/storage/async/connection.h" |
17 | 18 | #include "google/cloud/storage/mocks/mock_async_writer_connection.h" |
18 | 19 | #include "google/cloud/storage/testing/canonical_errors.h" |
@@ -615,6 +616,128 @@ TEST(WriterConnectionResumed, OnQueryUpdatesWriteHandle) { |
615 | 616 | EXPECT_EQ(current_handle->handle(), "updated-handle"); |
616 | 617 | } |
617 | 618 |
|
| 619 | +TEST(WriterConnectionResumed, ResetWriteOffsetOnResume) { |
| 620 | + AsyncSequencer<bool> sequencer; |
| 621 | + auto mock = std::make_unique<MockAsyncWriterConnection>(); |
| 622 | + auto* mock_ptr = mock.get(); |
| 623 | + |
| 624 | + auto initial_request = google::storage::v2::BidiWriteObjectRequest{}; |
| 625 | + google::storage::v2::BidiWriteObjectResponse first_response; |
| 626 | + first_response.mutable_write_handle()->set_handle("initial-handle"); |
| 627 | + |
| 628 | + auto mock_hash = |
| 629 | + std::make_shared<google::cloud::storage::testing::MockHashFunction>(); |
| 630 | + EXPECT_CALL(*mock_hash, Update(::testing::An<std::int64_t>(), |
| 631 | + ::testing::An<absl::Cord const&>(), |
| 632 | + ::testing::An<std::uint32_t>())) |
| 633 | + .WillRepeatedly(Return(Status())); |
| 634 | + |
| 635 | + EXPECT_CALL(*mock_ptr, PersistedState) |
| 636 | + .WillOnce( |
| 637 | + Return(MakePersistedState(0))) // Initial state: 0 bytes persisted. |
| 638 | + .WillOnce(Return( |
| 639 | + MakePersistedState(1024))); // Resumed state: 1024 bytes persisted. |
| 640 | + |
| 641 | + EXPECT_CALL(*mock_ptr, Flush(_)).WillOnce([&](auto) { |
| 642 | + return sequencer.PushBack("Flush").then([](auto f) { |
| 643 | + if (f.get()) return Status{}; |
| 644 | + return TransientError(); // Return a transient error to trigger resume. |
| 645 | + }); |
| 646 | + }); |
| 647 | + |
| 648 | + MockFactory mock_factory; |
| 649 | + auto mock_stream = |
| 650 | + std::make_unique<google::cloud::mocks::MockAsyncStreamingReadWriteRpc< |
| 651 | + google::storage::v2::BidiWriteObjectRequest, |
| 652 | + google::storage::v2::BidiWriteObjectResponse>>(); |
| 653 | + auto* mock_stream_ptr = mock_stream.get(); |
| 654 | + |
| 655 | + // The mock factory is called when the connection resumes. |
| 656 | + EXPECT_CALL(mock_factory, Call(_)) |
| 657 | + .WillOnce([&](google::storage::v2::BidiWriteObjectRequest) { |
| 658 | + WriteObject::WriteResult result; |
| 659 | + result.stream = std::move(mock_stream); |
| 660 | + result.first_response.mutable_write_handle()->set_handle("new-handle"); |
| 661 | + return sequencer.PushBack("Factory").then( |
| 662 | + [r = std::move(result)](auto) mutable { |
| 663 | + return StatusOr<WriteObject::WriteResult>(std::move(r)); |
| 664 | + }); |
| 665 | + }); |
| 666 | + |
| 667 | + // After resuming, the connection should write the remaining payload. |
| 668 | + EXPECT_CALL(*mock_stream_ptr, Write(_, _)) |
| 669 | + .WillOnce([&](google::storage::v2::BidiWriteObjectRequest const& request, |
| 670 | + grpc::WriteOptions) { |
| 671 | + // We expect the next write on the resumed stream to send the remaining |
| 672 | + // 1024 bytes. If the write offset was not reset to 0, this size would |
| 673 | + // be incorrect. |
| 674 | + EXPECT_EQ(request.checksummed_data().content().size(), 1024); |
| 675 | + return sequencer.PushBack("StreamWrite").then([](auto) { |
| 676 | + return true; |
| 677 | + }); |
| 678 | + }) |
| 679 | + .WillOnce([&](google::storage::v2::BidiWriteObjectRequest const& request, |
| 680 | + grpc::WriteOptions) { |
| 681 | + // Expect a final "ghost" write to flush. |
| 682 | + EXPECT_TRUE(request.checksummed_data().content().empty()); |
| 683 | + EXPECT_TRUE(request.flush()); |
| 684 | + return sequencer.PushBack("GhostWrite").then([](auto) { return true; }); |
| 685 | + }); |
| 686 | + |
| 687 | + google::storage::v2::BidiWriteObjectResponse read_response1; |
| 688 | + read_response1.set_persisted_size(2048); |
| 689 | + google::storage::v2::BidiWriteObjectResponse read_response2; |
| 690 | + read_response2.set_persisted_size(2048); |
| 691 | + EXPECT_CALL(*mock_stream_ptr, Read) |
| 692 | + .WillOnce([&, read_response1]() { |
| 693 | + return sequencer.PushBack("StreamRead1").then([read_response1](auto) { |
| 694 | + return absl::make_optional(read_response1); |
| 695 | + }); |
| 696 | + }) |
| 697 | + .WillOnce([&, read_response2]() { |
| 698 | + return sequencer.PushBack("StreamRead2").then([read_response2](auto) { |
| 699 | + return absl::make_optional(read_response2); |
| 700 | + }); |
| 701 | + }); |
| 702 | + |
| 703 | + EXPECT_CALL(*mock_stream_ptr, Finish) |
| 704 | + .WillOnce(Return(make_ready_future(Status{}))); |
| 705 | + EXPECT_CALL(*mock_stream_ptr, Cancel).WillRepeatedly(Return()); |
| 706 | + |
| 707 | + auto connection = MakeWriterConnectionResumed( |
| 708 | + mock_factory.AsStdFunction(), std::move(mock), initial_request, mock_hash, |
| 709 | + first_response, Options{}); |
| 710 | + |
| 711 | + // Write a total of 2048 bytes. |
| 712 | + auto write = connection->Write(TestPayload(2048)); |
| 713 | + |
| 714 | + auto next = sequencer.PopFrontWithName(); |
| 715 | + EXPECT_EQ(next.second, "Flush"); |
| 716 | + next.first.set_value(false); |
| 717 | + |
| 718 | + next = sequencer.PopFrontWithName(); |
| 719 | + EXPECT_EQ(next.second, "Factory"); |
| 720 | + next.first.set_value(true); |
| 721 | + |
| 722 | + next = sequencer.PopFrontWithName(); |
| 723 | + EXPECT_EQ(next.second, "StreamWrite"); |
| 724 | + next.first.set_value(true); |
| 725 | + |
| 726 | + next = sequencer.PopFrontWithName(); |
| 727 | + EXPECT_EQ(next.second, "StreamRead1"); |
| 728 | + next.first.set_value(true); |
| 729 | + |
| 730 | + next = sequencer.PopFrontWithName(); |
| 731 | + EXPECT_EQ(next.second, "GhostWrite"); |
| 732 | + next.first.set_value(true); |
| 733 | + |
| 734 | + next = sequencer.PopFrontWithName(); |
| 735 | + EXPECT_EQ(next.second, "StreamRead2"); |
| 736 | + next.first.set_value(true); |
| 737 | + |
| 738 | + EXPECT_THAT(write.get(), StatusIs(StatusCode::kOk)); |
| 739 | +} |
| 740 | + |
618 | 741 | } // namespace |
619 | 742 | GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END |
620 | 743 | } // namespace storage_internal |
|
0 commit comments