1+ use std:: time:: Duration ;
2+
13use crate :: transport:: writer:: { WriterMessage , WriterRef } ;
24use tokio:: io:: { AsyncWrite , AsyncWriteExt , WriteHalf } ;
3- use tokio:: sync:: mpsc;
5+ use tokio:: sync:: { mpsc, oneshot } ;
46use tracing:: { debug, warn} ;
57
6- pub fn spawn_socket_writer ( writer : WriteHalf < impl AsyncWrite + Send + ' static > ) -> WriterRef {
8+ const WRITER_SHUTDOWN_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
9+
10+ pub fn spawn_socket_writer < W > ( writer : WriteHalf < W > ) -> ( WriterRef , oneshot:: Receiver < ( ) > )
11+ where
12+ W : AsyncWrite + Send + ' static ,
13+ {
714 let ( sender, mailbox) = mpsc:: channel ( 10 ) ;
15+ let ( exit_tx, exit_rx) = oneshot:: channel ( ) ;
816 let actor = WriterActor :: new ( writer, mailbox) ;
9- tokio:: spawn ( run_writer ( actor) ) ;
17+ tokio:: spawn ( run_writer ( actor, exit_tx ) ) ;
1018
11- WriterRef :: new ( sender)
19+ ( WriterRef :: new ( sender) , exit_rx )
1220}
1321
1422struct WriterActor < W > {
@@ -37,13 +45,23 @@ impl<W: AsyncWrite> WriterActor<W> {
3745 }
3846}
3947
40- async fn run_writer < W : AsyncWrite > ( mut actor : WriterActor < W > ) {
48+ async fn run_writer < W : AsyncWrite > ( mut actor : WriterActor < W > , exit_tx : oneshot :: Sender < ( ) > ) {
4149 while let Some ( msg) = actor. mailbox . recv ( ) . await {
4250 if !actor. handle ( msg) . await {
4351 break ;
4452 }
4553 }
4654
55+ match tokio:: time:: timeout ( WRITER_SHUTDOWN_TIMEOUT , actor. writer . shutdown ( ) ) . await {
56+ Ok ( Ok ( ( ) ) ) => debug ! ( "writer half closed cleanly" ) ,
57+ Ok ( Err ( err) ) => warn ! ( "writer shutdown returned error: {err}" ) ,
58+ Err ( _) => warn ! (
59+ "writer shutdown timed out after {:?}" ,
60+ WRITER_SHUTDOWN_TIMEOUT
61+ ) ,
62+ }
63+
64+ let _ = exit_tx. send ( ( ) ) ;
4765 debug ! ( "writer loop is shutting down" ) ;
4866}
4967
@@ -58,7 +76,7 @@ mod tests {
5876 async fn test_send_single_message ( ) {
5977 let ( reader, writer) = duplex ( 1024 ) ;
6078 let ( _reader_half, writer_half) = tokio:: io:: split ( writer) ;
61- let writer_ref = spawn_socket_writer ( writer_half) ;
79+ let ( writer_ref, _exit_rx ) = spawn_socket_writer ( writer_half) ;
6280
6381 let fix_message = b"8=FIX.4.4\x01 9=77\x01 35=A\x01 34=1\x01 49=sender\x01 52=20230908-08:24:56.574\x01 56=target\x01 98=0\x01 108=30\x01 141=Y\x01 10=037\x01 " ;
6482 let raw_message = RawFixMessage :: new ( fix_message. to_vec ( ) ) ;
@@ -84,7 +102,7 @@ mod tests {
84102 async fn test_send_multiple_messages ( ) {
85103 let ( reader, writer) = duplex ( 2048 ) ;
86104 let ( _reader_half, writer_half) = tokio:: io:: split ( writer) ;
87- let writer_ref = spawn_socket_writer ( writer_half) ;
105+ let ( writer_ref, _exit_rx ) = spawn_socket_writer ( writer_half) ;
88106
89107 let msg1 = b"8=FIX.4.4\x01 9=77\x01 35=A\x01 34=1\x01 49=sender\x01 52=20230908-08:24:56.574\x01 56=target\x01 98=0\x01 108=30\x01 141=Y\x01 10=037\x01 " ;
90108 let msg2 = b"8=FIX.4.4\x01 9=77\x01 35=A\x01 34=2\x01 49=sender\x01 52=20230908-08:24:58.574\x01 56=target\x01 98=0\x01 108=30\x01 141=Y\x01 10=040\x01 " ;
@@ -125,7 +143,7 @@ mod tests {
125143 async fn test_disconnect ( ) {
126144 let ( reader, writer) = duplex ( 1024 ) ;
127145 let ( _reader_half, writer_half) = tokio:: io:: split ( writer) ;
128- let writer_ref = spawn_socket_writer ( writer_half) ;
146+ let ( writer_ref, _exit_rx ) = spawn_socket_writer ( writer_half) ;
129147
130148 // send a message first
131149 let fix_message = b"8=FIX.4.4\x01 9=77\x01 35=A\x01 34=1\x01 49=sender\x01 52=20230908-08:24:56.574\x01 56=target\x01 98=0\x01 108=30\x01 141=Y\x01 10=037\x01 " ;
@@ -155,7 +173,7 @@ mod tests {
155173 async fn test_send_empty_message ( ) {
156174 let ( reader, writer) = duplex ( 1024 ) ;
157175 let ( _reader_half, writer_half) = tokio:: io:: split ( writer) ;
158- let writer_ref = spawn_socket_writer ( writer_half) ;
176+ let ( writer_ref, _exit_rx ) = spawn_socket_writer ( writer_half) ;
159177
160178 let empty_message = RawFixMessage :: new ( vec ! [ ] ) ;
161179 writer_ref. send_raw_message ( empty_message) . await ;
@@ -187,7 +205,7 @@ mod tests {
187205 async fn test_writer_shutdown_on_mailbox_close ( ) {
188206 let ( _reader, writer) = duplex ( 1024 ) ;
189207 let ( _reader_half, writer_half) = tokio:: io:: split ( writer) ;
190- let writer_ref = spawn_socket_writer ( writer_half) ;
208+ let ( writer_ref, _exit_rx ) = spawn_socket_writer ( writer_half) ;
191209
192210 // send a message to ensure the writer is running
193211 let fix_message = b"8=FIX.4.4\x01 9=77\x01 35=A\x01 34=1\x01 49=sender\x01 52=20230908-08:24:56.574\x01 56=target\x01 98=0\x01 108=30\x01 141=Y\x01 10=037\x01 " ;
@@ -210,7 +228,7 @@ mod tests {
210228 async fn test_write_error_handling ( ) {
211229 let ( reader, writer) = duplex ( 1024 ) ;
212230 let ( _reader_half, writer_half) = tokio:: io:: split ( writer) ;
213- let writer_ref = spawn_socket_writer ( writer_half) ;
231+ let ( writer_ref, _exit_rx ) = spawn_socket_writer ( writer_half) ;
214232
215233 // close the reader end, which should cause write errors
216234 drop ( reader) ;
@@ -231,4 +249,96 @@ mod tests {
231249 // and continued running (as per the code comment that it only shuts down
232250 // when explicitly requested)
233251 }
252+
253+ /// After processing Disconnect, the actor calls shutdown() on its WriteHalf,
254+ /// which for a duplex stream surfaces as EOF on the peer read side.
255+ #[ tokio:: test]
256+ async fn shutdown_called_on_disconnect ( ) {
257+ let ( reader, writer) = duplex ( 1024 ) ;
258+ let ( _reader_half, writer_half) = tokio:: io:: split ( writer) ;
259+ let ( writer_ref, exit_rx) = spawn_socket_writer ( writer_half) ;
260+
261+ writer_ref. disconnect ( ) . await ;
262+
263+ tokio:: time:: timeout ( tokio:: time:: Duration :: from_millis ( 200 ) , exit_rx)
264+ . await
265+ . expect ( "exit signal not fired within timeout" )
266+ . expect ( "exit sender dropped without signalling" ) ;
267+
268+ // Peer side of the duplex should observe EOF after shutdown.
269+ let mut reader = reader;
270+ let mut buf = vec ! [ 0u8 ; 16 ] ;
271+ let n = tokio:: time:: timeout (
272+ tokio:: time:: Duration :: from_millis ( 200 ) ,
273+ reader. read ( & mut buf) ,
274+ )
275+ . await
276+ . expect ( "read timed out — shutdown did not surface as EOF" )
277+ . expect ( "read failed" ) ;
278+ assert_eq ! ( n, 0 , "expected EOF after writer shutdown, read {n} bytes" ) ;
279+ }
280+
281+ /// Fallback exit path: all WriterRef clones dropped without sending Disconnect.
282+ /// The actor's mailbox closes, the loop exits, shutdown() runs, and exit fires.
283+ #[ tokio:: test]
284+ async fn exit_signal_fires_when_all_senders_dropped ( ) {
285+ let ( _reader, writer) = duplex ( 1024 ) ;
286+ let ( _reader_half, writer_half) = tokio:: io:: split ( writer) ;
287+ let ( writer_ref, exit_rx) = spawn_socket_writer ( writer_half) ;
288+
289+ drop ( writer_ref) ;
290+
291+ tokio:: time:: timeout ( tokio:: time:: Duration :: from_millis ( 200 ) , exit_rx)
292+ . await
293+ . expect ( "exit signal not fired within timeout" )
294+ . expect ( "exit sender dropped without signalling" ) ;
295+ }
296+
297+ use std:: pin:: Pin ;
298+ use std:: task:: { Context , Poll } ;
299+ use tokio:: io:: AsyncWrite ;
300+
301+ /// `AsyncWrite` where `poll_write` succeeds but `poll_shutdown` hangs forever.
302+ struct StuckShutdownWriter ;
303+
304+ impl AsyncWrite for StuckShutdownWriter {
305+ fn poll_write (
306+ self : Pin < & mut Self > ,
307+ _cx : & mut Context < ' _ > ,
308+ buf : & [ u8 ] ,
309+ ) -> Poll < std:: io:: Result < usize > > {
310+ Poll :: Ready ( Ok ( buf. len ( ) ) )
311+ }
312+
313+ fn poll_flush ( self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
314+ Poll :: Ready ( Ok ( ( ) ) )
315+ }
316+
317+ fn poll_shutdown (
318+ self : Pin < & mut Self > ,
319+ _cx : & mut Context < ' _ > ,
320+ ) -> Poll < std:: io:: Result < ( ) > > {
321+ Poll :: Pending
322+ }
323+ }
324+
325+ /// If shutdown() never resolves, the writer still exits after WRITER_SHUTDOWN_TIMEOUT.
326+ /// Virtual time via `start_paused = true` keeps the test fast.
327+ #[ tokio:: test( start_paused = true ) ]
328+ async fn shutdown_timeout_does_not_block_exit ( ) {
329+ // Build a split pair around StuckShutdownWriter. It only implements AsyncWrite;
330+ // we wrap with `tokio::io::join` to supply a dummy AsyncRead.
331+ let stuck = tokio:: io:: join ( tokio:: io:: empty ( ) , StuckShutdownWriter ) ;
332+ let ( _read_half, write_half) = tokio:: io:: split ( stuck) ;
333+ let ( writer_ref, exit_rx) = spawn_socket_writer ( write_half) ;
334+
335+ writer_ref. disconnect ( ) . await ;
336+
337+ // Advance virtual time past the shutdown timeout.
338+ tokio:: time:: advance ( WRITER_SHUTDOWN_TIMEOUT + std:: time:: Duration :: from_millis ( 100 ) )
339+ . await ;
340+
341+ // Exit should have fired by now.
342+ exit_rx. await . expect ( "exit sender dropped without signalling" ) ;
343+ }
234344}
0 commit comments