@@ -180,3 +180,66 @@ async fn run_send_pump(
180180 }
181181 let _ = ws_sender. close ( ) . await ;
182182}
183+
184+ #[ cfg( test) ]
185+ mod tests {
186+ use super :: * ;
187+ use axum:: extract:: State ;
188+ use axum:: extract:: ws:: WebSocketUpgrade ;
189+ use axum:: response:: Response ;
190+ use std:: time:: Duration ;
191+ use tokio:: io:: AsyncReadExt ;
192+ use tokio:: net:: TcpListener ;
193+ use tokio_tungstenite:: connect_async;
194+ use tokio_tungstenite:: tungstenite:: Message as TungsteniteMessage ;
195+
196+ #[ derive( Clone ) ]
197+ struct EchoState ;
198+
199+ async fn echo_handler ( ws : WebSocketUpgrade , State ( _) : State < EchoState > ) -> Response {
200+ ws. on_upgrade ( |socket| async move {
201+ let ( ws_sender, ws_receiver) = socket. split ( ) ;
202+ let ( duplex_write, duplex_read) = tokio:: io:: duplex ( DUPLEX_BUFFER_SIZE ) ;
203+ let recv = run_recv_pump ( ws_receiver, duplex_write) ;
204+ let send = run_send_pump ( ws_sender, duplex_read) ;
205+ tokio:: join!( recv, send) ;
206+ } )
207+ }
208+
209+ async fn start_echo_server ( ) -> String {
210+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
211+ let addr = listener. local_addr ( ) . unwrap ( ) ;
212+ let app = axum:: Router :: new ( )
213+ . route ( "/ws" , axum:: routing:: get ( echo_handler) )
214+ . with_state ( EchoState ) ;
215+ tokio:: spawn ( async move {
216+ axum:: serve ( listener, app) . await . unwrap ( ) ;
217+ } ) ;
218+ format ! ( "ws://{}/ws" , addr)
219+ }
220+
221+ #[ tokio:: test]
222+ async fn multiple_messages_round_trip ( ) {
223+ let url = start_echo_server ( ) . await ;
224+ let ( mut ws, _) = connect_async ( & url) . await . unwrap ( ) ;
225+
226+ let messages = vec ! [ "alpha" , "beta" , "gamma" ] ;
227+ for msg in & messages {
228+ ws. send ( TungsteniteMessage :: Text ( ( * msg) . into ( ) ) )
229+ . await
230+ . unwrap ( ) ;
231+ }
232+
233+ for expected in & messages {
234+ let msg = tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , ws. next ( ) )
235+ . await
236+ . expect ( "timeout" )
237+ . expect ( "stream ended" )
238+ . unwrap ( ) ;
239+ match msg {
240+ TungsteniteMessage :: Text ( t) => assert_eq ! ( t, * expected) ,
241+ other => panic ! ( "expected Text('{expected}'), got {other:?}" ) ,
242+ }
243+ }
244+ }
245+ }
0 commit comments