Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 214 additions & 63 deletions engine/packages/pegboard-gateway2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{
sync::{Arc, atomic::AtomicU64},
time::Duration,
};
use tokio::sync::{Mutex, watch};
use tokio::sync::{Mutex, mpsc, watch};
use tokio_tungstenite::tungstenite::{
Message,
protocol::frame::{CloseFrame, coding::CloseCode},
Expand Down Expand Up @@ -270,9 +270,31 @@ impl PegboardGateway2 {
"should not be creating a new in flight entry after hibernation"
);

// If we are reconnecting after hibernation, don't send an open message
// If we are reconnecting after hibernation, the actor restore path
// re-sends the websocket-open ack once the connection is attached. Wait
// for that ack before replaying buffered client messages.
let can_hibernate = if after_hibernation {
true
tracing::debug!("gateway waiting for restored websocket open from tunnel");
let open_msg = wait_for_envoy_websocket_open(
&mut msg_rx,
&mut drop_rx,
&mut stopped_sub,
request_id,
Duration::from_millis(
self.ctx
.config()
.pegboard()
.gateway_websocket_open_timeout_ms(),
),
false,
)
.await?;

self.shared_state
.toggle_hibernation(request_id, open_msg.can_hibernate)
.await?;

open_msg.can_hibernate
} else {
// Send WebSocket open message
let open_message = protocol::ToEnvoyTunnelMessageKind::ToEnvoyWebSocketOpen(
Expand All @@ -289,61 +311,20 @@ impl PegboardGateway2 {

tracing::debug!("gateway waiting for websocket open from tunnel");

// Wait for WebSocket open acknowledgment
let fut = async {
loop {
tokio::select! {
res = msg_rx.recv() => {
if let Some(msg) = res {
match msg {
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
return anyhow::Ok(msg);
}
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
tracing::warn!(?close, "websocket closed before opening");
return Err(WebSocketServiceUnavailable.build());
}
_ => {
tracing::warn!(
"received unexpected message while waiting for websocket open"
);
}
}
} else {
tracing::warn!(
request_id=%protocol::util::id_to_string(&request_id),
"received no message response during ws init",
);
break;
}
}
_ = stopped_sub.next() => {
tracing::debug!("actor stopped while waiting for websocket open");
return Err(WebSocketServiceUnavailable.build());
}
_ = drop_rx.changed() => {
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
return Err(WebSocketServiceUnavailable.build());
}
}
}

Err(WebSocketServiceUnavailable.build())
};

let websocket_open_timeout = Duration::from_millis(
self.ctx
.config()
.pegboard()
.gateway_websocket_open_timeout_ms(),
);
let open_msg = tokio::time::timeout(websocket_open_timeout, fut)
.await
.map_err(|_| {
tracing::warn!("timed out waiting for websocket open from envoy");

WebSocketServiceUnavailable.build()
})??;
let open_msg = wait_for_envoy_websocket_open(
&mut msg_rx,
&mut drop_rx,
&mut stopped_sub,
request_id,
Duration::from_millis(
self.ctx
.config()
.pegboard()
.gateway_websocket_open_timeout_ms(),
),
true,
)
.await?;

self.shared_state
.toggle_hibernation(request_id, open_msg.can_hibernate)
Expand All @@ -360,6 +341,7 @@ impl PegboardGateway2 {
.resend_pending_websocket_messages(request_id)
.await?;

let initial_tunnel_messages = drain_ready_tunnel_messages(request_id, &mut msg_rx);
let ws_rx = client_ws.recv();

let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
Expand All @@ -375,6 +357,7 @@ impl PegboardGateway2 {
stopped_sub,
msg_rx,
drop_rx,
initial_tunnel_messages,
can_hibernate,
egress_bytes.clone(),
tunnel_to_ws_abort_rx,
Expand Down Expand Up @@ -418,18 +401,25 @@ impl PegboardGateway2 {
} else {
None
};
let tunnel_to_ws_abort_handle = tunnel_to_ws.abort_handle();
let ws_to_tunnel_abort_handle = ws_to_tunnel.abort_handle();

// Wait for all tasks to complete
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res, metrics_res) = tokio::join!(
async {
let res = tunnel_to_ws.await?;
let res = match tunnel_to_ws.await {
Ok(res) => res,
Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted),
Err(err) => Err(err.into()),
};

// Abort other if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "tunnel to ws task completed, aborting counterpart");

let _ = ping_abort_tx.send(());
let _ = ws_to_tunnel_abort_tx.send(());
ws_to_tunnel_abort_handle.abort();
let _ = keepalive_abort_tx.send(());
let _ = metrics_abort_tx.send(());
} else {
Expand All @@ -439,14 +429,19 @@ impl PegboardGateway2 {
res
},
async {
let res = ws_to_tunnel.await?;
let res = match ws_to_tunnel.await {
Ok(res) => res,
Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted),
Err(err) => Err(err.into()),
};

// Abort other if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "ws to tunnel task completed, aborting counterpart");

let _ = ping_abort_tx.send(());
let _ = tunnel_to_ws_abort_tx.send(());
tunnel_to_ws_abort_handle.abort();
let _ = keepalive_abort_tx.send(());
let _ = metrics_abort_tx.send(());
} else {
Expand Down Expand Up @@ -560,6 +555,15 @@ impl PegboardGateway2 {
{
tracing::error!(?err, "error sending close message");
}

if matches!(lifecycle_res, Ok(LifecycleResult::ClientClose(_))) {
ctx.op(pegboard::ops::actor::hibernating_request::delete::Input {
actor_id: self.actor_id,
gateway_id: self.shared_state.gateway_id(),
request_id,
})
.await?;
}
}

// Send WebSocket close message to client
Expand Down Expand Up @@ -695,10 +699,27 @@ impl CustomServeTrait for PegboardGateway2 {
.await?
{
tracing::debug!("exiting hibernating due to pending messages");
tokio::try_join!(
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id: self.actor_id,
gateway_id: self.shared_state.gateway_id(),
request_id,
}),
self.shared_state.keepalive_hws(request_id),
)?;

return Ok(HibernationResult::Continue);
}

tokio::try_join!(
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id: self.actor_id,
gateway_id: self.shared_state.gateway_id(),
request_id,
}),
self.shared_state.keepalive_hws(request_id),
)?;

// Start keepalive task
let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(());
let keepalive_handle = tokio::spawn(keepalive_task::task(
Expand All @@ -723,6 +744,17 @@ impl CustomServeTrait for PegboardGateway2 {
let _ = keepalive_abort_tx.send(());
let _ = keepalive_handle.await;

if matches!(res, Ok(HibernationResult::Continue)) {
tokio::try_join!(
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id: self.actor_id,
gateway_id: self.shared_state.gateway_id(),
request_id,
}),
self.shared_state.keepalive_hws(request_id),
)?;
}

let (delete_res, _) = tokio::join!(
async {
match &res {
Expand Down Expand Up @@ -831,9 +863,12 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
Ok(Message::Binary(_)) | Ok(Message::Text(_)) => {
return Ok(HibernationResult::Continue);
}
// We don't care about the close frame because we're currently hibernating; there is no
// downstream to send the close frame to.
Ok(Message::Close(_)) => return Ok(HibernationResult::Close),
// Consume the close frame so the websocket stack can complete the
// close handshake while the actor is hibernating.
Ok(Message::Close(_)) => {
pinned.try_next().await?;
return Ok(HibernationResult::Close);
}
// Ignore rest
_ => {
pinned.try_next().await?;
Expand All @@ -845,6 +880,122 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
}
}

fn drain_ready_tunnel_messages(
request_id: protocol::RequestId,
msg_rx: &mut mpsc::Receiver<protocol::ToRivetTunnelMessageKind>,
) -> Vec<protocol::ToRivetTunnelMessageKind> {
let mut messages = Vec::new();

loop {
match msg_rx.try_recv() {
Ok(message) => messages.push(message),
Err(mpsc::error::TryRecvError::Empty) => break,
Err(mpsc::error::TryRecvError::Disconnected) => break,
}
}

if !messages.is_empty() {
tracing::debug!(
request_id=%protocol::util::id_to_string(&request_id),
message_count=messages.len(),
"drained ready tunnel messages before websocket forwarding task"
);
}

messages
}

async fn wait_for_envoy_websocket_open(
msg_rx: &mut tokio::sync::mpsc::Receiver<protocol::ToRivetTunnelMessageKind>,
drop_rx: &mut watch::Receiver<Option<crate::shared_state::MsgGcReason>>,
stopped_sub: &mut message::SubscriptionHandle<pegboard::workflows::actor2::Stopped>,
request_id: protocol::RequestId,
websocket_open_timeout: Duration,
stop_on_actor_stopped: bool,
) -> Result<protocol::ToRivetWebSocketOpen> {
let fut = async {
loop {
if stop_on_actor_stopped {
tokio::select! {
res = msg_rx.recv() => {
if let Some(msg) = res {
match msg {
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
return anyhow::Ok(msg);
}
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
tracing::warn!(?close, "websocket closed before opening");
return Err(WebSocketServiceUnavailable.build());
}
_ => {
tracing::warn!(
"received unexpected message while waiting for websocket open"
);
}
}
} else {
tracing::warn!(
request_id=%protocol::util::id_to_string(&request_id),
"received no message response during ws init",
);
break;
}
}
_ = stopped_sub.next() => {
tracing::debug!("actor stopped while waiting for websocket open");
return Err(WebSocketServiceUnavailable.build());
}
_ = drop_rx.changed() => {
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
return Err(WebSocketServiceUnavailable.build());
}
}
} else {
tokio::select! {
res = msg_rx.recv() => {
if let Some(msg) = res {
match msg {
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
return anyhow::Ok(msg);
}
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
tracing::warn!(?close, "websocket closed before opening");
return Err(WebSocketServiceUnavailable.build());
}
_ => {
tracing::warn!(
"received unexpected message while waiting for websocket open"
);
}
}
} else {
tracing::warn!(
request_id=%protocol::util::id_to_string(&request_id),
"received no message response during ws init",
);
break;
}
}
_ = drop_rx.changed() => {
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
return Err(WebSocketServiceUnavailable.build());
}
}
}
}

Err(WebSocketServiceUnavailable.build())
};

tokio::time::timeout(websocket_open_timeout, fut)
.await
.map_err(|_| {
tracing::warn!("timed out waiting for websocket open from envoy");

WebSocketServiceUnavailable.build()
})?
}

#[derive(Debug)]
enum Metric {
HttpIngress(usize),
Expand Down
Loading
Loading