diff --git a/engine/packages/pegboard-gateway2/src/lib.rs b/engine/packages/pegboard-gateway2/src/lib.rs index e1571c9699..8ad1d3dae6 100644 --- a/engine/packages/pegboard-gateway2/src/lib.rs +++ b/engine/packages/pegboard-gateway2/src/lib.rs @@ -20,7 +20,7 @@ use std::{ sync::{Arc, atomic::AtomicU64}, time::Duration, }; -use tokio::sync::{Mutex, watch}; +use tokio::sync::{Mutex, mpsc, watch}; use tokio_tungstenite::tungstenite::{ Message, protocol::frame::{CloseFrame, coding::CloseCode}, @@ -270,9 +270,31 @@ impl PegboardGateway2 { "should not be creating a new in flight entry after hibernation" ); - // If we are reconnecting after hibernation, don't send an open message + // If we are reconnecting after hibernation, the actor restore path + // re-sends the websocket-open ack once the connection is attached. Wait + // for that ack before replaying buffered client messages. let can_hibernate = if after_hibernation { - true + tracing::debug!("gateway waiting for restored websocket open from tunnel"); + let open_msg = wait_for_envoy_websocket_open( + &mut msg_rx, + &mut drop_rx, + &mut stopped_sub, + request_id, + Duration::from_millis( + self.ctx + .config() + .pegboard() + .gateway_websocket_open_timeout_ms(), + ), + false, + ) + .await?; + + self.shared_state + .toggle_hibernation(request_id, open_msg.can_hibernate) + .await?; + + open_msg.can_hibernate } else { // Send WebSocket open message let open_message = protocol::ToEnvoyTunnelMessageKind::ToEnvoyWebSocketOpen( @@ -289,61 +311,20 @@ impl PegboardGateway2 { tracing::debug!("gateway waiting for websocket open from tunnel"); - // Wait for WebSocket open acknowledgment - let fut = async { - loop { - tokio::select! { - res = msg_rx.recv() => { - if let Some(msg) = res { - match msg { - protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => { - return anyhow::Ok(msg); - } - protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => { - tracing::warn!(?close, "websocket closed before opening"); - return Err(WebSocketServiceUnavailable.build()); - } - _ => { - tracing::warn!( - "received unexpected message while waiting for websocket open" - ); - } - } - } else { - tracing::warn!( - request_id=%protocol::util::id_to_string(&request_id), - "received no message response during ws init", - ); - break; - } - } - _ = stopped_sub.next() => { - tracing::debug!("actor stopped while waiting for websocket open"); - return Err(WebSocketServiceUnavailable.build()); - } - _ = drop_rx.changed() => { - tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout"); - return Err(WebSocketServiceUnavailable.build()); - } - } - } - - Err(WebSocketServiceUnavailable.build()) - }; - - let websocket_open_timeout = Duration::from_millis( - self.ctx - .config() - .pegboard() - .gateway_websocket_open_timeout_ms(), - ); - let open_msg = tokio::time::timeout(websocket_open_timeout, fut) - .await - .map_err(|_| { - tracing::warn!("timed out waiting for websocket open from envoy"); - - WebSocketServiceUnavailable.build() - })??; + let open_msg = wait_for_envoy_websocket_open( + &mut msg_rx, + &mut drop_rx, + &mut stopped_sub, + request_id, + Duration::from_millis( + self.ctx + .config() + .pegboard() + .gateway_websocket_open_timeout_ms(), + ), + true, + ) + .await?; self.shared_state .toggle_hibernation(request_id, open_msg.can_hibernate) @@ -360,6 +341,7 @@ impl PegboardGateway2 { .resend_pending_websocket_messages(request_id) .await?; + let initial_tunnel_messages = drain_ready_tunnel_messages(request_id, &mut msg_rx); let ws_rx = client_ws.recv(); let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(()); @@ -375,6 +357,7 @@ impl PegboardGateway2 { stopped_sub, msg_rx, drop_rx, + initial_tunnel_messages, can_hibernate, egress_bytes.clone(), tunnel_to_ws_abort_rx, @@ -418,11 +401,17 @@ impl PegboardGateway2 { } else { None }; + let tunnel_to_ws_abort_handle = tunnel_to_ws.abort_handle(); + let ws_to_tunnel_abort_handle = ws_to_tunnel.abort_handle(); // Wait for all tasks to complete let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res, metrics_res) = tokio::join!( async { - let res = tunnel_to_ws.await?; + let res = match tunnel_to_ws.await { + Ok(res) => res, + Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted), + Err(err) => Err(err.into()), + }; // Abort other if not aborted if !matches!(res, Ok(LifecycleResult::Aborted)) { @@ -430,6 +419,7 @@ impl PegboardGateway2 { let _ = ping_abort_tx.send(()); let _ = ws_to_tunnel_abort_tx.send(()); + ws_to_tunnel_abort_handle.abort(); let _ = keepalive_abort_tx.send(()); let _ = metrics_abort_tx.send(()); } else { @@ -439,7 +429,11 @@ impl PegboardGateway2 { res }, async { - let res = ws_to_tunnel.await?; + let res = match ws_to_tunnel.await { + Ok(res) => res, + Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted), + Err(err) => Err(err.into()), + }; // Abort other if not aborted if !matches!(res, Ok(LifecycleResult::Aborted)) { @@ -447,6 +441,7 @@ impl PegboardGateway2 { let _ = ping_abort_tx.send(()); let _ = tunnel_to_ws_abort_tx.send(()); + tunnel_to_ws_abort_handle.abort(); let _ = keepalive_abort_tx.send(()); let _ = metrics_abort_tx.send(()); } else { @@ -560,6 +555,15 @@ impl PegboardGateway2 { { tracing::error!(?err, "error sending close message"); } + + if matches!(lifecycle_res, Ok(LifecycleResult::ClientClose(_))) { + ctx.op(pegboard::ops::actor::hibernating_request::delete::Input { + actor_id: self.actor_id, + gateway_id: self.shared_state.gateway_id(), + request_id, + }) + .await?; + } } // Send WebSocket close message to client @@ -695,10 +699,27 @@ impl CustomServeTrait for PegboardGateway2 { .await? { tracing::debug!("exiting hibernating due to pending messages"); + tokio::try_join!( + ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input { + actor_id: self.actor_id, + gateway_id: self.shared_state.gateway_id(), + request_id, + }), + self.shared_state.keepalive_hws(request_id), + )?; return Ok(HibernationResult::Continue); } + tokio::try_join!( + ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input { + actor_id: self.actor_id, + gateway_id: self.shared_state.gateway_id(), + request_id, + }), + self.shared_state.keepalive_hws(request_id), + )?; + // Start keepalive task let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(()); let keepalive_handle = tokio::spawn(keepalive_task::task( @@ -723,6 +744,17 @@ impl CustomServeTrait for PegboardGateway2 { let _ = keepalive_abort_tx.send(()); let _ = keepalive_handle.await; + if matches!(res, Ok(HibernationResult::Continue)) { + tokio::try_join!( + ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input { + actor_id: self.actor_id, + gateway_id: self.shared_state.gateway_id(), + request_id, + }), + self.shared_state.keepalive_hws(request_id), + )?; + } + let (delete_res, _) = tokio::join!( async { match &res { @@ -831,9 +863,12 @@ async fn hibernate_ws(ws_rx: Arc>) -> Result { return Ok(HibernationResult::Continue); } - // We don't care about the close frame because we're currently hibernating; there is no - // downstream to send the close frame to. - Ok(Message::Close(_)) => return Ok(HibernationResult::Close), + // Consume the close frame so the websocket stack can complete the + // close handshake while the actor is hibernating. + Ok(Message::Close(_)) => { + pinned.try_next().await?; + return Ok(HibernationResult::Close); + } // Ignore rest _ => { pinned.try_next().await?; @@ -845,6 +880,122 @@ async fn hibernate_ws(ws_rx: Arc>) -> Result, +) -> Vec { + let mut messages = Vec::new(); + + loop { + match msg_rx.try_recv() { + Ok(message) => messages.push(message), + Err(mpsc::error::TryRecvError::Empty) => break, + Err(mpsc::error::TryRecvError::Disconnected) => break, + } + } + + if !messages.is_empty() { + tracing::debug!( + request_id=%protocol::util::id_to_string(&request_id), + message_count=messages.len(), + "drained ready tunnel messages before websocket forwarding task" + ); + } + + messages +} + +async fn wait_for_envoy_websocket_open( + msg_rx: &mut tokio::sync::mpsc::Receiver, + drop_rx: &mut watch::Receiver>, + stopped_sub: &mut message::SubscriptionHandle, + request_id: protocol::RequestId, + websocket_open_timeout: Duration, + stop_on_actor_stopped: bool, +) -> Result { + let fut = async { + loop { + if stop_on_actor_stopped { + tokio::select! { + res = msg_rx.recv() => { + if let Some(msg) = res { + match msg { + protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => { + return anyhow::Ok(msg); + } + protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => { + tracing::warn!(?close, "websocket closed before opening"); + return Err(WebSocketServiceUnavailable.build()); + } + _ => { + tracing::warn!( + "received unexpected message while waiting for websocket open" + ); + } + } + } else { + tracing::warn!( + request_id=%protocol::util::id_to_string(&request_id), + "received no message response during ws init", + ); + break; + } + } + _ = stopped_sub.next() => { + tracing::debug!("actor stopped while waiting for websocket open"); + return Err(WebSocketServiceUnavailable.build()); + } + _ = drop_rx.changed() => { + tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout"); + return Err(WebSocketServiceUnavailable.build()); + } + } + } else { + tokio::select! { + res = msg_rx.recv() => { + if let Some(msg) = res { + match msg { + protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => { + return anyhow::Ok(msg); + } + protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => { + tracing::warn!(?close, "websocket closed before opening"); + return Err(WebSocketServiceUnavailable.build()); + } + _ => { + tracing::warn!( + "received unexpected message while waiting for websocket open" + ); + } + } + } else { + tracing::warn!( + request_id=%protocol::util::id_to_string(&request_id), + "received no message response during ws init", + ); + break; + } + } + _ = drop_rx.changed() => { + tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout"); + return Err(WebSocketServiceUnavailable.build()); + } + } + } + } + + Err(WebSocketServiceUnavailable.build()) + }; + + tokio::time::timeout(websocket_open_timeout, fut) + .await + .map_err(|_| { + tracing::warn!("timed out waiting for websocket open from envoy"); + + WebSocketServiceUnavailable.build() + })? +} + #[derive(Debug)] enum Metric { HttpIngress(usize), diff --git a/engine/packages/pegboard-gateway2/src/shared_state.rs b/engine/packages/pegboard-gateway2/src/shared_state.rs index 6f063ad9b1..48c4b085f0 100644 --- a/engine/packages/pegboard-gateway2/src/shared_state.rs +++ b/engine/packages/pegboard-gateway2/src/shared_state.rs @@ -29,6 +29,7 @@ struct InFlightRequest { receiver_subject: String, /// Sender for incoming messages to this request. msg_tx: mpsc::Sender, + buffered_inbound: Vec, /// Used to check if the request handler has been dropped. drop_tx: watch::Sender>, /// True once first message for this request has been sent (so envoy learned reply_to). @@ -43,6 +44,9 @@ struct InFlightRequest { struct HibernationState { total_pending_ws_msgs_size: u64, pending_ws_msgs: Vec, + /// True while the old request handler is parked in hibernation and the + /// replacement after-hibernation handler has not taken over yet. + hibernating: bool, // Used to keep hibernating websockets from being GC'd last_ping: Instant, } @@ -131,12 +135,14 @@ impl SharedState { ) -> InFlightRequestHandle { let (msg_tx, msg_rx) = mpsc::channel(128); let (drop_tx, drop_rx) = watch::channel(None); + let mut buffered_inbound = Vec::new(); let new = match self.in_flight_requests.entry_async(request_id).await { Entry::Vacant(entry) => { entry.insert_entry(InFlightRequest { receiver_subject, msg_tx, + buffered_inbound: Vec::new(), drop_tx, opened: false, message_index: 0, @@ -154,6 +160,10 @@ impl SharedState { entry.drop_tx = drop_tx; entry.opened = false; entry.last_pong = util::timestamp::now(); + buffered_inbound = std::mem::take(&mut entry.buffered_inbound); + if let Some(hs) = &mut entry.hibernation_state { + hs.hibernating = false; + } if entry.stopping { entry.hibernation_state = None; @@ -164,6 +174,16 @@ impl SharedState { } }; + if let Some(req) = self.in_flight_requests.get_async(&request_id).await { + for msg in buffered_inbound { + tracing::debug!( + request_id=%protocol::util::id_to_string(&request_id), + "replaying buffered inbound tunnel message" + ); + let _ = req.msg_tx.send(msg).await; + } + } + InFlightRequestHandle { msg_rx, drop_rx, @@ -293,6 +313,7 @@ impl SharedState { .context("request not in flight")?; if let Some(hs) = &mut req.hibernation_state { + hs.hibernating = true; hs.last_ping = Instant::now(); } else { tracing::warn!("should not call keepalive_hws for non-hibernating ws"); @@ -362,7 +383,7 @@ impl SharedState { Ok(protocol::ToGateway::ToRivetTunnelMessage(msg)) => { let message_id = msg.message_id; - let Some(in_flight) = self + let Some(mut in_flight) = self .in_flight_requests .get_async(&message_id.request_id) .await @@ -376,7 +397,6 @@ impl SharedState { continue; }; - // Send message to the request handler to emulate the real network action let inner_size = match &msg.message_kind { protocol::ToRivetTunnelMessageKind::ToRivetWebSocketMessage(ws_msg) => { ws_msg.data.len() @@ -391,18 +411,44 @@ impl SharedState { "forwarding message to request handler" ); - if in_flight + let should_buffer_restored_open = in_flight.hibernation_state.is_some() + && in_flight.opened && matches!( + msg.message_kind, + protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(_) + ); + let should_buffer = + should_buffer_restored_open + || in_flight + .hibernation_state + .as_ref() + .is_some_and(|hs| hs.hibernating) + && matches!( + msg.message_kind, + protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(_) + | protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(_) + ); + + if should_buffer { + tracing::debug!( + gateway_id=%protocol::util::id_to_string(&message_id.gateway_id), + request_id=%protocol::util::id_to_string(&message_id.request_id), + message_index=message_id.message_index, + "buffering hibernating control tunnel message" + ); + in_flight.buffered_inbound.push(msg.message_kind.clone()); + } else if in_flight .msg_tx .send(msg.message_kind.clone()) .await .is_err() { - tracing::warn!( + tracing::debug!( gateway_id=%protocol::util::id_to_string(&message_id.gateway_id), request_id=%protocol::util::id_to_string(&message_id.request_id), - receiver_subject=%in_flight.receiver_subject, - "message handler channel closed", + message_index=message_id.message_index, + "buffering tunnel message after request receiver dropped" ); + in_flight.buffered_inbound.push(msg.message_kind.clone()); } } Err(err) => { @@ -432,6 +478,7 @@ impl SharedState { req.hibernation_state = Some(HibernationState { total_pending_ws_msgs_size: 0, pending_ws_msgs: Vec::new(), + hibernating: false, last_ping: Instant::now(), }); } diff --git a/engine/packages/pegboard-gateway2/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-gateway2/src/tunnel_to_ws_task.rs index 7f160f0376..4ed4a113ec 100644 --- a/engine/packages/pegboard-gateway2/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-gateway2/src/tunnel_to_ws_task.rs @@ -23,54 +23,44 @@ pub async fn task( mut stopped_sub: message::SubscriptionHandle, mut msg_rx: mpsc::Receiver, mut drop_rx: watch::Receiver>, + initial_messages: Vec, can_hibernate: bool, egress_bytes: Arc, mut tunnel_to_ws_abort_rx: watch::Receiver<()>, ) -> Result { + let mut initial_messages = initial_messages.into_iter(); + loop { + if let Some(msg) = initial_messages.next() { + if let Some(result) = handle_tunnel_message( + &shared_state, + &client_ws, + request_id, + can_hibernate, + &egress_bytes, + msg, + ) + .await? + { + return Ok(result); + } + continue; + } + tokio::select! { res = msg_rx.recv() => { if let Some(msg) = res { - match msg { - protocol::ToRivetTunnelMessageKind::ToRivetWebSocketMessage(ws_msg) => { - tracing::trace!( - request_id=%protocol::util::id_to_string(&request_id), - data_len=ws_msg.data.len(), - binary=ws_msg.binary, - "forwarding websocket message to client" - ); - let msg = if ws_msg.binary { - Message::Binary(ws_msg.data.into()) - } else { - Message::Text( - String::from_utf8_lossy(&ws_msg.data).into_owned().into(), - ) - }; - - egress_bytes.fetch_add(msg.len() as u64, Ordering::AcqRel); - client_ws.send(msg).await?; - } - protocol::ToRivetTunnelMessageKind::ToRivetWebSocketMessageAck(ack) => { - tracing::debug!( - request_id=%protocol::util::id_to_string(&request_id), - ack_index=?ack.index, - "received WebSocketMessageAck from envoy" - ); - shared_state - .ack_pending_websocket_messages(request_id, ack.index) - .await?; - } - protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => { - tracing::debug!(?close, "server closed websocket"); - - if can_hibernate && close.hibernate { - return Err(WebSocketServiceHibernate.build()); - } else { - // Successful closure - return Ok(LifecycleResult::ServerClose(close)); - } - } - _ => {} + if let Some(result) = handle_tunnel_message( + &shared_state, + &client_ws, + request_id, + can_hibernate, + &egress_bytes, + msg, + ) + .await? + { + return Ok(result); } } else { tracing::debug!("tunnel sub closed"); @@ -97,3 +87,53 @@ pub async fn task( } } } + +async fn handle_tunnel_message( + shared_state: &SharedState, + client_ws: &WebSocketHandle, + request_id: protocol::RequestId, + can_hibernate: bool, + egress_bytes: &AtomicU64, + msg: protocol::ToRivetTunnelMessageKind, +) -> Result> { + match msg { + protocol::ToRivetTunnelMessageKind::ToRivetWebSocketMessage(ws_msg) => { + tracing::trace!( + request_id=%protocol::util::id_to_string(&request_id), + data_len=ws_msg.data.len(), + binary=ws_msg.binary, + "forwarding websocket message to client" + ); + let msg = if ws_msg.binary { + Message::Binary(ws_msg.data.into()) + } else { + Message::Text(String::from_utf8_lossy(&ws_msg.data).into_owned().into()) + }; + + egress_bytes.fetch_add(msg.len() as u64, Ordering::AcqRel); + client_ws.send(msg).await?; + } + protocol::ToRivetTunnelMessageKind::ToRivetWebSocketMessageAck(ack) => { + tracing::debug!( + request_id=%protocol::util::id_to_string(&request_id), + ack_index=?ack.index, + "received WebSocketMessageAck from envoy" + ); + shared_state + .ack_pending_websocket_messages(request_id, ack.index) + .await?; + } + protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => { + tracing::debug!(?close, "server closed websocket"); + + if can_hibernate && close.hibernate { + return Err(WebSocketServiceHibernate.build()); + } + + return Ok(Some(LifecycleResult::ServerClose(close))); + } + _ => {} + } + + Ok(None) +}