Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 149 additions & 34 deletions engine/packages/pegboard-gateway2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{
sync::{Arc, atomic::AtomicU64},
time::Duration,
};
use tokio::sync::{Mutex, watch};
use tokio::sync::{Mutex, mpsc, watch};
use tokio_tungstenite::tungstenite::{
Message,
protocol::frame::{CloseFrame, coding::CloseCode},
Expand Down Expand Up @@ -336,6 +336,7 @@ impl PegboardGateway2 {
.pegboard()
.gateway_websocket_open_timeout_ms(),
),
false,
)
.await?;

Expand Down Expand Up @@ -371,6 +372,7 @@ impl PegboardGateway2 {
.pegboard()
.gateway_websocket_open_timeout_ms(),
),
true,
)
.await?;

Expand All @@ -389,6 +391,7 @@ impl PegboardGateway2 {
.resend_pending_websocket_messages(request_id)
.await?;

let initial_tunnel_messages = drain_ready_tunnel_messages(request_id, &mut msg_rx);
let ws_rx = client_ws.recv();

let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
Expand All @@ -404,6 +407,7 @@ impl PegboardGateway2 {
stopped_sub,
msg_rx,
drop_rx,
initial_tunnel_messages,
can_hibernate,
egress_bytes.clone(),
tunnel_to_ws_abort_rx,
Expand Down Expand Up @@ -447,18 +451,25 @@ impl PegboardGateway2 {
} else {
None
};
let tunnel_to_ws_abort_handle = tunnel_to_ws.abort_handle();
let ws_to_tunnel_abort_handle = ws_to_tunnel.abort_handle();

// Wait for all tasks to complete
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res, metrics_res) = tokio::join!(
async {
let res = tunnel_to_ws.await?;
let res = match tunnel_to_ws.await {
Ok(res) => res,
Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted),
Err(err) => Err(err.into()),
};

// Abort other if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "tunnel to ws task completed, aborting counterpart");

let _ = ping_abort_tx.send(());
let _ = ws_to_tunnel_abort_tx.send(());
ws_to_tunnel_abort_handle.abort();
let _ = keepalive_abort_tx.send(());
let _ = metrics_abort_tx.send(());
} else {
Expand All @@ -468,14 +479,19 @@ impl PegboardGateway2 {
res
},
async {
let res = ws_to_tunnel.await?;
let res = match ws_to_tunnel.await {
Ok(res) => res,
Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted),
Err(err) => Err(err.into()),
};

// Abort other if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "ws to tunnel task completed, aborting counterpart");

let _ = ping_abort_tx.send(());
let _ = tunnel_to_ws_abort_tx.send(());
tunnel_to_ws_abort_handle.abort();
let _ = keepalive_abort_tx.send(());
let _ = metrics_abort_tx.send(());
} else {
Expand Down Expand Up @@ -589,6 +605,15 @@ impl PegboardGateway2 {
{
tracing::error!(?err, "error sending close message");
}

if matches!(lifecycle_res, Ok(LifecycleResult::ClientClose(_))) {
ctx.op(pegboard::ops::actor::hibernating_request::delete::Input {
actor_id: self.actor_id,
gateway_id: self.shared_state.gateway_id(),
request_id,
})
.await?;
}
}

// Send WebSocket close message to client
Expand Down Expand Up @@ -724,10 +749,27 @@ impl CustomServeTrait for PegboardGateway2 {
.await?
{
tracing::debug!("exiting hibernating due to pending messages");
tokio::try_join!(
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id: self.actor_id,
gateway_id: self.shared_state.gateway_id(),
request_id,
}),
self.shared_state.keepalive_hws(request_id),
)?;

return Ok(HibernationResult::Continue);
}

tokio::try_join!(
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id: self.actor_id,
gateway_id: self.shared_state.gateway_id(),
request_id,
}),
self.shared_state.keepalive_hws(request_id),
)?;

// Start keepalive task
let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(());
let keepalive_handle = tokio::spawn(keepalive_task::task(
Expand All @@ -752,6 +794,17 @@ impl CustomServeTrait for PegboardGateway2 {
let _ = keepalive_abort_tx.send(());
let _ = keepalive_handle.await;

if matches!(res, Ok(HibernationResult::Continue)) {
tokio::try_join!(
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id: self.actor_id,
gateway_id: self.shared_state.gateway_id(),
request_id,
}),
self.shared_state.keepalive_hws(request_id),
)?;
}

let (delete_res, _) = tokio::join!(
async {
match &res {
Expand Down Expand Up @@ -860,9 +913,12 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
Ok(Message::Binary(_)) | Ok(Message::Text(_)) => {
return Ok(HibernationResult::Continue);
}
// We don't care about the close frame because we're currently hibernating; there is no
// downstream to send the close frame to.
Ok(Message::Close(_)) => return Ok(HibernationResult::Close),
// Consume the close frame so the websocket stack can complete the
// close handshake while the actor is hibernating.
Ok(Message::Close(_)) => {
pinned.try_next().await?;
return Ok(HibernationResult::Close);
}
// Ignore rest
_ => {
pinned.try_next().await?;
Expand All @@ -874,47 +930,106 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
}
}

fn drain_ready_tunnel_messages(
request_id: protocol::RequestId,
msg_rx: &mut mpsc::Receiver<protocol::ToRivetTunnelMessageKind>,
) -> Vec<protocol::ToRivetTunnelMessageKind> {
let mut messages = Vec::new();

loop {
match msg_rx.try_recv() {
Ok(message) => messages.push(message),
Err(mpsc::error::TryRecvError::Empty) => break,
Err(mpsc::error::TryRecvError::Disconnected) => break,
}
}

if !messages.is_empty() {
tracing::debug!(
request_id=%protocol::util::id_to_string(&request_id),
message_count=messages.len(),
"drained ready tunnel messages before websocket forwarding task"
);
}

messages
}

async fn wait_for_envoy_websocket_open(
msg_rx: &mut tokio::sync::mpsc::Receiver<protocol::ToRivetTunnelMessageKind>,
drop_rx: &mut watch::Receiver<Option<crate::shared_state::MsgGcReason>>,
stopped_sub: &mut message::SubscriptionHandle<pegboard::workflows::actor2::Stopped>,
request_id: protocol::RequestId,
websocket_open_timeout: Duration,
stop_on_actor_stopped: bool,
) -> Result<protocol::ToRivetWebSocketOpen> {
let fut = async {
loop {
tokio::select! {
res = msg_rx.recv() => {
if let Some(msg) = res {
match msg {
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
return anyhow::Ok(msg);
}
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
tracing::warn!(?close, "websocket closed before opening");
return Err(WebSocketServiceUnavailable.build());
}
_ => {
tracing::warn!(
"received unexpected message while waiting for websocket open"
);
if stop_on_actor_stopped {
tokio::select! {
res = msg_rx.recv() => {
if let Some(msg) = res {
match msg {
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
return anyhow::Ok(msg);
}
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
tracing::warn!(?close, "websocket closed before opening");
return Err(WebSocketServiceUnavailable.build());
}
_ => {
tracing::warn!(
"received unexpected message while waiting for websocket open"
);
}
}
} else {
tracing::warn!(
request_id=%protocol::util::id_to_string(&request_id),
"received no message response during ws init",
);
break;
}
} else {
tracing::warn!(
request_id=%protocol::util::id_to_string(&request_id),
"received no message response during ws init",
);
break;
}
_ = stopped_sub.next() => {
tracing::debug!("actor stopped while waiting for websocket open");
return Err(WebSocketServiceUnavailable.build());
}
_ = drop_rx.changed() => {
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
return Err(WebSocketServiceUnavailable.build());
}
}
_ = stopped_sub.next() => {
tracing::debug!("actor stopped while waiting for websocket open");
return Err(WebSocketServiceUnavailable.build());
}
_ = drop_rx.changed() => {
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
return Err(WebSocketServiceUnavailable.build());
} else {
tokio::select! {
res = msg_rx.recv() => {
if let Some(msg) = res {
match msg {
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
return anyhow::Ok(msg);
}
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
tracing::warn!(?close, "websocket closed before opening");
return Err(WebSocketServiceUnavailable.build());
}
_ => {
tracing::warn!(
"received unexpected message while waiting for websocket open"
);
}
}
} else {
tracing::warn!(
request_id=%protocol::util::id_to_string(&request_id),
"received no message response during ws init",
);
break;
}
}
_ = drop_rx.changed() => {
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
return Err(WebSocketServiceUnavailable.build());
}
}
}
}
Expand Down
Loading
Loading