@@ -22,7 +22,7 @@ module Network.Transport.QUIC.Internal
2222where
2323
2424import Control.Concurrent (forkIO , killThread , modifyMVar_ , newEmptyMVar , readMVar )
25- import Control.Concurrent.MVar (modifyMVar , putMVar , takeMVar , withMVar )
25+ import Control.Concurrent.MVar (modifyMVar , putMVar , takeMVar , tryPutMVar , withMVar )
2626import Control.Concurrent.STM (atomically , newTQueueIO )
2727import Control.Concurrent.STM.TQueue
2828 ( TQueue ,
@@ -34,13 +34,11 @@ import Control.Monad (unless, when)
3434import Data.Bifunctor (Bifunctor (first ))
3535import Data.Binary qualified as Binary (decodeOrFail )
3636import Data.ByteString (ByteString , fromStrict )
37- import Data.Foldable (forM_ )
3837import Data.Function ((&) )
3938import Data.Functor ((<&>) )
4039import Data.IORef (newIORef , readIORef , writeIORef )
4140import Data.List.NonEmpty (NonEmpty )
4241import Data.Map.Strict qualified as Map
43- import Data.Maybe (isNothing )
4442import Lens.Micro.Platform ((+~) )
4543import Network.QUIC qualified as QUIC
4644import Network.TLS (Credential )
@@ -62,8 +60,7 @@ import Network.Transport
6260 )
6361import Network.Transport.QUIC.Internal.Configuration (credentialLoadX509 )
6462import Network.Transport.QUIC.Internal.Messaging
65- ( ClientConnId ,
66- MessageReceived (.. ),
63+ ( MessageReceived (.. ),
6764 createConnectionId ,
6865 decodeMessage ,
6966 encodeMessage ,
@@ -101,7 +98,6 @@ import Network.Transport.QUIC.Internal.QUICTransport
10198 nextSelfConnOutId ,
10299 remoteEndPointAddress ,
103100 remoteEndPointState ,
104- remoteIncoming ,
105101 remoteServerConnId ,
106102 remoteStream ,
107103 transportConfig ,
@@ -182,21 +178,15 @@ handleNewStream quicTransport stream = do
182178 (remoteEndPoint, _) <- either throwIO pure =<< createRemoteEndPoint ourEndPoint remoteAddress Incoming
183179 doneMVar <- newEmptyMVar
184180
185- -- Sending an ack is important, because otherwise
186- -- the client may start sending messages well before we
187- -- start being able to receive them
188-
189- clientConnId <- either (throwIO . userError ) (pure . fromIntegral ) =<< recvWord32 stream
190181 let serverConnId = remoteServerConnId remoteEndPoint
191- connectionId = createConnectionId serverConnId clientConnId
182+ -- One logical connection per stream; clientConnId is always 0.
183+ connectionId = createConnectionId serverConnId 0
192184
193185 let st =
194186 RemoteEndPointValid $
195187 ValidRemoteEndPointState
196188 { _remoteStream = stream,
197- _remoteStreamIsClosed = doneMVar,
198- _remoteIncoming = Just clientConnId,
199- _remoteNextConnOutId = 0
189+ _remoteStreamIsClosed = doneMVar
200190 }
201191 modifyMVar_
202192 (remoteEndPoint ^. remoteEndPointState)
@@ -218,6 +208,13 @@ handleNewStream quicTransport stream = do
218208 remoteAddress
219209 )
220210
211+ -- Second handshake ack: only sent after ConnectionOpened is
212+ -- enqueued. The initiator's @connect@ call blocks on this ack, so
213+ -- when it returns, the caller can trust that any peer observing
214+ -- events on this endpoint will see ConnectionOpened before any
215+ -- messages the caller subsequently sends on other connections.
216+ sendAck stream
217+
221218 tid <-
222219 forkIO $
223220 handleIncomingMessages
@@ -250,12 +247,8 @@ handleIncomingMessages ourEndPoint remoteEndPoint =
250247 release (Left err) = closeRemoteEndPoint Incoming remoteEndPoint >> prematureExit err
251248 release (Right _) = closeRemoteEndPoint Incoming remoteEndPoint
252249
253- connectionId = createConnectionId serverConnId
254-
255- writeConnectionClosedSTM connId =
256- writeTQueue
257- ourQueue
258- (ConnectionClosed (connectionId connId))
250+ -- One logical connection per stream; clientConnId is always 0.
251+ connectionId = createConnectionId serverConnId 0
259252
260253 go = either prematureExit loop
261254
@@ -265,34 +258,31 @@ handleIncomingMessages ourEndPoint remoteEndPoint =
265258 Left errmsg -> do
266259 -- Throwing will trigger 'prematureExit'
267260 throwIO $ userError $ " (handleIncomingMessages) Failed with: " <> errmsg
268- Right (Message connId bytes) -> handleMessage connId bytes >> loop stream
261+ Right (Message bytes) -> handleMessage bytes >> loop stream
269262 Right StreamClosed -> throwIO $ userError " (handleIncomingMessages) Stream closed"
270- Right ( CloseConnection connId) -> do
271- atomically (writeConnectionClosedSTM connId )
263+ Right CloseConnection -> do
264+ atomically (writeTQueue ourQueue ( ConnectionClosed connectionId) )
272265 mAct <- modifyMVar (remoteEndPoint ^. remoteEndPointState) $ \ case
273266 RemoteEndPointInit -> pure (RemoteEndPointClosed , Nothing )
274267 RemoteEndPointClosed -> pure (RemoteEndPointClosed , Nothing )
275- RemoteEndPointValid (ValidRemoteEndPointState _ isClosed _ _ ) -> do
268+ RemoteEndPointValid (ValidRemoteEndPointState _ isClosed) -> do
276269 pure (RemoteEndPointClosed , Just $ putMVar isClosed () )
277270 case mAct of
278271 Nothing -> pure ()
279272 Just cleanup -> cleanup
280273 Right CloseEndPoint -> do
281- connIds <- modifyMVar (remoteEndPoint ^. remoteEndPointState) $ \ case
282- RemoteEndPointValid vst -> do
283- pure (RemoteEndPointClosed , vst ^. remoteIncoming)
284- other -> pure (other, Nothing )
285- unless
286- (isNothing connIds)
287- ( atomically $
288- forM_
289- connIds
290- (writeTQueue ourQueue . ConnectionClosed . connectionId)
291- )
274+ -- handleIncomingMessages only runs on incoming remote endpoints, so if
275+ -- the state was still Valid there is exactly one logical connection to
276+ -- surface as closed.
277+ wasValid <- modifyMVar (remoteEndPoint ^. remoteEndPointState) $ \ case
278+ RemoteEndPointValid _ -> pure (RemoteEndPointClosed , True )
279+ other -> pure (other, False )
280+ when wasValid $
281+ atomically $ writeTQueue ourQueue (ConnectionClosed connectionId)
292282
293- handleMessage :: ClientConnId -> [ByteString ] -> IO ()
294- handleMessage clientConnId payload =
295- atomically (writeTQueue ourQueue (Received ( connectionId clientConnId) payload))
283+ handleMessage :: [ByteString ] -> IO ()
284+ handleMessage payload =
285+ atomically (writeTQueue ourQueue (Received connectionId payload))
296286
297287 prematureExit :: IOException -> IO ()
298288 prematureExit exc = do
@@ -360,24 +350,24 @@ newConnection ourEndPoint creds validateCreds remoteAddress _reliability _connec
360350 else
361351 createConnectionTo creds validateCreds ourEndPoint remoteAddress >>= \ case
362352 Left err -> pure $ Left err
363- Right ( remoteEndPoint, connId) -> do
353+ Right remoteEndPoint -> do
364354 connAlive <- newIORef True
365355 pure
366356 . Right
367357 $ Connection
368- { send = sendConn remoteEndPoint connAlive connId ,
369- close = closeConn remoteEndPoint connAlive connId
358+ { send = sendConn remoteEndPoint connAlive,
359+ close = closeConn remoteEndPoint connAlive
370360 }
371361 where
372362 ourAddress = ourEndPoint ^. localAddress
373- sendConn remoteEndPoint connAlive connId packets =
363+ sendConn remoteEndPoint connAlive packets =
374364 readMVar (remoteEndPoint ^. remoteEndPointState) >>= \ case
375365 RemoteEndPointInit -> undefined
376366 RemoteEndPointValid vst ->
377367 readIORef connAlive >>= \ case
378368 False -> pure . Left $ TransportError SendClosed " Connection closed"
379369 True ->
380- sendMessage (vst ^. remoteStream) connId packets
370+ sendMessage (vst ^. remoteStream) packets
381371 <&> first (TransportError SendFailed . show )
382372 RemoteEndPointClosed -> do
383373 readIORef connAlive >>= \ case
@@ -387,16 +377,21 @@ newConnection ourEndPoint creds validateCreds remoteAddress _reliability _connec
387377 -- 'connAlive' IORefs.
388378 False -> pure . Left $ TransportError SendClosed " Connection closed"
389379 True -> pure . Left $ TransportError SendFailed " Remote endpoint closed"
390- closeConn remoteEndPoint connAlive connId = do
380+ closeConn remoteEndPoint connAlive = do
391381 mCleanup <- modifyMVar (remoteEndPoint ^. remoteEndPointState) $ \ case
392- RemoteEndPointValid vst@ (ValidRemoteEndPointState stream isClosed _ _ ) -> do
382+ RemoteEndPointValid vst@ (ValidRemoteEndPointState stream isClosed) -> do
393383 readIORef connAlive >>= \ case
394384 False -> pure (RemoteEndPointValid vst, Nothing )
395385 True -> do
396386 writeIORef connAlive False
397- -- We want to run this cleanup action OUTSIDE of the MVar modification
398- let cleanup = sendCloseConnection connId stream
399- pure (RemoteEndPointClosed , Just $ cleanup >> putMVar isClosed () )
387+ -- Run cleanup OUTSIDE the MVar modification. tryPutMVar keeps this
388+ -- safe against races with the finally in streamToEndpoint that can
389+ -- also signal isClosed on QUIC.Client.run exit.
390+ let cleanup = do
391+ _ <- sendCloseConnection stream
392+ _ <- tryPutMVar isClosed ()
393+ pure ()
394+ pure (RemoteEndPointClosed , Just cleanup)
400395 _ -> pure (RemoteEndPointClosed , Nothing )
401396
402397 case mCleanup of
0 commit comments