@@ -20,7 +20,7 @@ use std::{
2020 sync:: { Arc , atomic:: AtomicU64 } ,
2121 time:: Duration ,
2222} ;
23- use tokio:: sync:: { Mutex , watch} ;
23+ use tokio:: sync:: { Mutex , mpsc , watch} ;
2424use tokio_tungstenite:: tungstenite:: {
2525 Message ,
2626 protocol:: frame:: { CloseFrame , coding:: CloseCode } ,
@@ -270,9 +270,31 @@ impl PegboardGateway2 {
270270 "should not be creating a new in flight entry after hibernation"
271271 ) ;
272272
273- // If we are reconnecting after hibernation, don't send an open message
273+ // If we are reconnecting after hibernation, the actor restore path
274+ // re-sends the websocket-open ack once the connection is attached. Wait
275+ // for that ack before replaying buffered client messages.
274276 let can_hibernate = if after_hibernation {
275- true
277+ tracing:: debug!( "gateway waiting for restored websocket open from tunnel" ) ;
278+ let open_msg = wait_for_envoy_websocket_open (
279+ & mut msg_rx,
280+ & mut drop_rx,
281+ & mut stopped_sub,
282+ request_id,
283+ Duration :: from_millis (
284+ self . ctx
285+ . config ( )
286+ . pegboard ( )
287+ . gateway_websocket_open_timeout_ms ( ) ,
288+ ) ,
289+ false ,
290+ )
291+ . await ?;
292+
293+ self . shared_state
294+ . toggle_hibernation ( request_id, open_msg. can_hibernate )
295+ . await ?;
296+
297+ open_msg. can_hibernate
276298 } else {
277299 // Send WebSocket open message
278300 let open_message = protocol:: ToEnvoyTunnelMessageKind :: ToEnvoyWebSocketOpen (
@@ -289,61 +311,20 @@ impl PegboardGateway2 {
289311
290312 tracing:: debug!( "gateway waiting for websocket open from tunnel" ) ;
291313
292- // Wait for WebSocket open acknowledgment
293- let fut = async {
294- loop {
295- tokio:: select! {
296- res = msg_rx. recv( ) => {
297- if let Some ( msg) = res {
298- match msg {
299- protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketOpen ( msg) => {
300- return anyhow:: Ok ( msg) ;
301- }
302- protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketClose ( close) => {
303- tracing:: warn!( ?close, "websocket closed before opening" ) ;
304- return Err ( WebSocketServiceUnavailable . build( ) ) ;
305- }
306- _ => {
307- tracing:: warn!(
308- "received unexpected message while waiting for websocket open"
309- ) ;
310- }
311- }
312- } else {
313- tracing:: warn!(
314- request_id=%protocol:: util:: id_to_string( & request_id) ,
315- "received no message response during ws init" ,
316- ) ;
317- break ;
318- }
319- }
320- _ = stopped_sub. next( ) => {
321- tracing:: debug!( "actor stopped while waiting for websocket open" ) ;
322- return Err ( WebSocketServiceUnavailable . build( ) ) ;
323- }
324- _ = drop_rx. changed( ) => {
325- tracing:: warn!( reason=?drop_rx. borrow( ) , "websocket open timeout" ) ;
326- return Err ( WebSocketServiceUnavailable . build( ) ) ;
327- }
328- }
329- }
330-
331- Err ( WebSocketServiceUnavailable . build ( ) )
332- } ;
333-
334- let websocket_open_timeout = Duration :: from_millis (
335- self . ctx
336- . config ( )
337- . pegboard ( )
338- . gateway_websocket_open_timeout_ms ( ) ,
339- ) ;
340- let open_msg = tokio:: time:: timeout ( websocket_open_timeout, fut)
341- . await
342- . map_err ( |_| {
343- tracing:: warn!( "timed out waiting for websocket open from envoy" ) ;
344-
345- WebSocketServiceUnavailable . build ( )
346- } ) ??;
314+ let open_msg = wait_for_envoy_websocket_open (
315+ & mut msg_rx,
316+ & mut drop_rx,
317+ & mut stopped_sub,
318+ request_id,
319+ Duration :: from_millis (
320+ self . ctx
321+ . config ( )
322+ . pegboard ( )
323+ . gateway_websocket_open_timeout_ms ( ) ,
324+ ) ,
325+ true ,
326+ )
327+ . await ?;
347328
348329 self . shared_state
349330 . toggle_hibernation ( request_id, open_msg. can_hibernate )
@@ -360,6 +341,7 @@ impl PegboardGateway2 {
360341 . resend_pending_websocket_messages ( request_id)
361342 . await ?;
362343
344+ let initial_tunnel_messages = drain_ready_tunnel_messages ( request_id, & mut msg_rx) ;
363345 let ws_rx = client_ws. recv ( ) ;
364346
365347 let ( tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch:: channel ( ( ) ) ;
@@ -375,6 +357,7 @@ impl PegboardGateway2 {
375357 stopped_sub,
376358 msg_rx,
377359 drop_rx,
360+ initial_tunnel_messages,
378361 can_hibernate,
379362 egress_bytes. clone ( ) ,
380363 tunnel_to_ws_abort_rx,
@@ -418,18 +401,25 @@ impl PegboardGateway2 {
418401 } else {
419402 None
420403 } ;
404+ let tunnel_to_ws_abort_handle = tunnel_to_ws. abort_handle ( ) ;
405+ let ws_to_tunnel_abort_handle = ws_to_tunnel. abort_handle ( ) ;
421406
422407 // Wait for all tasks to complete
423408 let ( tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res, metrics_res) = tokio:: join!(
424409 async {
425- let res = tunnel_to_ws. await ?;
410+ let res = match tunnel_to_ws. await {
411+ Ok ( res) => res,
412+ Err ( err) if err. is_cancelled( ) => Ok ( LifecycleResult :: Aborted ) ,
413+ Err ( err) => Err ( err. into( ) ) ,
414+ } ;
426415
427416 // Abort other if not aborted
428417 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
429418 tracing:: debug!( ?res, "tunnel to ws task completed, aborting counterpart" ) ;
430419
431420 let _ = ping_abort_tx. send( ( ) ) ;
432421 let _ = ws_to_tunnel_abort_tx. send( ( ) ) ;
422+ ws_to_tunnel_abort_handle. abort( ) ;
433423 let _ = keepalive_abort_tx. send( ( ) ) ;
434424 let _ = metrics_abort_tx. send( ( ) ) ;
435425 } else {
@@ -439,14 +429,19 @@ impl PegboardGateway2 {
439429 res
440430 } ,
441431 async {
442- let res = ws_to_tunnel. await ?;
432+ let res = match ws_to_tunnel. await {
433+ Ok ( res) => res,
434+ Err ( err) if err. is_cancelled( ) => Ok ( LifecycleResult :: Aborted ) ,
435+ Err ( err) => Err ( err. into( ) ) ,
436+ } ;
443437
444438 // Abort other if not aborted
445439 if !matches!( res, Ok ( LifecycleResult :: Aborted ) ) {
446440 tracing:: debug!( ?res, "ws to tunnel task completed, aborting counterpart" ) ;
447441
448442 let _ = ping_abort_tx. send( ( ) ) ;
449443 let _ = tunnel_to_ws_abort_tx. send( ( ) ) ;
444+ tunnel_to_ws_abort_handle. abort( ) ;
450445 let _ = keepalive_abort_tx. send( ( ) ) ;
451446 let _ = metrics_abort_tx. send( ( ) ) ;
452447 } else {
@@ -560,6 +555,15 @@ impl PegboardGateway2 {
560555 {
561556 tracing:: error!( ?err, "error sending close message" ) ;
562557 }
558+
559+ if matches ! ( lifecycle_res, Ok ( LifecycleResult :: ClientClose ( _) ) ) {
560+ ctx. op ( pegboard:: ops:: actor:: hibernating_request:: delete:: Input {
561+ actor_id : self . actor_id ,
562+ gateway_id : self . shared_state . gateway_id ( ) ,
563+ request_id,
564+ } )
565+ . await ?;
566+ }
563567 }
564568
565569 // Send WebSocket close message to client
@@ -695,10 +699,27 @@ impl CustomServeTrait for PegboardGateway2 {
695699 . await ?
696700 {
697701 tracing:: debug!( "exiting hibernating due to pending messages" ) ;
702+ tokio:: try_join!(
703+ ctx. op( pegboard:: ops:: actor:: hibernating_request:: upsert:: Input {
704+ actor_id: self . actor_id,
705+ gateway_id: self . shared_state. gateway_id( ) ,
706+ request_id,
707+ } ) ,
708+ self . shared_state. keepalive_hws( request_id) ,
709+ ) ?;
698710
699711 return Ok ( HibernationResult :: Continue ) ;
700712 }
701713
714+ tokio:: try_join!(
715+ ctx. op( pegboard:: ops:: actor:: hibernating_request:: upsert:: Input {
716+ actor_id: self . actor_id,
717+ gateway_id: self . shared_state. gateway_id( ) ,
718+ request_id,
719+ } ) ,
720+ self . shared_state. keepalive_hws( request_id) ,
721+ ) ?;
722+
702723 // Start keepalive task
703724 let ( keepalive_abort_tx, keepalive_abort_rx) = watch:: channel ( ( ) ) ;
704725 let keepalive_handle = tokio:: spawn ( keepalive_task:: task (
@@ -723,6 +744,17 @@ impl CustomServeTrait for PegboardGateway2 {
723744 let _ = keepalive_abort_tx. send ( ( ) ) ;
724745 let _ = keepalive_handle. await ;
725746
747+ if matches ! ( res, Ok ( HibernationResult :: Continue ) ) {
748+ tokio:: try_join!(
749+ ctx. op( pegboard:: ops:: actor:: hibernating_request:: upsert:: Input {
750+ actor_id: self . actor_id,
751+ gateway_id: self . shared_state. gateway_id( ) ,
752+ request_id,
753+ } ) ,
754+ self . shared_state. keepalive_hws( request_id) ,
755+ ) ?;
756+ }
757+
726758 let ( delete_res, _) = tokio:: join!(
727759 async {
728760 match & res {
@@ -831,9 +863,12 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
831863 Ok ( Message :: Binary ( _) ) | Ok ( Message :: Text ( _) ) => {
832864 return Ok ( HibernationResult :: Continue ) ;
833865 }
834- // We don't care about the close frame because we're currently hibernating; there is no
835- // downstream to send the close frame to.
836- Ok ( Message :: Close ( _) ) => return Ok ( HibernationResult :: Close ) ,
866+ // Consume the close frame so the websocket stack can complete the
867+ // close handshake while the actor is hibernating.
868+ Ok ( Message :: Close ( _) ) => {
869+ pinned. try_next ( ) . await ?;
870+ return Ok ( HibernationResult :: Close ) ;
871+ }
837872 // Ignore rest
838873 _ => {
839874 pinned. try_next ( ) . await ?;
@@ -845,6 +880,122 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
845880 }
846881}
847882
883+ fn drain_ready_tunnel_messages (
884+ request_id : protocol:: RequestId ,
885+ msg_rx : & mut mpsc:: Receiver < protocol:: ToRivetTunnelMessageKind > ,
886+ ) -> Vec < protocol:: ToRivetTunnelMessageKind > {
887+ let mut messages = Vec :: new ( ) ;
888+
889+ loop {
890+ match msg_rx. try_recv ( ) {
891+ Ok ( message) => messages. push ( message) ,
892+ Err ( mpsc:: error:: TryRecvError :: Empty ) => break ,
893+ Err ( mpsc:: error:: TryRecvError :: Disconnected ) => break ,
894+ }
895+ }
896+
897+ if !messages. is_empty ( ) {
898+ tracing:: debug!(
899+ request_id=%protocol:: util:: id_to_string( & request_id) ,
900+ message_count=messages. len( ) ,
901+ "drained ready tunnel messages before websocket forwarding task"
902+ ) ;
903+ }
904+
905+ messages
906+ }
907+
908+ async fn wait_for_envoy_websocket_open (
909+ msg_rx : & mut tokio:: sync:: mpsc:: Receiver < protocol:: ToRivetTunnelMessageKind > ,
910+ drop_rx : & mut watch:: Receiver < Option < crate :: shared_state:: MsgGcReason > > ,
911+ stopped_sub : & mut message:: SubscriptionHandle < pegboard:: workflows:: actor2:: Stopped > ,
912+ request_id : protocol:: RequestId ,
913+ websocket_open_timeout : Duration ,
914+ stop_on_actor_stopped : bool ,
915+ ) -> Result < protocol:: ToRivetWebSocketOpen > {
916+ let fut = async {
917+ loop {
918+ if stop_on_actor_stopped {
919+ tokio:: select! {
920+ res = msg_rx. recv( ) => {
921+ if let Some ( msg) = res {
922+ match msg {
923+ protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketOpen ( msg) => {
924+ return anyhow:: Ok ( msg) ;
925+ }
926+ protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketClose ( close) => {
927+ tracing:: warn!( ?close, "websocket closed before opening" ) ;
928+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
929+ }
930+ _ => {
931+ tracing:: warn!(
932+ "received unexpected message while waiting for websocket open"
933+ ) ;
934+ }
935+ }
936+ } else {
937+ tracing:: warn!(
938+ request_id=%protocol:: util:: id_to_string( & request_id) ,
939+ "received no message response during ws init" ,
940+ ) ;
941+ break ;
942+ }
943+ }
944+ _ = stopped_sub. next( ) => {
945+ tracing:: debug!( "actor stopped while waiting for websocket open" ) ;
946+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
947+ }
948+ _ = drop_rx. changed( ) => {
949+ tracing:: warn!( reason=?drop_rx. borrow( ) , "websocket open timeout" ) ;
950+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
951+ }
952+ }
953+ } else {
954+ tokio:: select! {
955+ res = msg_rx. recv( ) => {
956+ if let Some ( msg) = res {
957+ match msg {
958+ protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketOpen ( msg) => {
959+ return anyhow:: Ok ( msg) ;
960+ }
961+ protocol:: ToRivetTunnelMessageKind :: ToRivetWebSocketClose ( close) => {
962+ tracing:: warn!( ?close, "websocket closed before opening" ) ;
963+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
964+ }
965+ _ => {
966+ tracing:: warn!(
967+ "received unexpected message while waiting for websocket open"
968+ ) ;
969+ }
970+ }
971+ } else {
972+ tracing:: warn!(
973+ request_id=%protocol:: util:: id_to_string( & request_id) ,
974+ "received no message response during ws init" ,
975+ ) ;
976+ break ;
977+ }
978+ }
979+ _ = drop_rx. changed( ) => {
980+ tracing:: warn!( reason=?drop_rx. borrow( ) , "websocket open timeout" ) ;
981+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
982+ }
983+ }
984+ }
985+ }
986+
987+ Err ( WebSocketServiceUnavailable . build ( ) )
988+ } ;
989+
990+ tokio:: time:: timeout ( websocket_open_timeout, fut)
991+ . await
992+ . map_err ( |_| {
993+ tracing:: warn!( "timed out waiting for websocket open from envoy" ) ;
994+
995+ WebSocketServiceUnavailable . build ( )
996+ } ) ?
997+ }
998+
848999#[ derive( Debug ) ]
8491000enum Metric {
8501001 HttpIngress ( usize ) ,
0 commit comments