Skip to content

Commit 1aa40ba

Browse files
committed
WebSocketImpl: Improved handling of deferred tasks
WebSocketImpl has several methods that are called with the mutex locked, but have to call external code that can call back in and deadlock. The current code has some awkward special-purpose mechanisms to defer those calls until the outer method releases the lock. I haven't figured out the best way to do this, but I added a small utility class that allows for queuing up lambdas that will be called later after the mutex unlocks.
1 parent ac7fe95 commit 1aa40ba

1 file changed

Lines changed: 92 additions & 57 deletions

File tree

Networking/WebSockets/WebSocketImpl.cc

Lines changed: 92 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
//
1212

1313
#include "WebSocketImpl.hh"
14-
#include "Stopwatch.hh"
15-
#include "WebSocketProtocol.hh"
1614
#include "Error.hh"
15+
#include "NumConversion.hh"
16+
#include "Stopwatch.hh"
1717
#include "StringUtil.hh"
1818
#include "Timer.hh"
19-
#include "NumConversion.hh"
19+
#include "WebSocketProtocol.hh"
2020
#include <chrono>
2121
#include <cstdlib>
2222
#include <functional>
@@ -71,6 +71,8 @@ namespace litecore::websocket {
7171
* non-underscored methods, and only when not locked.
7272
*/
7373
struct WebSocketImpl::impl : LoggingProxy {
74+
class LockWithDefer;
75+
7476
enum SocketLifecycleState : int { SOCKET_UNINIT, SOCKET_OPENING, SOCKET_OPENED, SOCKET_CLOSING, SOCKET_CLOSED };
7577

7678
// Immutable state:
@@ -99,22 +101,20 @@ namespace litecore::websocket {
99101
bool _timedOut{false}; // True if _responseTimer timed out
100102
alloc_slice _protocolError; // Error message from WebSocketProtocol
101103
bool _didConnect{false}; // True if I've connected
102-
uint8_t _opToSend{}; // Opcode for _msgToSend
103-
alloc_slice _msgToSend; // Stores a message to send during onReceive
104-
vector<Ref<Message>> _msgsReceived; // Stores messages received during onReceive
105104
SocketLifecycleState _socketLCState{}; // Lifecycle state
105+
LockWithDefer* _lockWithDefer{};
106106

107107
// Connection diagnostics, logged on close:
108108
Stopwatch _timeConnected{false}; // Time since socket opened
109109
uint64_t _bytesSent{0}, _bytesReceived{0}; // Total byte count sent/received
110110

111111
impl(WebSocketImpl& webSocket, bool framing, Parameters parameters)
112-
: LoggingProxy(&webSocket)
113-
, _parameters(std::move(parameters))
114-
, _webSocket{webSocket}
115-
, _framing(framing)
116-
, _heartbeatInterval{computeHeartbeatInterval(_framing, _parameters)}
117-
, _responseTimer(new actor::Timer([this] { timedOut(); })) {
112+
: LoggingProxy(&webSocket)
113+
, _parameters(std::move(parameters))
114+
, _webSocket{webSocket}
115+
, _framing(framing)
116+
, _heartbeatInterval{computeHeartbeatInterval(_framing, _parameters)}
117+
, _responseTimer(new actor::Timer([this] { timedOut(); })) {
118118
if ( framing ) {
119119
if ( webSocket.role() == Role::Server ) _serverProtocol = make_unique<ServerProtocol>();
120120
else
@@ -210,7 +210,7 @@ namespace litecore::websocket {
210210

211211
// Protected API. Called when an async write has completed.
212212
void onWriteComplete(size_t size) {
213-
unique_lock lock(_mutex);
213+
LockWithDefer lock(this);
214214

215215
_bytesSent += size;
216216
bool notify = (_bufferedBytes > kSendBufferSize);
@@ -229,14 +229,11 @@ namespace litecore::websocket {
229229

230230
// Protected API. Called when a WebSocket frame is received.
231231
void onReceive(slice data) {
232-
ssize_t completedBytes = 0;
233-
uint8_t opToSend = 0;
234-
alloc_slice msgToSend;
235-
vector<Ref<Message>> msgsReceived;
232+
ssize_t completedBytes = 0;
236233
{
237234
// Lock the mutex; this protects all methods (below) involved in receiving,
238235
// since they're called from this one.
239-
unique_lock lock(_mutex);
236+
LockWithDefer lock(this);
240237

241238
if ( data.empty() && !_closeReceived ) {
242239
// We assume empty data means a zero-length read, i.e. EOF
@@ -254,26 +251,17 @@ namespace litecore::websocket {
254251
if ( _clientProtocol ) _clientProtocol->consume((byte*)data.buf, data.size, this);
255252
else
256253
_serverProtocol->consume((byte*)data.buf, data.size, this);
257-
opToSend = _opToSend;
258-
msgToSend = std::move(_msgToSend);
259254
// Compute # of bytes consumed: just the framing data, not any partial or
260255
// delivered messages. (Trust me, the math works.)
261256
completedBytes =
262257
narrow_cast<ssize_t>(data.size + prevMessageLength - _curMessageLength - _deliveredBytes);
263258
} else {
264259
_deliverMessageToDelegate(alloc_slice(data));
265260
}
266-
msgsReceived = std::move(_msgsReceived);
267261
}
268262

269263
// After unlocking, tell subclass how many incoming bytes have been handled:
270264
if ( completedBytes > 0 ) _webSocket.receiveComplete(completedBytes);
271-
272-
// Send any message that was generated during the locked block above:
273-
if ( msgToSend ) sendOp(msgToSend, opToSend);
274-
275-
// Similarly, deliver any messages received:
276-
for ( auto& msg : msgsReceived ) _webSocket.delegateWeak()->invoke(&Delegate::onWebSocketMessage, msg);
277265
}
278266

279267
// Called from inside _protocol->consume(), with the _mutex locked
@@ -318,10 +306,12 @@ namespace litecore::websocket {
318306
case CLOSE:
319307
return _receivedClose(message);
320308
case PING:
321-
logInfo("Received PING -- sending PONG");
322-
_opToSend = PONG;
323-
_msgToSend = message ? message : alloc_slice(size_t(0));
324-
return true;
309+
{
310+
logInfo("Received PING -- sending PONG");
311+
alloc_slice msgToSend = message ? message : alloc_slice(size_t(0));
312+
defer([=, this] { sendOp(msgToSend, PONG); });
313+
return true;
314+
}
325315
case PONG:
326316
_receivedPong();
327317
return true;
@@ -337,13 +327,11 @@ namespace litecore::websocket {
337327
_callCloseSocket();
338328
}
339329

340-
void _deliverMessageToDelegate(alloc_slice message) {
341-
logVerbose("Received %zu-byte message", message.size);
342-
_deliveredBytes += message.size;
343-
// We can't call the delegate now because the mutex is locked and from here there's no
344-
// way to unlock it. Instead, we store the message in `_msgsReceived`, and my caller
345-
// `onReceive()` will unlock the mutex and then deliver it.
346-
_msgsReceived.emplace_back(new MessageImpl(&_webSocket, std::move(message), true));
330+
void _deliverMessageToDelegate(alloc_slice messageBody) {
331+
logVerbose("Received %zu-byte message", messageBody.size);
332+
_deliveredBytes += messageBody.size;
333+
auto message = make_retained<MessageImpl>(&_webSocket, std::move(messageBody), true);
334+
defer([=, this] { _webSocket.delegateWeak()->invoke(&Delegate::onWebSocketMessage, message); });
347335
}
348336

349337
#pragma mark - HEARTBEAT:
@@ -395,7 +383,7 @@ namespace litecore::websocket {
395383

396384
// timer callback
397385
void timedOut() {
398-
unique_lock lock(_mutex);
386+
LockWithDefer lock(this);
399387

400388
if ( _timerDisabled ) return;
401389
if ( Timer::clock::now() - _lastReceiveTime < _curTimeout ) return;
@@ -428,7 +416,7 @@ namespace litecore::websocket {
428416
if ( state != SOCKET_OPENED ) { logVerbose("Calling closeSocket before the socket is open"); }
429417
_socketLCState = SOCKET_CLOSING;
430418
_startResponseTimer(kCloseTimeout);
431-
_webSocket.closeSocket();
419+
defer([this] { _webSocket.closeSocket(); });
432420
} else {
433421
logVerbose("Calling closeSocket when the socket is %s",
434422
state == SOCKET_CLOSING ? "pending close" : "already closed");
@@ -442,10 +430,13 @@ namespace litecore::websocket {
442430
logVerbose("Calling requestClose before the socket is connected");
443431
[[fallthrough]];
444432
case SOCKET_OPENED:
445-
_socketLCState = SOCKET_CLOSING;
446-
_startResponseTimer(kCloseTimeout);
447-
_webSocket.requestClose(status, message);
448-
break;
433+
{
434+
_socketLCState = SOCKET_CLOSING;
435+
_startResponseTimer(kCloseTimeout);
436+
alloc_slice allocedMessage(message);
437+
defer([=, this] { _webSocket.requestClose(status, allocedMessage); });
438+
break;
439+
}
449440
case SOCKET_CLOSING:
450441
logVerbose("Calling requestClose when the socket is pending close");
451442
break;
@@ -457,23 +448,23 @@ namespace litecore::websocket {
457448

458449
// Public API. Initiates a request to close the connection cleanly.
459450
void close(int status, slice message) {
460-
unique_lock lock(_mutex);
451+
LockWithDefer lock(this);
461452

462453
switch ( _socketLCState ) {
463454
case SOCKET_CLOSING:
464455
logVerbose("Calling close when the socket is pending close");
465-
return;
456+
break;
466457
case SOCKET_CLOSED:
467458
logVerbose("Calling close when the socket is already closed");
468-
return;
459+
break;
469460
case SOCKET_OPENED:
470461
logInfo("Requesting close with status=%d, message='%.*s'", status, SPLAT(message));
471462
if ( _framing ) {
472463
if ( _closeSent || _closeReceived ) {
473464
logVerbose("Close already processed (_closeSent: %d, _closeReceived: %d), exiting "
474465
"WebSocketImpl::close()",
475466
(int)_closeSent, (int)_closeReceived);
476-
return;
467+
break;
477468
}
478469

479470
auto closeMsg = alloc_slice(2 + message.size);
@@ -483,13 +474,11 @@ namespace litecore::websocket {
483474
_closeSent = true;
484475
_closeMessage = closeMsg;
485476
_startResponseTimer(kCloseTimeout);
486-
487-
lock.unlock(); // UNLOCK MUTEX to call sendOp()
488-
sendOp(closeMsg, CLOSE);
477+
defer([=, this] { sendOp(closeMsg, CLOSE); });
489478
} else {
490479
_callRequestClose(status, message);
491480
}
492-
return;
481+
break;
493482
case SOCKET_OPENING:
494483
logInfo("Closing socket before connection established...");
495484
if ( _framing ) {
@@ -499,10 +488,10 @@ namespace litecore::websocket {
499488
} else {
500489
_callRequestClose(status, message);
501490
}
502-
return;
491+
break;
503492
case SOCKET_UNINIT:
504493
_callCloseSocket();
505-
return;
494+
break;
506495
}
507496
}
508497

@@ -523,9 +512,7 @@ namespace litecore::websocket {
523512
}
524513
_closeSent = true;
525514
_closeMessage = message;
526-
// Don't send the message now or I'll deadlock; remember to do it later in onReceive:
527-
_msgToSend = message;
528-
_opToSend = CLOSE;
515+
defer([=, this] { sendOp(_closeMessage, CLOSE); });
529516
}
530517
_timerDisabled = true;
531518
return true;
@@ -640,6 +627,54 @@ namespace litecore::websocket {
640627
delegate->invoke(&Delegate::onWebSocketClose, status);
641628
}
642629
}
630+
631+
/** Utility class that locks `_mutex` and enables use of the `defer()` function below.
632+
* It should be instantiated as `LockWithDefer lock(this);`. */
633+
class LockWithDefer {
634+
public:
635+
explicit LockWithDefer(impl* owner) : _owner(owner), _lock(owner->_mutex) {
636+
DebugAssert(owner->_lockWithDefer == nullptr);
637+
owner->_lockWithDefer = this;
638+
}
639+
640+
void defer(function<void()> action) {
641+
DebugAssert(_lock.owns_lock());
642+
_actions.emplace_back(std::move(action));
643+
}
644+
645+
void unlock() {
646+
DebugAssert(_lock.owns_lock());
647+
_owner->_lockWithDefer = nullptr;
648+
_lock.unlock();
649+
}
650+
651+
~LockWithDefer() {
652+
if ( _lock.owns_lock() ) _owner->_lockWithDefer = nullptr;
653+
if ( !_actions.empty() ) {
654+
_lock.unlock();
655+
for ( auto& action : _actions ) {
656+
try {
657+
action();
658+
} catch ( ... ) {
659+
#ifdef _MSC_VER
660+
C4Error::warnCurrentException(__FUNCSIG__);
661+
#else
662+
C4Error::warnCurrentException(__PRETTY_FUNCTION__);
663+
#endif
664+
}
665+
}
666+
}
667+
}
668+
669+
private:
670+
impl* const _owner;
671+
unique_lock<mutex> _lock;
672+
vector<function<void()>> _actions; // can't use smallVector: std::function is not trivially moveable
673+
};
674+
675+
/// Schedules a function to be called immediately after the current lock is released.
676+
/// Precondition: Some caller must have a local `LockWithDefer` instance.
677+
void defer(function<void()> fn) { _lockWithDefer->defer(std::move(fn)); }
643678
};
644679

645680
#pragma mark - WEBSOCKET IMPL:

0 commit comments

Comments
 (0)