Skip to content

Commit abbb45a

Browse files
authored
Addressed race conditions and flaky tests in QUIC transport (#485)
Addressed various data races in the `network-transport-quic` transport, including on new connections and on connections lost. As a result, the test suite is now much less flaky
1 parent 567023a commit abbb45a

7 files changed

Lines changed: 172 additions & 150 deletions

File tree

packages/network-transport-quic/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
* Eliminated a rare race condition that allowed the transport to read messages before
55
marking the connection as open, violating the interface expectations.
6+
* Better handling of lost connections.
7+
* Fixed rare race condition in establishing connections.
68

79
2026-01-01 Laurent P. René de Cotret <laurent.decotret@outlook.com> 0.1.1
810

packages/network-transport-quic/bench/Bench.hs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,15 @@ quicConfig =
5858
>>= \case
5959
Left errmsg -> throwIO $ userError errmsg
6060
Right credentials ->
61-
QUIC.createTransport "127.0.0.1" "0" (credentials :| [])
61+
QUIC.createTransport
62+
( QUIC.QUICTransportConfig
63+
{ hostName = "127.0.0.1"
64+
, serviceName = "0"
65+
, credentials = credentials :| []
66+
, -- credentials are self-signed
67+
validateCredentials = False
68+
}
69+
)
6270
}
6371

6472
data BenchParams = BenchParams
@@ -115,8 +123,7 @@ throughputBench TransportConfig{mkTransport} BenchParams{messageSize, messageCou
115123
connections <-
116124
replicateM
117125
connectionCount
118-
( connect senderEP receiverAddr ReliableOrdered defaultConnectHints >>= either throwIO pure
119-
)
126+
(connect senderEP receiverAddr ReliableOrdered defaultConnectHints >>= either throwIO pure)
120127

121128
takeMVar receiverReady
122129

packages/network-transport-quic/src/Network/Transport/QUIC/Internal.hs

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ module Network.Transport.QUIC.Internal
2222
where
2323

2424
import Control.Concurrent (forkIO, killThread, modifyMVar_, newEmptyMVar, readMVar)
25-
import Control.Concurrent.MVar (modifyMVar, putMVar, takeMVar, withMVar)
25+
import Control.Concurrent.MVar (modifyMVar, putMVar, takeMVar, tryPutMVar, withMVar)
2626
import Control.Concurrent.STM (atomically, newTQueueIO)
2727
import Control.Concurrent.STM.TQueue
2828
( TQueue,
@@ -34,13 +34,11 @@ import Control.Monad (unless, when)
3434
import Data.Bifunctor (Bifunctor (first))
3535
import Data.Binary qualified as Binary (decodeOrFail)
3636
import Data.ByteString (ByteString, fromStrict)
37-
import Data.Foldable (forM_)
3837
import Data.Function ((&))
3938
import Data.Functor ((<&>))
4039
import Data.IORef (newIORef, readIORef, writeIORef)
4140
import Data.List.NonEmpty (NonEmpty)
4241
import Data.Map.Strict qualified as Map
43-
import Data.Maybe (isNothing)
4442
import Lens.Micro.Platform ((+~))
4543
import Network.QUIC qualified as QUIC
4644
import Network.TLS (Credential)
@@ -62,8 +60,7 @@ import Network.Transport
6260
)
6361
import Network.Transport.QUIC.Internal.Configuration (credentialLoadX509)
6462
import Network.Transport.QUIC.Internal.Messaging
65-
( ClientConnId,
66-
MessageReceived (..),
63+
( MessageReceived (..),
6764
createConnectionId,
6865
decodeMessage,
6966
encodeMessage,
@@ -101,7 +98,6 @@ import Network.Transport.QUIC.Internal.QUICTransport
10198
nextSelfConnOutId,
10299
remoteEndPointAddress,
103100
remoteEndPointState,
104-
remoteIncoming,
105101
remoteServerConnId,
106102
remoteStream,
107103
transportConfig,
@@ -182,21 +178,15 @@ handleNewStream quicTransport stream = do
182178
(remoteEndPoint, _) <- either throwIO pure =<< createRemoteEndPoint ourEndPoint remoteAddress Incoming
183179
doneMVar <- newEmptyMVar
184180

185-
-- Sending an ack is important, because otherwise
186-
-- the client may start sending messages well before we
187-
-- start being able to receive them
188-
189-
clientConnId <- either (throwIO . userError) (pure . fromIntegral) =<< recvWord32 stream
190181
let serverConnId = remoteServerConnId remoteEndPoint
191-
connectionId = createConnectionId serverConnId clientConnId
182+
-- One logical connection per stream; clientConnId is always 0.
183+
connectionId = createConnectionId serverConnId 0
192184

193185
let st =
194186
RemoteEndPointValid $
195187
ValidRemoteEndPointState
196188
{ _remoteStream = stream,
197-
_remoteStreamIsClosed = doneMVar,
198-
_remoteIncoming = Just clientConnId,
199-
_remoteNextConnOutId = 0
189+
_remoteStreamIsClosed = doneMVar
200190
}
201191
modifyMVar_
202192
(remoteEndPoint ^. remoteEndPointState)
@@ -218,6 +208,13 @@ handleNewStream quicTransport stream = do
218208
remoteAddress
219209
)
220210

211+
-- Second handshake ack: only sent after ConnectionOpened is
212+
-- enqueued. The initiator's @connect@ call blocks on this ack, so
213+
-- when it returns, the caller can trust that any peer observing
214+
-- events on this endpoint will see ConnectionOpened before any
215+
-- messages the caller subsequently sends on other connections.
216+
sendAck stream
217+
221218
tid <-
222219
forkIO $
223220
handleIncomingMessages
@@ -250,12 +247,8 @@ handleIncomingMessages ourEndPoint remoteEndPoint =
250247
release (Left err) = closeRemoteEndPoint Incoming remoteEndPoint >> prematureExit err
251248
release (Right _) = closeRemoteEndPoint Incoming remoteEndPoint
252249

253-
connectionId = createConnectionId serverConnId
254-
255-
writeConnectionClosedSTM connId =
256-
writeTQueue
257-
ourQueue
258-
(ConnectionClosed (connectionId connId))
250+
-- One logical connection per stream; clientConnId is always 0.
251+
connectionId = createConnectionId serverConnId 0
259252

260253
go = either prematureExit loop
261254

@@ -265,34 +258,31 @@ handleIncomingMessages ourEndPoint remoteEndPoint =
265258
Left errmsg -> do
266259
-- Throwing will trigger 'prematureExit'
267260
throwIO $ userError $ "(handleIncomingMessages) Failed with: " <> errmsg
268-
Right (Message connId bytes) -> handleMessage connId bytes >> loop stream
261+
Right (Message bytes) -> handleMessage bytes >> loop stream
269262
Right StreamClosed -> throwIO $ userError "(handleIncomingMessages) Stream closed"
270-
Right (CloseConnection connId) -> do
271-
atomically (writeConnectionClosedSTM connId)
263+
Right CloseConnection -> do
264+
atomically (writeTQueue ourQueue (ConnectionClosed connectionId))
272265
mAct <- modifyMVar (remoteEndPoint ^. remoteEndPointState) $ \case
273266
RemoteEndPointInit -> pure (RemoteEndPointClosed, Nothing)
274267
RemoteEndPointClosed -> pure (RemoteEndPointClosed, Nothing)
275-
RemoteEndPointValid (ValidRemoteEndPointState _ isClosed _ _) -> do
268+
RemoteEndPointValid (ValidRemoteEndPointState _ isClosed) -> do
276269
pure (RemoteEndPointClosed, Just $ putMVar isClosed ())
277270
case mAct of
278271
Nothing -> pure ()
279272
Just cleanup -> cleanup
280273
Right CloseEndPoint -> do
281-
connIds <- modifyMVar (remoteEndPoint ^. remoteEndPointState) $ \case
282-
RemoteEndPointValid vst -> do
283-
pure (RemoteEndPointClosed, vst ^. remoteIncoming)
284-
other -> pure (other, Nothing)
285-
unless
286-
(isNothing connIds)
287-
( atomically $
288-
forM_
289-
connIds
290-
(writeTQueue ourQueue . ConnectionClosed . connectionId)
291-
)
274+
-- handleIncomingMessages only runs on incoming remote endpoints, so if
275+
-- the state was still Valid there is exactly one logical connection to
276+
-- surface as closed.
277+
wasValid <- modifyMVar (remoteEndPoint ^. remoteEndPointState) $ \case
278+
RemoteEndPointValid _ -> pure (RemoteEndPointClosed, True)
279+
other -> pure (other, False)
280+
when wasValid $
281+
atomically $ writeTQueue ourQueue (ConnectionClosed connectionId)
292282

293-
handleMessage :: ClientConnId -> [ByteString] -> IO ()
294-
handleMessage clientConnId payload =
295-
atomically (writeTQueue ourQueue (Received (connectionId clientConnId) payload))
283+
handleMessage :: [ByteString] -> IO ()
284+
handleMessage payload =
285+
atomically (writeTQueue ourQueue (Received connectionId payload))
296286

297287
prematureExit :: IOException -> IO ()
298288
prematureExit exc = do
@@ -360,24 +350,24 @@ newConnection ourEndPoint creds validateCreds remoteAddress _reliability _connec
360350
else
361351
createConnectionTo creds validateCreds ourEndPoint remoteAddress >>= \case
362352
Left err -> pure $ Left err
363-
Right (remoteEndPoint, connId) -> do
353+
Right remoteEndPoint -> do
364354
connAlive <- newIORef True
365355
pure
366356
. Right
367357
$ Connection
368-
{ send = sendConn remoteEndPoint connAlive connId,
369-
close = closeConn remoteEndPoint connAlive connId
358+
{ send = sendConn remoteEndPoint connAlive,
359+
close = closeConn remoteEndPoint connAlive
370360
}
371361
where
372362
ourAddress = ourEndPoint ^. localAddress
373-
sendConn remoteEndPoint connAlive connId packets =
363+
sendConn remoteEndPoint connAlive packets =
374364
readMVar (remoteEndPoint ^. remoteEndPointState) >>= \case
375365
RemoteEndPointInit -> undefined
376366
RemoteEndPointValid vst ->
377367
readIORef connAlive >>= \case
378368
False -> pure . Left $ TransportError SendClosed "Connection closed"
379369
True ->
380-
sendMessage (vst ^. remoteStream) connId packets
370+
sendMessage (vst ^. remoteStream) packets
381371
<&> first (TransportError SendFailed . show)
382372
RemoteEndPointClosed -> do
383373
readIORef connAlive >>= \case
@@ -387,16 +377,21 @@ newConnection ourEndPoint creds validateCreds remoteAddress _reliability _connec
387377
-- 'connAlive' IORefs.
388378
False -> pure . Left $ TransportError SendClosed "Connection closed"
389379
True -> pure . Left $ TransportError SendFailed "Remote endpoint closed"
390-
closeConn remoteEndPoint connAlive connId = do
380+
closeConn remoteEndPoint connAlive = do
391381
mCleanup <- modifyMVar (remoteEndPoint ^. remoteEndPointState) $ \case
392-
RemoteEndPointValid vst@(ValidRemoteEndPointState stream isClosed _ _) -> do
382+
RemoteEndPointValid vst@(ValidRemoteEndPointState stream isClosed) -> do
393383
readIORef connAlive >>= \case
394384
False -> pure (RemoteEndPointValid vst, Nothing)
395385
True -> do
396386
writeIORef connAlive False
397-
-- We want to run this cleanup action OUTSIDE of the MVar modification
398-
let cleanup = sendCloseConnection connId stream
399-
pure (RemoteEndPointClosed, Just $ cleanup >> putMVar isClosed ())
387+
-- Run cleanup OUTSIDE the MVar modification. tryPutMVar keeps this
388+
-- safe against races with the finally in streamToEndpoint that can
389+
-- also signal isClosed on QUIC.Client.run exit.
390+
let cleanup = do
391+
_ <- sendCloseConnection stream
392+
_ <- tryPutMVar isClosed ()
393+
pure ()
394+
pure (RemoteEndPointClosed, Just cleanup)
400395
_ -> pure (RemoteEndPointClosed, Nothing)
401396

402397
case mCleanup of

packages/network-transport-quic/src/Network/Transport/QUIC/Internal/Client.hs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ where
1010

1111
import Control.Concurrent (forkIOWithUnmask, newEmptyMVar)
1212
import Control.Concurrent.Async (withAsync)
13-
import Control.Concurrent.MVar (MVar, putMVar, takeMVar)
14-
import Control.Exception (SomeException, bracket, catch, mask, mask_, throwIO)
13+
import Control.Concurrent.MVar (MVar, putMVar, takeMVar, tryPutMVar)
14+
import Control.Exception (SomeException, bracket, catch, finally, mask, mask_, throwIO)
1515
import Data.List.NonEmpty (NonEmpty)
1616
import Network.QUIC qualified as QUIC
1717
import Network.QUIC.Client qualified as QUIC.Client
@@ -28,9 +28,11 @@ streamToEndpoint ::
2828
EndPointAddress ->
2929
-- | Their address
3030
EndPointAddress ->
31-
-- | On exception
32-
(SomeException -> IO ()) ->
33-
-- | On a message to forcibly close the connection
31+
-- | Called when the QUIC connection or stream ends without us having initiated the
32+
-- close. Must be idempotent (the caller typically gates on remote endpoint state so
33+
-- that repeated invocations are safe) — this handler is invoked from multiple sites
34+
-- (peer-initiated close signal, QUIC.Client.run exception, thread finally) to cover
35+
-- every termination path.
3436
IO () ->
3537
IO
3638
( Either
@@ -40,7 +42,7 @@ streamToEndpoint ::
4042
QUIC.Stream
4143
)
4244
)
43-
streamToEndpoint creds validateCreds ourAddress theirAddress onExc onCloseForcibly =
45+
streamToEndpoint creds validateCreds ourAddress theirAddress onConnLoss =
4446
case decodeQUICAddr theirAddress of
4547
Left errmsg -> pure $ Left (TransportError ConnectNotFound errmsg)
4648
Right (QUICAddr hostname servicename _) -> do
@@ -75,7 +77,8 @@ streamToEndpoint creds validateCreds ourAddress theirAddress onExc onCloseForcib
7577
(throwIO @SomeException)
7678
)
7779
)
78-
onExc
80+
(\(_ :: SomeException) -> pure ())
81+
`finally` onConnLoss
7982

8083
streamOrError <- takeMVar streamMVar
8184

@@ -85,11 +88,24 @@ streamToEndpoint creds validateCreds ourAddress theirAddress onExc onCloseForcib
8588
listenForClose stream doneMVar =
8689
receiveMessage stream
8790
>>= \case
88-
Right StreamClosed -> do
89-
putMVar doneMVar ()
90-
Right (CloseConnection _) -> do
91-
putMVar doneMVar ()
92-
Right CloseEndPoint -> do
93-
putMVar doneMVar ()
94-
onCloseForcibly
91+
-- Any message from the peer on this stream means we're done listening.
92+
-- Peer-initiated closes (StreamClosed/CloseEndPoint) additionally call
93+
-- onConnLoss; the idempotent gate in the handler dedupes with the finally
94+
-- that also fires on QUIC.Client.run exit.
95+
--
96+
-- Mask signalling+onConnLoss as an atomic pair: tryPutMVar unblocks
97+
-- runClient's takeMVar, which causes withAsync to cancel this thread.
98+
-- Without mask, the async ThreadKilled can fire partway through
99+
-- onConnLoss, dropping the ErrorEvent. The finally in the parent thread
100+
-- is a backup but cannot recover if surfaceConnectionLost already
101+
-- transitioned the remote state to Closed.
102+
Right StreamClosed -> mask_ $ do
103+
_ <- tryPutMVar doneMVar ()
104+
onConnLoss
105+
Right CloseConnection ->
106+
-- Peer closed the logical connection cleanly; no ErrorEvent.
107+
() <$ tryPutMVar doneMVar ()
108+
Right CloseEndPoint -> mask_ $ do
109+
_ <- tryPutMVar doneMVar ()
110+
onConnLoss
95111
other -> throwIO . userError $ "Unexpected incoming message to client: " <> show other

0 commit comments

Comments
 (0)