1111//
1212
1313#include " WebSocketImpl.hh"
14- #include " Stopwatch.hh"
15- #include " WebSocketProtocol.hh"
1614#include " Error.hh"
15+ #include " NumConversion.hh"
16+ #include " Stopwatch.hh"
1717#include " StringUtil.hh"
1818#include " Timer.hh"
19- #include " NumConversion .hh"
19+ #include " WebSocketProtocol .hh"
2020#include < chrono>
2121#include < cstdlib>
2222#include < functional>
@@ -71,6 +71,8 @@ namespace litecore::websocket {
7171 * non-underscored methods, and only when not locked.
7272 */
7373 struct WebSocketImpl ::impl : LoggingProxy {
74+ class LockWithDefer ;
75+
7476 enum SocketLifecycleState : int { SOCKET_UNINIT, SOCKET_OPENING, SOCKET_OPENED, SOCKET_CLOSING, SOCKET_CLOSED };
7577
7678 // Immutable state:
@@ -99,22 +101,20 @@ namespace litecore::websocket {
99101 bool _timedOut{false }; // True if _responseTimer timed out
100102 alloc_slice _protocolError; // Error message from WebSocketProtocol
101103 bool _didConnect{false }; // True if I've connected
102- uint8_t _opToSend{}; // Opcode for _msgToSend
103- alloc_slice _msgToSend; // Stores a message to send during onReceive
104- vector<Ref<Message>> _msgsReceived; // Stores messages received during onReceive
105104 SocketLifecycleState _socketLCState{}; // Lifecycle state
105+ LockWithDefer* _lockWithDefer{};
106106
107107 // Connection diagnostics, logged on close:
108108 Stopwatch _timeConnected{false }; // Time since socket opened
109109 uint64_t _bytesSent{0 }, _bytesReceived{0 }; // Total byte count sent/received
110110
111111 impl (WebSocketImpl& webSocket, bool framing, Parameters parameters)
112- : LoggingProxy(&webSocket)
113- , _parameters(std::move(parameters))
114- , _webSocket{webSocket}
115- , _framing(framing)
116- , _heartbeatInterval{computeHeartbeatInterval (_framing, _parameters)}
117- , _responseTimer(new actor::Timer([this ] { timedOut (); })) {
112+ : LoggingProxy(&webSocket)
113+ , _parameters(std::move(parameters))
114+ , _webSocket{webSocket}
115+ , _framing(framing)
116+ , _heartbeatInterval{computeHeartbeatInterval (_framing, _parameters)}
117+ , _responseTimer(new actor::Timer([this ] { timedOut (); })) {
118118 if ( framing ) {
119119 if ( webSocket.role () == Role::Server ) _serverProtocol = make_unique<ServerProtocol>();
120120 else
@@ -210,7 +210,7 @@ namespace litecore::websocket {
210210
211211 // Protected API. Called when an async write has completed.
212212 void onWriteComplete (size_t size) {
213- unique_lock lock (_mutex );
213+ LockWithDefer lock (this );
214214
215215 _bytesSent += size;
216216 bool notify = (_bufferedBytes > kSendBufferSize );
@@ -229,14 +229,11 @@ namespace litecore::websocket {
229229
230230 // Protected API. Called when a WebSocket frame is received.
231231 void onReceive (slice data) {
232- ssize_t completedBytes = 0 ;
233- uint8_t opToSend = 0 ;
234- alloc_slice msgToSend;
235- vector<Ref<Message>> msgsReceived;
232+ ssize_t completedBytes = 0 ;
236233 {
237234 // Lock the mutex; this protects all methods (below) involved in receiving,
238235 // since they're called from this one.
239- unique_lock lock (_mutex );
236+ LockWithDefer lock (this );
240237
241238 if ( data.empty () && !_closeReceived ) {
242239 // We assume empty data means a zero-length read, i.e. EOF
@@ -254,26 +251,17 @@ namespace litecore::websocket {
254251 if ( _clientProtocol ) _clientProtocol->consume ((byte*)data.buf , data.size , this );
255252 else
256253 _serverProtocol->consume ((byte*)data.buf , data.size , this );
257- opToSend = _opToSend;
258- msgToSend = std::move (_msgToSend);
259254 // Compute # of bytes consumed: just the framing data, not any partial or
260255 // delivered messages. (Trust me, the math works.)
261256 completedBytes =
262257 narrow_cast<ssize_t >(data.size + prevMessageLength - _curMessageLength - _deliveredBytes);
263258 } else {
264259 _deliverMessageToDelegate (alloc_slice (data));
265260 }
266- msgsReceived = std::move (_msgsReceived);
267261 }
268262
269263 // After unlocking, tell subclass how many incoming bytes have been handled:
270264 if ( completedBytes > 0 ) _webSocket.receiveComplete (completedBytes);
271-
272- // Send any message that was generated during the locked block above:
273- if ( msgToSend ) sendOp (msgToSend, opToSend);
274-
275- // Similarly, deliver any messages received:
276- for ( auto & msg : msgsReceived ) _webSocket.delegateWeak ()->invoke (&Delegate::onWebSocketMessage, msg);
277265 }
278266
279267 // Called from inside _protocol->consume(), with the _mutex locked
@@ -318,10 +306,12 @@ namespace litecore::websocket {
318306 case CLOSE:
319307 return _receivedClose (message);
320308 case PING:
321- logInfo (" Received PING -- sending PONG" );
322- _opToSend = PONG;
323- _msgToSend = message ? message : alloc_slice (size_t (0 ));
324- return true ;
309+ {
310+ logInfo (" Received PING -- sending PONG" );
311+ alloc_slice msgToSend = message ? message : alloc_slice (size_t (0 ));
312+ defer ([=, this ] { sendOp (msgToSend, PONG); });
313+ return true ;
314+ }
325315 case PONG:
326316 _receivedPong ();
327317 return true ;
@@ -337,13 +327,11 @@ namespace litecore::websocket {
337327 _callCloseSocket ();
338328 }
339329
340- void _deliverMessageToDelegate (alloc_slice message) {
341- logVerbose (" Received %zu-byte message" , message.size );
342- _deliveredBytes += message.size ;
343- // We can't call the delegate now because the mutex is locked and from here there's no
344- // way to unlock it. Instead, we store the message in `_msgsReceived`, and my caller
345- // `onReceive()` will unlock the mutex and then deliver it.
346- _msgsReceived.emplace_back (new MessageImpl (&_webSocket, std::move (message), true ));
330+ void _deliverMessageToDelegate (alloc_slice messageBody) {
331+ logVerbose (" Received %zu-byte message" , messageBody.size );
332+ _deliveredBytes += messageBody.size ;
333+ auto message = make_retained<MessageImpl>(&_webSocket, std::move (messageBody), true );
334+ defer ([=, this ] { _webSocket.delegateWeak ()->invoke (&Delegate::onWebSocketMessage, message); });
347335 }
348336
349337#pragma mark - HEARTBEAT:
@@ -395,7 +383,7 @@ namespace litecore::websocket {
395383
396384 // timer callback
397385 void timedOut () {
398- unique_lock lock (_mutex );
386+ LockWithDefer lock (this );
399387
400388 if ( _timerDisabled ) return ;
401389 if ( Timer::clock::now () - _lastReceiveTime < _curTimeout ) return ;
@@ -428,7 +416,7 @@ namespace litecore::websocket {
428416 if ( state != SOCKET_OPENED ) { logVerbose (" Calling closeSocket before the socket is open" ); }
429417 _socketLCState = SOCKET_CLOSING;
430418 _startResponseTimer (kCloseTimeout );
431- _webSocket.closeSocket ();
419+ defer ([ this ] { _webSocket.closeSocket (); } );
432420 } else {
433421 logVerbose (" Calling closeSocket when the socket is %s" ,
434422 state == SOCKET_CLOSING ? " pending close" : " already closed" );
@@ -442,10 +430,13 @@ namespace litecore::websocket {
442430 logVerbose (" Calling requestClose before the socket is connected" );
443431 [[fallthrough]];
444432 case SOCKET_OPENED:
445- _socketLCState = SOCKET_CLOSING;
446- _startResponseTimer (kCloseTimeout );
447- _webSocket.requestClose (status, message);
448- break ;
433+ {
434+ _socketLCState = SOCKET_CLOSING;
435+ _startResponseTimer (kCloseTimeout );
436+ alloc_slice allocedMessage (message);
437+ defer ([=, this ] { _webSocket.requestClose (status, allocedMessage); });
438+ break ;
439+ }
449440 case SOCKET_CLOSING:
450441 logVerbose (" Calling requestClose when the socket is pending close" );
451442 break ;
@@ -457,23 +448,23 @@ namespace litecore::websocket {
457448
458449 // Public API. Initiates a request to close the connection cleanly.
459450 void close (int status, slice message) {
460- unique_lock lock (_mutex );
451+ LockWithDefer lock (this );
461452
462453 switch ( _socketLCState ) {
463454 case SOCKET_CLOSING:
464455 logVerbose (" Calling close when the socket is pending close" );
465- return ;
456+ break ;
466457 case SOCKET_CLOSED:
467458 logVerbose (" Calling close when the socket is already closed" );
468- return ;
459+ break ;
469460 case SOCKET_OPENED:
470461 logInfo (" Requesting close with status=%d, message='%.*s'" , status, SPLAT (message));
471462 if ( _framing ) {
472463 if ( _closeSent || _closeReceived ) {
473464 logVerbose (" Close already processed (_closeSent: %d, _closeReceived: %d), exiting "
474465 " WebSocketImpl::close()" ,
475466 (int )_closeSent, (int )_closeReceived);
476- return ;
467+ break ;
477468 }
478469
479470 auto closeMsg = alloc_slice (2 + message.size );
@@ -483,13 +474,11 @@ namespace litecore::websocket {
483474 _closeSent = true ;
484475 _closeMessage = closeMsg;
485476 _startResponseTimer (kCloseTimeout );
486-
487- lock.unlock (); // UNLOCK MUTEX to call sendOp()
488- sendOp (closeMsg, CLOSE);
477+ defer ([=, this ] { sendOp (closeMsg, CLOSE); });
489478 } else {
490479 _callRequestClose (status, message);
491480 }
492- return ;
481+ break ;
493482 case SOCKET_OPENING:
494483 logInfo (" Closing socket before connection established..." );
495484 if ( _framing ) {
@@ -499,10 +488,10 @@ namespace litecore::websocket {
499488 } else {
500489 _callRequestClose (status, message);
501490 }
502- return ;
491+ break ;
503492 case SOCKET_UNINIT:
504493 _callCloseSocket ();
505- return ;
494+ break ;
506495 }
507496 }
508497
@@ -523,9 +512,7 @@ namespace litecore::websocket {
523512 }
524513 _closeSent = true ;
525514 _closeMessage = message;
526- // Don't send the message now or I'll deadlock; remember to do it later in onReceive:
527- _msgToSend = message;
528- _opToSend = CLOSE;
515+ defer ([=, this ] { sendOp (_closeMessage, CLOSE); });
529516 }
530517 _timerDisabled = true ;
531518 return true ;
@@ -640,6 +627,54 @@ namespace litecore::websocket {
640627 delegate->invoke (&Delegate::onWebSocketClose, status);
641628 }
642629 }
630+
631+ /* * Utility class that locks `_mutex` and enables use of the `defer()` function below.
632+ * It should be instantiated as `LockWithDefer lock(this);`. */
633+ class LockWithDefer {
634+ public:
635+ explicit LockWithDefer (impl* owner) : _owner(owner), _lock(owner->_mutex) {
636+ DebugAssert (owner->_lockWithDefer == nullptr );
637+ owner->_lockWithDefer = this ;
638+ }
639+
640+ void defer (function<void ()> action) {
641+ DebugAssert (_lock.owns_lock ());
642+ _actions.emplace_back (std::move (action));
643+ }
644+
645+ void unlock () {
646+ DebugAssert (_lock.owns_lock ());
647+ _owner->_lockWithDefer = nullptr ;
648+ _lock.unlock ();
649+ }
650+
651+ ~LockWithDefer () {
652+ if ( _lock.owns_lock () ) _owner->_lockWithDefer = nullptr ;
653+ if ( !_actions.empty () ) {
654+ _lock.unlock ();
655+ for ( auto & action : _actions ) {
656+ try {
657+ action ();
658+ } catch ( ... ) {
659+ #ifdef _MSC_VER
660+ C4Error::warnCurrentException (__FUNCSIG__);
661+ #else
662+ C4Error::warnCurrentException (__PRETTY_FUNCTION__);
663+ #endif
664+ }
665+ }
666+ }
667+ }
668+
669+ private:
670+ impl* const _owner;
671+ unique_lock<mutex> _lock;
672+ vector<function<void ()>> _actions; // can't use smallVector: std::function is not trivially moveable
673+ };
674+
675+ // / Schedules a function to be called immediately after the current lock is released.
676+ // / Precondition: Some caller must have a local `LockWithDefer` instance.
677+ void defer (function<void ()> fn) { _lockWithDefer->defer (std::move (fn)); }
643678 };
644679
645680#pragma mark - WEBSOCKET IMPL:
0 commit comments