Skip to content

Commit d1d6175

Browse files
authored
Make batch create stream SendFeedback thread safe (#3215)
1 parent 9a6ca26 commit d1d6175

3 files changed

Lines changed: 267 additions & 10 deletions

File tree

src/brpc/stream.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Stream::Stream()
5252
, _remote_consumed(0)
5353
, _cur_buf_size(0)
5454
, _local_consumed(0)
55+
, _atomic_local_consumed(0)
5556
, _parse_rpc_response(false)
5657
, _pending_buf(NULL)
5758
, _start_idle_timer_us(0)
@@ -287,14 +288,21 @@ void Stream::SetConnected(const StreamSettings* remote_settings) {
287288
CHECK(_host_socket != NULL);
288289
RPC_VLOG << "stream=" << id() << " is connected to stream_id="
289290
<< _remote_settings.stream_id() << " at host_socket=" << *_host_socket;
290-
_connected = true;
291+
_connected.store(true, butil::memory_order_release);
291292
_connect_meta.ec = 0;
292293
TriggerOnConnectIfNeed();
293294
if (remote_settings == NULL) {
294295
// Start the timer at server-side
295296
// Client-side timer would triggered in Consume after received the first
296297
// message which is the very RPC response
297298
StartIdleTimer();
299+
} else {
300+
// send first feedback for client-side stream if it already consumed data
301+
if (_remote_settings.need_feedback()) {
302+
auto consumed_bytes = _atomic_local_consumed.load(butil::memory_order_acquire);
303+
if (consumed_bytes > 0)
304+
SendFeedback(consumed_bytes);
305+
}
298306
}
299307
}
300308

@@ -620,20 +628,34 @@ int Stream::Consume(void *meta, bthread::TaskIterator<butil::IOBuf*>& iter) {
620628
}
621629
mb.flush();
622630

623-
if (s->_remote_settings.need_feedback() && mb.total_length() > 0) {
624-
s->_local_consumed += mb.total_length();
625-
s->SendFeedback();
631+
auto total_length = mb.total_length();
632+
if (total_length > 0) {
633+
// fast path for connected stream
634+
if (s->_connected.load(butil::memory_order_acquire)){
635+
if (s->_remote_settings.need_feedback()) {
636+
s->_local_consumed += total_length;
637+
s->SendFeedback(s->_local_consumed);
638+
}
639+
} else {
640+
// Under the scenario of batch creation of Streams, there is concurrency between SetConnected and Consume for the same stream,
641+
// and it is necessary to ensure the memory order.
642+
s->_local_consumed = s->_atomic_local_consumed.fetch_add(total_length, butil::memory_order_release) + total_length;
643+
if (s->_connected.load(butil::memory_order_acquire) && s->_remote_settings.need_feedback()) {
644+
s->SendFeedback(s->_local_consumed);
645+
}
646+
}
626647
}
648+
627649
s->StartIdleTimer();
628650
return 0;
629651
}
630652

631-
void Stream::SendFeedback() {
653+
void Stream::SendFeedback(int64_t _consumed_bytes) {
632654
StreamFrameMeta fm;
633655
fm.set_frame_type(FRAME_TYPE_FEEDBACK);
634656
fm.set_stream_id(_remote_settings.stream_id());
635657
fm.set_source_stream_id(id());
636-
fm.mutable_feedback()->set_consumed_size(_local_consumed);
658+
fm.mutable_feedback()->set_consumed_size(_consumed_bytes);
637659
butil::IOBuf out;
638660
policy::PackStreamMessage(&out, fm, NULL);
639661
WriteToHostSocket(&out);

src/brpc/stream_impl.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ friend struct butil::DefaultDeleter<Stream>;
8181
void TriggerOnConnectIfNeed();
8282
void Wait(void (*on_writable)(StreamId, void*, int), void* arg,
8383
const timespec* due_time, bool new_thread, bthread_id_t *join_id);
84-
void SendFeedback();
84+
void SendFeedback(int64_t _consumed_bytes);
8585
void StartIdleTimer();
8686
void StopIdleTimer();
8787
void HandleRpcResponse(butil::IOBuf* response_buffer);
@@ -115,7 +115,7 @@ friend struct butil::DefaultDeleter<Stream>;
115115

116116
bthread_mutex_t _connect_mutex;
117117
ConnectMeta _connect_meta;
118-
bool _connected;
118+
butil::atomic<bool> _connected;
119119
bool _closed;
120120
int _error_code;
121121
std::string _error_text;
@@ -127,7 +127,8 @@ friend struct butil::DefaultDeleter<Stream>;
127127
bthread_id_list_t _writable_wait_list;
128128

129129
int64_t _local_consumed;
130-
StreamSettings _remote_settings;
130+
butil::atomic<int64_t> _atomic_local_consumed;
131+
StreamSettings _remote_settings;
131132

132133
bool _parse_rpc_response;
133134
bthread::ExecutionQueueId<butil::IOBuf*> _consumer_queue;

test/brpc_streaming_rpc_unittest.cpp

Lines changed: 235 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
// Date: 2015/10/22 16:28:44
2121

2222
#include <gtest/gtest.h>
23+
#include <atomic>
2324
#include "brpc/server.h"
2425

2526
#include "brpc/controller.h"
2627
#include "brpc/channel.h"
28+
#include "brpc/callback.h"
2729
#include "brpc/socket.h"
2830
#include "brpc/stream_impl.h"
2931
#include "brpc/policy/streaming_rpc_protocol.h"
@@ -54,7 +56,7 @@ class MyServiceWithStream : public test::EchoService {
5456
const ::test::EchoRequest* request,
5557
::test::EchoResponse* response,
5658
::google::protobuf::Closure* done) {
57-
brpc::ClosureGuard done_gurad(done);
59+
brpc::ClosureGuard done_guard(done);
5860
response->set_message(request->message());
5961
brpc::Controller* cntl = (brpc::Controller*)controller;
6062
brpc::StreamId response_stream;
@@ -78,6 +80,158 @@ class StreamingRpcTest : public testing::Test {
7880
test::EchoResponse response;
7981
};
8082

83+
struct BatchStreamFeedbackRaceState {
84+
brpc::StreamId server_first_stream_id{brpc::INVALID_STREAM_ID};
85+
brpc::StreamId server_extra_stream_id{brpc::INVALID_STREAM_ID};
86+
brpc::StreamId client_extra_stream_id{brpc::INVALID_STREAM_ID};
87+
88+
std::atomic<int> server_first_write_rc{-1};
89+
std::atomic<int> server_second_write_rc{-1};
90+
std::atomic<bool> client_got_first_msg{false};
91+
std::atomic<bool> client_got_second_msg{false};
92+
std::atomic<bool> server_write_done{false};
93+
std::atomic<bool> rpc_done{false};
94+
95+
bthread_t server_send_tid{0};
96+
std::atomic<bool> server_send_started{false};
97+
};
98+
99+
class BatchStreamClientHandler : public brpc::StreamInputHandler {
100+
public:
101+
explicit BatchStreamClientHandler(BatchStreamFeedbackRaceState* state)
102+
: _state(state) {}
103+
104+
int on_received_messages(brpc::StreamId id,
105+
butil::IOBuf* const messages[],
106+
size_t size) override {
107+
if (id != _state->client_extra_stream_id) {
108+
// This test only cares about extra stream in batch creation.
109+
return 0;
110+
}
111+
for (size_t i = 0; i < size; ++i) {
112+
const size_t len = messages[i]->length();
113+
messages[i]->clear();
114+
// First payload: 64 bytes. Second payload: 1 byte.
115+
if (len == 64) {
116+
_state->client_got_first_msg.store(true, std::memory_order_release);
117+
} else if (len == 1) {
118+
_state->client_got_second_msg.store(true, std::memory_order_release);
119+
}
120+
}
121+
return 0;
122+
}
123+
124+
void on_idle_timeout(brpc::StreamId /*id*/) override {}
125+
126+
void on_closed(brpc::StreamId /*id*/) override {}
127+
128+
void on_failed(brpc::StreamId /*id*/, int /*error_code*/, const std::string& /*error_text*/) override {}
129+
130+
private:
131+
BatchStreamFeedbackRaceState* _state;
132+
};
133+
134+
static void* SendTwoMessagesOnServerExtraStream(void* arg) {
135+
auto* state = static_cast<BatchStreamFeedbackRaceState*>(arg);
136+
const brpc::StreamId sid = state->server_extra_stream_id;
137+
138+
// Wait until server-side stream is connected.
139+
const int64_t connect_deadline_us = butil::gettimeofday_us() + 2 * 1000 * 1000L;
140+
bool connected = false;
141+
while (butil::gettimeofday_us() < connect_deadline_us) {
142+
brpc::SocketUniquePtr ptr;
143+
if (brpc::Socket::Address(sid, &ptr) == 0) {
144+
brpc::Stream* s = static_cast<brpc::Stream*>(ptr->conn());
145+
if (s->_host_socket != NULL && s->_connected) {
146+
connected = true;
147+
break;
148+
}
149+
}
150+
usleep(1000);
151+
}
152+
153+
if (!connected) {
154+
state->server_first_write_rc.store(ETIMEDOUT, std::memory_order_relaxed);
155+
state->server_second_write_rc.store(ETIMEDOUT, std::memory_order_relaxed);
156+
state->server_write_done.store(true, std::memory_order_release);
157+
return NULL;
158+
}
159+
160+
// 1) Send a payload exactly equal to max_buf_size(64).
161+
{
162+
std::string payload(64, 'a');
163+
butil::IOBuf out;
164+
out.append(payload);
165+
state->server_first_write_rc.store(brpc::StreamWrite(sid, out), std::memory_order_relaxed);
166+
}
167+
168+
// 2) Then send another byte. This write should become writable only after
169+
// client sends FEEDBACK with consumed_size >= 64.
170+
const int64_t write_deadline_us = butil::gettimeofday_us() + 2 * 1000 * 1000L;
171+
int rc = -1;
172+
while (butil::gettimeofday_us() < write_deadline_us) {
173+
butil::IOBuf out;
174+
out.append("b", 1);
175+
rc = brpc::StreamWrite(sid, out);
176+
if (rc == 0) {
177+
break;
178+
}
179+
if (rc != EAGAIN) {
180+
break;
181+
}
182+
const timespec duetime = butil::milliseconds_from_now(100);
183+
(void)brpc::StreamWait(sid, &duetime);
184+
}
185+
state->server_second_write_rc.store(rc, std::memory_order_relaxed);
186+
state->server_write_done.store(true, std::memory_order_release);
187+
return NULL;
188+
}
189+
190+
class MyServiceWithBatchStream : public test::EchoService {
191+
public:
192+
MyServiceWithBatchStream(const brpc::StreamOptions& options,
193+
BatchStreamFeedbackRaceState* state)
194+
: _options(options), _state(state) {}
195+
196+
void Echo(::google::protobuf::RpcController* controller,
197+
const ::test::EchoRequest* request,
198+
::test::EchoResponse* response,
199+
::google::protobuf::Closure* done) override {
200+
brpc::ClosureGuard done_guard(done);
201+
response->set_message(request->message());
202+
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
203+
204+
brpc::StreamIds response_streams;
205+
ASSERT_EQ(0, brpc::StreamAccept(response_streams, *cntl, &_options));
206+
ASSERT_EQ(2u, response_streams.size());
207+
_state->server_first_stream_id = response_streams[0];
208+
_state->server_extra_stream_id = response_streams[1];
209+
210+
bthread_t tid;
211+
ASSERT_EQ(0, bthread_start_background(
212+
&tid, &BTHREAD_ATTR_NORMAL,
213+
SendTwoMessagesOnServerExtraStream, _state));
214+
_state->server_send_tid = tid;
215+
_state->server_send_started.store(true, std::memory_order_release);
216+
}
217+
218+
private:
219+
brpc::StreamOptions _options;
220+
BatchStreamFeedbackRaceState* _state;
221+
};
222+
223+
static void SetAtomicTrue(std::atomic<bool>* f) {
224+
f->store(true, std::memory_order_release);
225+
}
226+
227+
static bool WaitForTrue(const std::atomic<bool>& f, int timeout_ms) {
228+
const int64_t deadline_us = butil::gettimeofday_us() + (int64_t)timeout_ms * 1000L;
229+
while (!f.load(std::memory_order_acquire) && butil::gettimeofday_us() < deadline_us) {
230+
usleep(1000);
231+
}
232+
return f.load(std::memory_order_acquire);
233+
}
234+
81235
TEST_F(StreamingRpcTest, sanity) {
82236
brpc::Server server;
83237
MyServiceWithStream service;
@@ -98,6 +252,86 @@ TEST_F(StreamingRpcTest, sanity) {
98252
server.Join();
99253
}
100254

255+
TEST_F(StreamingRpcTest, batch_create_stream_feedback_race) {
256+
BatchStreamFeedbackRaceState state;
257+
BatchStreamClientHandler client_handler(&state);
258+
259+
brpc::StreamOptions server_stream_opt;
260+
// Make server-side sender sensitive to FEEDBACK quickly.
261+
server_stream_opt.max_buf_size = 16;
262+
263+
brpc::Server server;
264+
MyServiceWithBatchStream service(server_stream_opt, &state);
265+
ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE));
266+
ASSERT_EQ(0, server.Start(9007, NULL));
267+
268+
brpc::Channel channel;
269+
ASSERT_EQ(0, channel.Init("127.0.0.1:9007", NULL));
270+
271+
brpc::Controller cntl;
272+
brpc::StreamIds request_streams;
273+
brpc::StreamOptions client_stream_opt;
274+
client_stream_opt.handler = &client_handler;
275+
client_stream_opt.max_buf_size = 0;
276+
ASSERT_EQ(0, brpc::StreamCreate(request_streams, 2, cntl, &client_stream_opt));
277+
ASSERT_EQ(2u, request_streams.size());
278+
state.client_extra_stream_id = request_streams[1];
279+
280+
// Block SetConnected() on the extra stream to enlarge the race window.
281+
brpc::SocketUniquePtr client_extra_ptr;
282+
ASSERT_EQ(0, brpc::Socket::Address(state.client_extra_stream_id, &client_extra_ptr));
283+
brpc::Stream* client_extra_stream = static_cast<brpc::Stream*>(client_extra_ptr->conn());
284+
bthread_mutex_lock(&client_extra_stream->_connect_mutex);
285+
struct UnlockGuard {
286+
bthread_mutex_t* m;
287+
~UnlockGuard() {
288+
if (m) {
289+
bthread_mutex_unlock(m);
290+
}
291+
}
292+
} unlock_guard{&client_extra_stream->_connect_mutex};
293+
294+
BRPC_SCOPE_EXIT {
295+
if (state.server_extra_stream_id != brpc::INVALID_STREAM_ID) {
296+
brpc::StreamClose(state.server_extra_stream_id);
297+
}
298+
if (state.server_first_stream_id != brpc::INVALID_STREAM_ID) {
299+
brpc::StreamClose(state.server_first_stream_id);
300+
}
301+
for (auto sid : request_streams) {
302+
brpc::StreamClose(sid);
303+
}
304+
305+
if (state.server_send_tid) {
306+
bthread_join(state.server_send_tid, NULL);
307+
}
308+
server.Stop(0);
309+
server.Join();
310+
};
311+
312+
test::EchoService_Stub stub(&channel);
313+
stub.Echo(&cntl, &request, &response, brpc::NewCallback(SetAtomicTrue, &state.rpc_done));
314+
315+
// Wait until client consumes the first 64B payload on extra stream.
316+
ASSERT_TRUE(WaitForTrue(state.client_got_first_msg, 2000));
317+
318+
// Unblock SetConnected(); the fix in PR 3215 should send the first FEEDBACK
319+
// with consumed_size=64 here, making server-side stream writable again.
320+
bthread_mutex_unlock(&client_extra_stream->_connect_mutex);
321+
unlock_guard.m = NULL;
322+
323+
ASSERT_TRUE(WaitForTrue(state.rpc_done, 2000));
324+
ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText();
325+
326+
// Wait for server-side send thread to be started.
327+
ASSERT_TRUE(WaitForTrue(state.server_send_started, 2000));
328+
329+
ASSERT_TRUE(WaitForTrue(state.server_write_done, 2000));
330+
ASSERT_EQ(0, state.server_first_write_rc.load(std::memory_order_relaxed));
331+
ASSERT_EQ(0, state.server_second_write_rc.load(std::memory_order_relaxed));
332+
ASSERT_TRUE(WaitForTrue(state.client_got_second_msg, 2000));
333+
}
334+
101335
struct HandlerControl {
102336
HandlerControl()
103337
: block(false)

0 commit comments

Comments
 (0)