@@ -20,7 +20,7 @@ use std::{
2020 sync:: { Arc , atomic:: AtomicU64 } ,
2121 time:: Duration ,
2222} ;
23- use tokio:: sync:: { Mutex , watch} ;
23+ use tokio:: sync:: { Mutex , mpsc , watch} ;
2424use tokio_tungstenite:: tungstenite:: {
2525 Message ,
2626 protocol:: frame:: { CloseFrame , coding:: CloseCode } ,
@@ -286,6 +286,7 @@ impl PegboardGateway2 {
286286 . pegboard ( )
287287 . gateway_websocket_open_timeout_ms ( ) ,
288288 ) ,
289+ false ,
289290 )
290291 . await ?;
291292
@@ -321,6 +322,7 @@ impl PegboardGateway2 {
321322 . pegboard ( )
322323 . gateway_websocket_open_timeout_ms ( ) ,
323324 ) ,
325+ true ,
324326 )
325327 . await ?;
326328
@@ -339,6 +341,7 @@ impl PegboardGateway2 {
339341 . resend_pending_websocket_messages ( request_id)
340342 . await ?;
341343
344+ let initial_tunnel_messages = drain_ready_tunnel_messages ( request_id, & mut msg_rx) ;
342345 let ws_rx = client_ws. recv ( ) ;
343346
344347 let ( tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
@@ -354,6 +357,7 @@ impl PegboardGateway2 {
354357 stopped_sub,
355358 msg_rx,
356359 drop_rx,
360+ initial_tunnel_messages,
357361 can_hibernate,
358362 egress_bytes. clone ( ) ,
359363 tunnel_to_ws_abort_rx,
@@ -397,18 +401,25 @@ impl PegboardGateway2 {
397401 } else {
398402 None
399403 } ;
404+ let tunnel_to_ws_abort_handle = tunnel_to_ws. abort_handle ( ) ;
405+ let ws_to_tunnel_abort_handle = ws_to_tunnel. abort_handle ( ) ;
400406
401407 // Wait for all tasks to complete
402408 let ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res, metrics_res) = tokio:: join!(
403409 async {
404- let res = tunnel_to_ws. await ?;
410+ let res = match tunnel_to_ws. await {
411+ Ok ( res) => res,
412+ Err ( err) if err. is_cancelled( ) => Ok ( LifecycleResult :: Aborted ) ,
413+ Err ( err) => Err ( err. into( ) ) ,
414+ } ;
405415
406416 // Abort other if not aborted
407417 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
408418 tracing:: debug!( ?res, "tunnel to ws task completed, aborting counterpart" ) ;
409419
410420 let _ = ping_abort_tx. send( ( ) ) ;
411421 let _ = ws_to_tunnel_abort_tx. send( ( ) ) ;
422+ ws_to_tunnel_abort_handle. abort( ) ;
412423 let _ = keepalive_abort_tx. send( ( ) ) ;
413424 let _ = metrics_abort_tx. send( ( ) ) ;
414425 } else {
@@ -418,14 +429,19 @@ impl PegboardGateway2 {
418429 res
419430 } ,
420431 async {
421- let res = ws_to_tunnel. await ?;
432+ let res = match ws_to_tunnel. await {
433+ Ok ( res) => res,
434+ Err ( err) if err. is_cancelled( ) => Ok ( LifecycleResult :: Aborted ) ,
435+ Err ( err) => Err ( err. into( ) ) ,
436+ } ;
422437
423438 // Abort other if not aborted
424439 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
425440 tracing:: debug!( ?res, "ws to tunnel task completed, aborting counterpart" ) ;
426441
427442 let _ = ping_abort_tx. send( ( ) ) ;
428443 let _ = tunnel_to_ws_abort_tx. send( ( ) ) ;
444+ tunnel_to_ws_abort_handle. abort( ) ;
429445 let _ = keepalive_abort_tx. send( ( ) ) ;
430446 let _ = metrics_abort_tx. send( ( ) ) ;
431447 } else {
@@ -539,6 +555,15 @@ impl PegboardGateway2 {
539555 {
540556 tracing:: error!( ?err, "error sending close message" ) ;
541557 }
558+
559+ if matches ! ( lifecycle_res, Ok ( LifecycleResult :: ClientClose ( _) ) ) {
560+ ctx. op ( pegboard:: ops:: actor:: hibernating_request:: delete:: Input {
561+ actor_id : self . actor_id ,
562+ gateway_id : self . shared_state . gateway_id ( ) ,
563+ request_id,
564+ } )
565+ . await ?;
566+ }
542567 }
543568
544569 // Send WebSocket close message to client
@@ -674,10 +699,27 @@ impl CustomServeTrait for PegboardGateway2 {
674699 . await ?
675700 {
676701 tracing:: debug!( "exiting hibernating due to pending messages" ) ;
702+ tokio:: try_join!(
703+ ctx. op( pegboard:: ops:: actor:: hibernating_request:: upsert:: Input {
704+ actor_id: self . actor_id,
705+ gateway_id: self . shared_state. gateway_id( ) ,
706+ request_id,
707+ } ) ,
708+ self . shared_state. keepalive_hws( request_id) ,
709+ ) ?;
677710
678711 return Ok ( HibernationResult :: Continue ) ;
679712 }
680713
714+ tokio:: try_join!(
715+ ctx. op( pegboard:: ops:: actor:: hibernating_request:: upsert:: Input {
716+ actor_id: self . actor_id,
717+ gateway_id: self . shared_state. gateway_id( ) ,
718+ request_id,
719+ } ) ,
720+ self . shared_state. keepalive_hws( request_id) ,
721+ ) ?;
722+
681723 // Start keepalive task
682724 let ( keepalive_abort_tx, keepalive_abort_rx) = watch:: channel ( ( ) ) ;
683725 let keepalive_handle = tokio:: spawn ( keepalive_task:: task (
@@ -702,6 +744,17 @@ impl CustomServeTrait for PegboardGateway2 {
702744 let _ = keepalive_abort_tx. send ( ( ) ) ;
703745 let _ = keepalive_handle. await ;
704746
747+ if matches ! ( res, Ok ( HibernationResult :: Continue ) ) {
748+ tokio:: try_join!(
749+ ctx. op( pegboard:: ops:: actor:: hibernating_request:: upsert:: Input {
750+ actor_id: self . actor_id,
751+ gateway_id: self . shared_state. gateway_id( ) ,
752+ request_id,
753+ } ) ,
754+ self . shared_state. keepalive_hws( request_id) ,
755+ ) ?;
756+ }
757+
705758 let ( delete_res, _) = tokio:: join!(
706759 async {
707760 match & res {
@@ -810,9 +863,12 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
810863 Ok ( Message :: Binary ( _) ) | Ok ( Message :: Text ( _) ) => {
811864 return Ok ( HibernationResult :: Continue ) ;
812865 }
813- // We don't care about the close frame because we're currently hibernating; there is no
814- // downstream to send the close frame to.
815- Ok ( Message :: Close ( _) ) => return Ok ( HibernationResult :: Close ) ,
866+ // Consume the close frame so the websocket stack can complete the
867+ // close handshake while the actor is hibernating.
868+ Ok ( Message :: Close ( _) ) => {
869+ pinned. try_next ( ) . await ?;
870+ return Ok ( HibernationResult :: Close ) ;
871+ }
816872 // Ignore rest
817873 _ => {
818874 pinned. try_next ( ) . await ?;
@@ -824,47 +880,106 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
824880 }
825881}
826882
883+ fn drain_ready_tunnel_messages (
884+ request_id : protocol:: RequestId ,
885+ msg_rx : & mut mpsc:: Receiver < protocol:: ToRivetTunnelMessageKind > ,
886+ ) -> Vec < protocol:: ToRivetTunnelMessageKind > {
887+ let mut messages = Vec :: new ( ) ;
888+
889+ loop {
890+ match msg_rx. try_recv ( ) {
891+ Ok ( message) => messages. push ( message) ,
892+ Err ( mpsc:: error:: TryRecvError :: Empty ) => break ,
893+ Err ( mpsc:: error:: TryRecvError :: Disconnected ) => break ,
894+ }
895+ }
896+
897+ if !messages. is_empty ( ) {
898+ tracing:: debug!(
899+ request_id=%protocol:: util:: id_to_string( & request_id) ,
900+ message_count=messages. len( ) ,
901+ "drained ready tunnel messages before websocket forwarding task"
902+ ) ;
903+ }
904+
905+ messages
906+ }
907+
827908async fn wait_for_envoy_websocket_open (
828909 msg_rx : & mut tokio:: sync:: mpsc:: Receiver < protocol:: ToRivetTunnelMessageKind > ,
829910 drop_rx : & mut watch:: Receiver < Option < crate :: shared_state:: MsgGcReason > > ,
830911 stopped_sub : & mut message:: SubscriptionHandle < pegboard:: workflows:: actor2:: Stopped > ,
831912 request_id : protocol:: RequestId ,
832913 websocket_open_timeout : Duration ,
914+ stop_on_actor_stopped : bool ,
833915) -> Result < protocol:: ToRivetWebSocketOpen > {
834916 let fut = async {
835917 loop {
836- tokio:: select! {
837- res = msg_rx. recv( ) => {
838- if let Some ( msg) = res {
839- match msg {
840- protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketOpen ( msg) => {
841- return anyhow:: Ok ( msg) ;
842- }
843- protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketClose ( close) => {
844- tracing:: warn!( ?close, "websocket closed before opening" ) ;
845- return Err ( WebSocketServiceUnavailable . build( ) ) ;
846- }
847- _ => {
848- tracing:: warn!(
849- "received unexpected message while waiting for websocket open"
850- ) ;
918+ if stop_on_actor_stopped {
919+ tokio:: select! {
920+ res = msg_rx. recv( ) => {
921+ if let Some ( msg) = res {
922+ match msg {
923+ protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketOpen ( msg) => {
924+ return anyhow:: Ok ( msg) ;
925+ }
926+ protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketClose ( close) => {
927+ tracing:: warn!( ?close, "websocket closed before opening" ) ;
928+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
929+ }
930+ _ => {
931+ tracing:: warn!(
932+ "received unexpected message while waiting for websocket open"
933+ ) ;
934+ }
851935 }
936+ } else {
937+ tracing:: warn!(
938+ request_id=%protocol:: util:: id_to_string( & request_id) ,
939+ "received no message response during ws init" ,
940+ ) ;
941+ break ;
852942 }
853- } else {
854- tracing:: warn!(
855- request_id=%protocol:: util:: id_to_string( & request_id) ,
856- "received no message response during ws init" ,
857- ) ;
858- break ;
943+ }
944+ _ = stopped_sub. next( ) => {
945+ tracing:: debug!( "actor stopped while waiting for websocket open" ) ;
946+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
947+ }
948+ _ = drop_rx. changed( ) => {
949+ tracing:: warn!( reason=?drop_rx. borrow( ) , "websocket open timeout" ) ;
950+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
859951 }
860952 }
861- _ = stopped_sub. next( ) => {
862- tracing:: debug!( "actor stopped while waiting for websocket open" ) ;
863- return Err ( WebSocketServiceUnavailable . build( ) ) ;
864- }
865- _ = drop_rx. changed( ) => {
866- tracing:: warn!( reason=?drop_rx. borrow( ) , "websocket open timeout" ) ;
867- return Err ( WebSocketServiceUnavailable . build( ) ) ;
953+ } else {
954+ tokio:: select! {
955+ res = msg_rx. recv( ) => {
956+ if let Some ( msg) = res {
957+ match msg {
958+ protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketOpen ( msg) => {
959+ return anyhow:: Ok ( msg) ;
960+ }
961+ protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketClose ( close) => {
962+ tracing:: warn!( ?close, "websocket closed before opening" ) ;
963+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
964+ }
965+ _ => {
966+ tracing:: warn!(
967+ "received unexpected message while waiting for websocket open"
968+ ) ;
969+ }
970+ }
971+ } else {
972+ tracing:: warn!(
973+ request_id=%protocol:: util:: id_to_string( & request_id) ,
974+ "received no message response during ws init" ,
975+ ) ;
976+ break ;
977+ }
978+ }
979+ _ = drop_rx. changed( ) => {
980+ tracing:: warn!( reason=?drop_rx. borrow( ) , "websocket open timeout" ) ;
981+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
982+ }
868983 }
869984 }
870985 }
0 commit comments