@@ -180,3 +180,288 @@ async fn run_send_pump(
180180 }
181181 let _ = ws_sender. close ( ) . await ;
182182}
183+
184+ #[ cfg( test) ]
185+ mod tests {
186+ use super :: * ;
187+ use axum:: extract:: State ;
188+ use axum:: extract:: ws:: WebSocketUpgrade ;
189+ use axum:: response:: Response ;
190+ use std:: time:: Duration ;
191+ use tokio:: io:: AsyncReadExt ;
192+ use tokio:: net:: TcpListener ;
193+ use tokio_tungstenite:: connect_async;
194+ use tokio_tungstenite:: tungstenite:: Message as TungsteniteMessage ;
195+
196+ #[ derive( Clone ) ]
197+ struct EchoState ;
198+
199+ async fn echo_handler ( ws : WebSocketUpgrade , State ( _) : State < EchoState > ) -> Response {
200+ ws. on_upgrade ( |socket| async move {
201+ let ( ws_sender, ws_receiver) = socket. split ( ) ;
202+ let ( duplex_write, duplex_read) = tokio:: io:: duplex ( DUPLEX_BUFFER_SIZE ) ;
203+ let recv = run_recv_pump ( ws_receiver, duplex_write) ;
204+ let send = run_send_pump ( ws_sender, duplex_read) ;
205+ tokio:: join!( recv, send) ;
206+ } )
207+ }
208+
209+ async fn start_echo_server ( ) -> String {
210+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
211+ let addr = listener. local_addr ( ) . unwrap ( ) ;
212+ let app = axum:: Router :: new ( )
213+ . route ( "/ws" , axum:: routing:: get ( echo_handler) )
214+ . with_state ( EchoState ) ;
215+ tokio:: spawn ( async move {
216+ axum:: serve ( listener, app) . await . unwrap ( ) ;
217+ } ) ;
218+ format ! ( "ws://{}/ws" , addr)
219+ }
220+
221+ #[ tokio:: test]
222+ async fn recv_pump_routes_text_messages_to_duplex ( ) {
223+ let url = start_echo_server ( ) . await ;
224+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
225+
226+ ws. send ( TungsteniteMessage :: Text ( "hello" . into ( ) ) )
227+ . await
228+ . unwrap ( ) ;
229+
230+ let msg = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) )
231+ . await
232+ . expect ( "timeout" )
233+ . expect ( "stream ended" )
234+ . unwrap ( ) ;
235+
236+ match msg {
237+ TungsteniteMessage :: Text ( t) => assert_eq ! ( t, "hello" ) ,
238+ other => panic ! ( "expected Text, got {other:?}" ) ,
239+ }
240+ }
241+
242+ #[ tokio:: test]
243+ async fn recv_pump_routes_binary_utf8_messages ( ) {
244+ let url = start_echo_server ( ) . await ;
245+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
246+
247+ ws. send ( TungsteniteMessage :: Binary ( b"binary-text" . to_vec ( ) . into ( ) ) )
248+ . await
249+ . unwrap ( ) ;
250+
251+ let msg = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) )
252+ . await
253+ . expect ( "timeout" )
254+ . expect ( "stream ended" )
255+ . unwrap ( ) ;
256+
257+ match msg {
258+ TungsteniteMessage :: Text ( t) => assert_eq ! ( t, "binary-text" ) ,
259+ other => panic ! ( "expected Text, got {other:?}" ) ,
260+ }
261+ }
262+
263+ #[ tokio:: test]
264+ async fn recv_pump_drops_non_utf8_binary ( ) {
265+ let url = start_echo_server ( ) . await ;
266+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
267+
268+ ws. send ( TungsteniteMessage :: Binary ( vec ! [ 0xFF , 0xFE ] . into ( ) ) )
269+ . await
270+ . unwrap ( ) ;
271+
272+ ws. send ( TungsteniteMessage :: Text ( "after-invalid" . into ( ) ) )
273+ . await
274+ . unwrap ( ) ;
275+
276+ let msg = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) )
277+ . await
278+ . expect ( "timeout" )
279+ . expect ( "stream ended" )
280+ . unwrap ( ) ;
281+
282+ match msg {
283+ TungsteniteMessage :: Text ( t) => assert_eq ! ( t, "after-invalid" ) ,
284+ other => panic ! ( "expected Text, got {other:?}" ) ,
285+ }
286+ }
287+
288+ #[ tokio:: test]
289+ async fn recv_pump_skips_empty_text ( ) {
290+ let url = start_echo_server ( ) . await ;
291+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
292+
293+ ws. send ( TungsteniteMessage :: Text ( "" . into ( ) ) ) . await . unwrap ( ) ;
294+ ws. send ( TungsteniteMessage :: Text ( "\n " . into ( ) ) )
295+ . await
296+ . unwrap ( ) ;
297+ ws. send ( TungsteniteMessage :: Text ( "\r \n " . into ( ) ) )
298+ . await
299+ . unwrap ( ) ;
300+ ws. send ( TungsteniteMessage :: Text ( "real" . into ( ) ) )
301+ . await
302+ . unwrap ( ) ;
303+
304+ let msg = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) )
305+ . await
306+ . expect ( "timeout" )
307+ . expect ( "stream ended" )
308+ . unwrap ( ) ;
309+
310+ match msg {
311+ TungsteniteMessage :: Text ( t) => assert_eq ! ( t, "real" ) ,
312+ other => panic ! ( "expected Text, got {other:?}" ) ,
313+ }
314+ }
315+
316+ #[ tokio:: test]
317+ async fn recv_pump_close_frame_stops_loop ( ) {
318+ let url = start_echo_server ( ) . await ;
319+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
320+
321+ ws. send ( TungsteniteMessage :: Close ( None ) ) . await . unwrap ( ) ;
322+
323+ let result = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) ) . await ;
324+ match result {
325+ Ok ( Some ( Ok ( TungsteniteMessage :: Close ( _) ) ) ) | Ok ( None ) => { }
326+ other => panic ! ( "expected close or stream end, got {other:?}" ) ,
327+ }
328+ }
329+
330+ #[ tokio:: test]
331+ async fn send_pump_skips_empty_lines ( ) {
332+ let url = start_echo_server ( ) . await ;
333+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
334+
335+ ws. send ( TungsteniteMessage :: Text ( "\n \n first\n \n \n second" . into ( ) ) )
336+ . await
337+ . unwrap ( ) ;
338+
339+ let msg1 = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) )
340+ . await
341+ . expect ( "timeout" )
342+ . expect ( "stream ended" )
343+ . unwrap ( ) ;
344+ match msg1 {
345+ TungsteniteMessage :: Text ( t) => assert_eq ! ( t, "first" ) ,
346+ other => panic ! ( "expected 'first', got {other:?}" ) ,
347+ }
348+
349+ let msg2 = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) )
350+ . await
351+ . expect ( "timeout" )
352+ . expect ( "stream ended" )
353+ . unwrap ( ) ;
354+ match msg2 {
355+ TungsteniteMessage :: Text ( t) => assert_eq ! ( t, "second" ) ,
356+ other => panic ! ( "expected 'second', got {other:?}" ) ,
357+ }
358+ }
359+
360+ #[ tokio:: test]
361+ async fn send_pump_eof_closes_websocket ( ) {
362+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
363+ let addr = listener. local_addr ( ) . unwrap ( ) ;
364+
365+ async fn write_then_close_handler (
366+ ws : WebSocketUpgrade ,
367+ State ( _) : State < EchoState > ,
368+ ) -> Response {
369+ ws. on_upgrade ( |socket| async move {
370+ let ( ws_sender, _ws_receiver) = socket. split ( ) ;
371+ let ( mut duplex_write, duplex_read) = tokio:: io:: duplex ( DUPLEX_BUFFER_SIZE ) ;
372+ duplex_write. write_all ( b"msg\n " ) . await . unwrap ( ) ;
373+ drop ( duplex_write) ;
374+ run_send_pump ( ws_sender, duplex_read) . await ;
375+ } )
376+ }
377+
378+ let app = axum:: Router :: new ( )
379+ . route ( "/ws" , axum:: routing:: get ( write_then_close_handler) )
380+ . with_state ( EchoState ) ;
381+ tokio:: spawn ( async move {
382+ axum:: serve ( listener, app) . await . unwrap ( ) ;
383+ } ) ;
384+
385+ let url = format ! ( "ws://{}/ws" , addr) ;
386+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
387+
388+ let msg = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) )
389+ . await
390+ . expect ( "timeout" )
391+ . expect ( "stream ended" )
392+ . unwrap ( ) ;
393+ match msg {
394+ TungsteniteMessage :: Text ( t) => assert_eq ! ( t, "msg" ) ,
395+ other => panic ! ( "expected Text('msg'), got {other:?}" ) ,
396+ }
397+
398+ let close = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) ) . await ;
399+ match close {
400+ Ok ( Some ( Ok ( TungsteniteMessage :: Close ( _) ) ) ) | Ok ( None ) => { }
401+ other => panic ! ( "expected close after EOF, got {other:?}" ) ,
402+ }
403+ }
404+
405+ #[ tokio:: test]
406+ async fn recv_pump_breaks_when_duplex_write_closed ( ) {
407+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
408+ let addr = listener. local_addr ( ) . unwrap ( ) ;
409+
410+ async fn drop_duplex_handler ( ws : WebSocketUpgrade , State ( _) : State < EchoState > ) -> Response {
411+ ws. on_upgrade ( |socket| async move {
412+ let ( ws_sender, ws_receiver) = socket. split ( ) ;
413+ let ( duplex_write, _duplex_read) = tokio:: io:: duplex ( 64 ) ;
414+ drop ( _duplex_read) ;
415+ run_recv_pump ( ws_receiver, duplex_write) . await ;
416+ let _ = ws_sender;
417+ } )
418+ }
419+
420+ let app = axum:: Router :: new ( )
421+ . route ( "/ws" , axum:: routing:: get ( drop_duplex_handler) )
422+ . with_state ( EchoState ) ;
423+ tokio:: spawn ( async move {
424+ axum:: serve ( listener, app) . await . unwrap ( ) ;
425+ } ) ;
426+
427+ let url = format ! ( "ws://{}/ws" , addr) ;
428+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
429+
430+ for _ in 0 ..100 {
431+ if ws
432+ . send ( TungsteniteMessage :: Text ( "data" . into ( ) ) )
433+ . await
434+ . is_err ( )
435+ {
436+ break ;
437+ }
438+ }
439+
440+ let _ = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) ) . await ;
441+ }
442+
443+ #[ tokio:: test]
444+ async fn multiple_messages_round_trip ( ) {
445+ let url = start_echo_server ( ) . await ;
446+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
447+
448+ let messages = vec ! [ "alpha" , "beta" , "gamma" ] ;
449+ for msg in & messages {
450+ ws. send ( TungsteniteMessage :: Text ( ( * msg) . into ( ) ) )
451+ . await
452+ . unwrap ( ) ;
453+ }
454+
455+ for expected in & messages {
456+ let msg = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) )
457+ . await
458+ . expect ( "timeout" )
459+ . expect ( "stream ended" )
460+ . unwrap ( ) ;
461+ match msg {
462+ TungsteniteMessage :: Text ( t) => assert_eq ! ( t, * expected) ,
463+ other => panic ! ( "expected Text('{expected}'), got {other:?}" ) ,
464+ }
465+ }
466+ }
467+ }
0 commit comments