Skip to content

Commit ba9e838

Browse files
authored
Fix thread safety of Wrapper (#2952)
1 parent 1337db0 commit ba9e838

4 files changed

Lines changed: 78 additions & 86 deletions

File tree

src/brpc/policy/http_rpc_protocol.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,7 @@ static bool ProtoMessageToProtoJson(const google::protobuf::Message& message,
331331
butil::IOBufAsZeroCopyOutputStream* wrapper,
332332
Controller* cntl, int error_code) {
333333
json2pb::Pb2ProtoJsonOptions options;
334-
#if GOOGLE_PROTOBUF_VERSION >= 5026002
335-
options.always_print_fields_with_no_presence = cntl->has_always_print_primitive_fields();
336-
#else
337-
options.always_print_primitive_fields = cntl->has_always_print_primitive_fields();
338-
#endif
334+
AlwaysPrintPrimitiveFields(options) = cntl->has_always_print_primitive_fields();
339335
options.always_print_enums_as_ints = FLAGS_pb_enum_as_number;
340336
std::string error;
341337
bool ok = json2pb::ProtoMessageToProtoJson(message, wrapper, options, &error);

src/brpc/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1972,7 +1972,7 @@ bool IsDummyServerRunning() {
19721972
}
19731973

19741974
const Server::MethodProperty*
1975-
Server::FindMethodPropertyByFullName(const butil::StringPiece&fullname) const {
1975+
Server::FindMethodPropertyByFullName(const butil::StringPiece& fullname) const {
19761976
return _method_map.seek(fullname);
19771977
}
19781978

src/butil/containers/doubly_buffered_data.h

Lines changed: 74 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
#define BUTIL_DOUBLY_BUFFERED_DATA_H
2222

2323
#include <deque>
24-
#include <vector> // std::vector
24+
#include <vector>
25+
#include <memory>
2526
#include <pthread.h>
2627
#include "butil/scoped_lock.h"
2728
#include "butil/thread_local.h"
@@ -87,6 +88,8 @@ class DoublyBufferedData {
8788
class Wrapper;
8889
class WrapperTLSGroup;
8990
typedef int WrapperTLSId;
91+
typedef std::shared_ptr<Wrapper> WrapperSharedPtr;
92+
typedef std::weak_ptr<Wrapper> WrapperWeakPtr;
9093
public:
9194
class ScopedPtr {
9295
friend class DoublyBufferedData;
@@ -111,7 +114,7 @@ class DoublyBufferedData {
111114
const T* _data;
112115
// Index of foreground instance used by ScopedPtr.
113116
int _index;
114-
Wrapper* _w;
117+
WrapperSharedPtr _w;
115118
};
116119

117120
DoublyBufferedData();
@@ -152,8 +155,7 @@ class DoublyBufferedData {
152155
return _data + index;
153156
}
154157

155-
Wrapper* AddWrapper(Wrapper*);
156-
void RemoveWrapper(Wrapper*);
158+
WrapperSharedPtr GetWrapper();
157159

158160
// Foreground and background void.
159161
T _data[2];
@@ -165,7 +167,7 @@ class DoublyBufferedData {
165167
WrapperTLSId _wrapper_key;
166168

167169
// All thread-local instances.
168-
std::vector<Wrapper*> _wrappers;
170+
std::vector<WrapperWeakPtr> _wrappers;
169171

170172
// Sequence access to _wrappers.
171173
pthread_mutex_t _wrappers_mutex{};
@@ -195,18 +197,22 @@ class DoublyBufferedData<T, TLS, AllowBthreadSuspended>::WrapperTLSGroup {
195197
public:
196198
const static size_t RAW_BLOCK_SIZE = 4096;
197199
const static size_t ELEMENTS_PER_BLOCK =
198-
RAW_BLOCK_SIZE / sizeof(Wrapper) > 0 ? RAW_BLOCK_SIZE / sizeof(Wrapper) : 1;
200+
RAW_BLOCK_SIZE / sizeof(WrapperSharedPtr) > 0 ?
201+
RAW_BLOCK_SIZE / sizeof(WrapperSharedPtr) : 1;
199202

200203
struct BAIDU_CACHELINE_ALIGNMENT ThreadBlock {
201-
inline DoublyBufferedData::Wrapper* at(size_t offset) {
202-
return _data + offset;
204+
WrapperSharedPtr at(size_t offset) {
205+
if (NULL == _data[offset]) {
206+
_data[offset] = std::make_shared<Wrapper>();
207+
}
208+
return _data[offset];
203209
};
204210

205211
private:
206-
DoublyBufferedData::Wrapper _data[ELEMENTS_PER_BLOCK];
212+
WrapperSharedPtr _data[ELEMENTS_PER_BLOCK];
207213
};
208214

209-
inline static WrapperTLSId key_create() {
215+
static WrapperTLSId key_create() {
210216
BAIDU_SCOPED_LOCK(_s_mutex);
211217
WrapperTLSId id = 0;
212218
if (!_get_free_ids().empty()) {
@@ -218,7 +224,7 @@ class DoublyBufferedData<T, TLS, AllowBthreadSuspended>::WrapperTLSGroup {
218224
return id;
219225
}
220226

221-
inline static int key_delete(WrapperTLSId id) {
227+
static int key_delete(WrapperTLSId id) {
222228
BAIDU_SCOPED_LOCK(_s_mutex);
223229
if (id < 0 || id >= _s_id) {
224230
errno = EINVAL;
@@ -228,17 +234,13 @@ class DoublyBufferedData<T, TLS, AllowBthreadSuspended>::WrapperTLSGroup {
228234
return 0;
229235
}
230236

231-
inline static DoublyBufferedData::Wrapper* get_or_create_tls_data(WrapperTLSId id) {
237+
static WrapperSharedPtr get_or_create_tls_data(WrapperTLSId id) {
232238
if (BAIDU_UNLIKELY(id < 0)) {
233239
CHECK(false) << "Invalid id=" << id;
234240
return NULL;
235241
}
236242
if (_s_tls_blocks == NULL) {
237-
_s_tls_blocks = new (std::nothrow) std::vector<ThreadBlock*>;
238-
if (BAIDU_UNLIKELY(_s_tls_blocks == NULL)) {
239-
LOG(FATAL) << "Fail to create vector, " << berror();
240-
return NULL;
241-
}
243+
_s_tls_blocks = new std::vector<ThreadBlock*>;
242244
butil::thread_atexit(_destroy_tls_blocks);
243245
}
244246
const size_t block_id = (size_t)id / ELEMENTS_PER_BLOCK;
@@ -248,12 +250,8 @@ class DoublyBufferedData<T, TLS, AllowBthreadSuspended>::WrapperTLSGroup {
248250
}
249251
ThreadBlock* tb = (*_s_tls_blocks)[block_id];
250252
if (tb == NULL) {
251-
ThreadBlock* new_block = new (std::nothrow) ThreadBlock;
252-
if (BAIDU_UNLIKELY(new_block == NULL)) {
253-
return NULL;
254-
}
255-
tb = new_block;
256-
(*_s_tls_blocks)[block_id] = new_block;
253+
tb = new ThreadBlock;
254+
(*_s_tls_blocks)[block_id] = tb;
257255
}
258256
return tb->at(id - block_id * ELEMENTS_PER_BLOCK);
259257
}
@@ -316,10 +314,6 @@ friend class DoublyBufferedData;
316314
}
317315

318316
~Wrapper() {
319-
if (_control != NULL) {
320-
_control->RemoveWrapper(this);
321-
}
322-
323317
if (AllowBthreadSuspended) {
324318
WaitReadDone(0);
325319
WaitReadDone(1);
@@ -406,9 +400,9 @@ friend class DoublyBufferedData;
406400

407401
// Called when thread initializes thread-local wrapper.
408402
template <typename T, typename TLS, bool AllowBthreadSuspended>
409-
typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper*
410-
DoublyBufferedData<T, TLS, AllowBthreadSuspended>::AddWrapper(
411-
typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper* w) {
403+
typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::WrapperSharedPtr
404+
DoublyBufferedData<T, TLS, AllowBthreadSuspended>::GetWrapper() {
405+
WrapperSharedPtr w = WrapperTLSGroup::get_or_create_tls_data(_wrapper_key);
412406
if (NULL == w) {
413407
return NULL;
414408
}
@@ -423,29 +417,19 @@ DoublyBufferedData<T, TLS, AllowBthreadSuspended>::AddWrapper(
423417
w->_control = this;
424418
BAIDU_SCOPED_LOCK(_wrappers_mutex);
425419
_wrappers.push_back(w);
420+
// The chance to remove expired weak_ptr.
421+
_wrappers.erase(
422+
std::remove_if(_wrappers.begin(), _wrappers.end(),
423+
[](const WrapperWeakPtr& w) {
424+
return w.expired();
425+
}),
426+
_wrappers.end());
426427
} catch (std::exception& e) {
427428
return NULL;
428429
}
429430
return w;
430431
}
431432

432-
// Called when thread quits.
433-
template <typename T, typename TLS, bool AllowBthreadSuspended>
434-
void DoublyBufferedData<T, TLS, AllowBthreadSuspended>::RemoveWrapper(
435-
typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper* w) {
436-
if (NULL == w) {
437-
return;
438-
}
439-
BAIDU_SCOPED_LOCK(_wrappers_mutex);
440-
for (size_t i = 0; i < _wrappers.size(); ++i) {
441-
if (_wrappers[i] == w) {
442-
_wrappers[i] = _wrappers.back();
443-
_wrappers.pop_back();
444-
return;
445-
}
446-
}
447-
}
448-
449433
template <typename T, typename TLS, bool AllowBthreadSuspended>
450434
DoublyBufferedData<T, TLS, AllowBthreadSuspended>::DoublyBufferedData()
451435
: _index(0)
@@ -474,7 +458,10 @@ DoublyBufferedData<T, TLS, AllowBthreadSuspended>::~DoublyBufferedData() {
474458
{
475459
BAIDU_SCOPED_LOCK(_wrappers_mutex);
476460
for (size_t i = 0; i < _wrappers.size(); ++i) {
477-
_wrappers[i]->_control = NULL; // hack: disable removal.
461+
WrapperSharedPtr w = _wrappers[i].lock();
462+
if (NULL != w) {
463+
w->_control = NULL; // hack: disable removal.
464+
}
478465
}
479466
_wrappers.clear();
480467
}
@@ -487,29 +474,28 @@ DoublyBufferedData<T, TLS, AllowBthreadSuspended>::~DoublyBufferedData() {
487474
template <typename T, typename TLS, bool AllowBthreadSuspended>
488475
int DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Read(
489476
typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::ScopedPtr* ptr) {
490-
Wrapper* p = WrapperTLSGroup::get_or_create_tls_data(_wrapper_key);
491-
Wrapper* w = AddWrapper(p);
492-
if (BAIDU_LIKELY(w != NULL)) {
493-
if (AllowBthreadSuspended) {
494-
// Use reference count instead of mutex to indicate read of
495-
// foreground instance, so during the read process, there is
496-
// no need to lock mutex and bthread is allowed to be suspended.
497-
w->BeginRead();
498-
int index = -1;
499-
ptr->_data = UnsafeRead(index);
500-
ptr->_index = index;
501-
w->AddRef(index);
502-
ptr->_w = w;
503-
w->BeginReadRelease();
504-
} else {
505-
w->BeginRead();
506-
ptr->_data = UnsafeRead();
507-
ptr->_w = w;
508-
}
477+
WrapperSharedPtr w = GetWrapper();
478+
if (BAIDU_UNLIKELY(w == NULL)) {
479+
return -1;
480+
}
509481

510-
return 0;
482+
if (AllowBthreadSuspended) {
483+
// Use reference count instead of mutex to indicate read of
484+
// foreground instance, so during the read process, there is
485+
// no need to lock mutex and bthread is allowed to be suspended.
486+
w->BeginRead();
487+
int index = -1;
488+
ptr->_data = UnsafeRead(index);
489+
ptr->_index = index;
490+
w->AddRef(index);
491+
ptr->_w = w;
492+
w->BeginReadRelease();
493+
} else {
494+
w->BeginRead();
495+
ptr->_data = UnsafeRead();
496+
ptr->_w = w;
511497
}
512-
return -1;
498+
return 0;
513499
}
514500

515501
template <typename T, typename TLS, bool AllowBthreadSuspended>
@@ -530,7 +516,7 @@ template <typename Fn, typename... Args>
530516
size_t DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Modify(Fn&& fn, Args&&... args) {
531517
// _modify_mutex sequences modifications. Using a separate mutex rather
532518
// than _wrappers_mutex is to avoid blocking threads calling
533-
// AddWrapper() or RemoveWrapper() too long. Most of the time, modifications
519+
// GetWrapper() too long. Most of the time, modifications
534520
// are done by one thread, contention should be negligible.
535521
BAIDU_SCOPED_LOCK(_modify_mutex);
536522
int bg_index = !_index.load(butil::memory_order_relaxed);
@@ -552,14 +538,24 @@ size_t DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Modify(Fn&& fn, Args&&
552538
// read, they should see updated _index.
553539
{
554540
BAIDU_SCOPED_LOCK(_wrappers_mutex);
555-
for (size_t i = 0; i < _wrappers.size(); ++i) {
556-
// Wait read of old foreground instance done.
557-
if (AllowBthreadSuspended) {
558-
_wrappers[i]->WaitReadDone(bg_index);
559-
} else {
560-
_wrappers[i]->WaitReadDone();
561-
}
562-
}
541+
// The chance to remove expired weak_ptr.
542+
_wrappers.erase(
543+
std::remove_if(_wrappers.begin(), _wrappers.end(),
544+
[bg_index](const WrapperWeakPtr& weak) {
545+
WrapperSharedPtr w = weak.lock();
546+
bool expired = NULL == w;
547+
if (!expired) {
548+
// Notify all threads waiting for read done.
549+
if (AllowBthreadSuspended) {
550+
w->WaitReadDone(bg_index);
551+
} else {
552+
w->WaitReadDone();
553+
}
554+
}
555+
// Remove expired weak_ptr.
556+
return expired;
557+
}),
558+
_wrappers.end());
563559
}
564560

565561
const size_t ret2 = fn(_data[bg_index], std::forward<Args>(args)...);

src/json2pb/pb_to_json.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,14 @@ bool ProtoMessageToJson(const google::protobuf::Message& message,
336336
}
337337

338338
bool ProtoMessageToJson(const google::protobuf::Message& message,
339-
google::protobuf::io::ZeroCopyOutputStream *stream,
339+
google::protobuf::io::ZeroCopyOutputStream* stream,
340340
const Pb2JsonOptions& options, std::string* error) {
341341
json2pb::ZeroCopyStreamWriter wrapper(stream);
342342
return json2pb::ProtoMessageToJsonStream(message, options, wrapper, error);
343343
}
344344

345345
bool ProtoMessageToJson(const google::protobuf::Message& message,
346-
google::protobuf::io::ZeroCopyOutputStream *stream,
346+
google::protobuf::io::ZeroCopyOutputStream* stream,
347347
std::string* error) {
348348
return ProtoMessageToJson(message, stream, Pb2JsonOptions(), error);
349349
}

0 commit comments

Comments
 (0)