Skip to content

Commit 0cda947

Browse files
committed
refactor(pegboard-gateway2): consolidate hibernation in-flight and tunnel flush
1 parent de1ee1c commit 0cda947

3 files changed

Lines changed: 282 additions & 80 deletions

File tree

engine/packages/pegboard-gateway2/src/lib.rs

Lines changed: 149 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::{
2020
sync::{Arc, atomic::AtomicU64},
2121
time::Duration,
2222
};
23-
use tokio::sync::{Mutex, watch};
23+
use tokio::sync::{Mutex, mpsc, watch};
2424
use tokio_tungstenite::tungstenite::{
2525
Message,
2626
protocol::frame::{CloseFrame, coding::CloseCode},
@@ -286,6 +286,7 @@ impl PegboardGateway2 {
286286
.pegboard()
287287
.gateway_websocket_open_timeout_ms(),
288288
),
289+
false,
289290
)
290291
.await?;
291292

@@ -321,6 +322,7 @@ impl PegboardGateway2 {
321322
.pegboard()
322323
.gateway_websocket_open_timeout_ms(),
323324
),
325+
true,
324326
)
325327
.await?;
326328

@@ -339,6 +341,7 @@ impl PegboardGateway2 {
339341
.resend_pending_websocket_messages(request_id)
340342
.await?;
341343

344+
let initial_tunnel_messages = drain_ready_tunnel_messages(request_id, &mut msg_rx);
342345
let ws_rx = client_ws.recv();
343346

344347
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
@@ -354,6 +357,7 @@ impl PegboardGateway2 {
354357
stopped_sub,
355358
msg_rx,
356359
drop_rx,
360+
initial_tunnel_messages,
357361
can_hibernate,
358362
egress_bytes.clone(),
359363
tunnel_to_ws_abort_rx,
@@ -397,18 +401,25 @@ impl PegboardGateway2 {
397401
} else {
398402
None
399403
};
404+
let tunnel_to_ws_abort_handle = tunnel_to_ws.abort_handle();
405+
let ws_to_tunnel_abort_handle = ws_to_tunnel.abort_handle();
400406

401407
// Wait for all tasks to complete
402408
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res, metrics_res) = tokio::join!(
403409
async {
404-
let res = tunnel_to_ws.await?;
410+
let res = match tunnel_to_ws.await {
411+
Ok(res) => res,
412+
Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted),
413+
Err(err) => Err(err.into()),
414+
};
405415

406416
// Abort other if not aborted
407417
if !matches!(res, Ok(LifecycleResult::Aborted)) {
408418
tracing::debug!(?res, "tunnel to ws task completed, aborting counterpart");
409419

410420
let _ = ping_abort_tx.send(());
411421
let _ = ws_to_tunnel_abort_tx.send(());
422+
ws_to_tunnel_abort_handle.abort();
412423
let _ = keepalive_abort_tx.send(());
413424
let _ = metrics_abort_tx.send(());
414425
} else {
@@ -418,14 +429,19 @@ impl PegboardGateway2 {
418429
res
419430
},
420431
async {
421-
let res = ws_to_tunnel.await?;
432+
let res = match ws_to_tunnel.await {
433+
Ok(res) => res,
434+
Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted),
435+
Err(err) => Err(err.into()),
436+
};
422437

423438
// Abort other if not aborted
424439
if !matches!(res, Ok(LifecycleResult::Aborted)) {
425440
tracing::debug!(?res, "ws to tunnel task completed, aborting counterpart");
426441

427442
let _ = ping_abort_tx.send(());
428443
let _ = tunnel_to_ws_abort_tx.send(());
444+
tunnel_to_ws_abort_handle.abort();
429445
let _ = keepalive_abort_tx.send(());
430446
let _ = metrics_abort_tx.send(());
431447
} else {
@@ -539,6 +555,15 @@ impl PegboardGateway2 {
539555
{
540556
tracing::error!(?err, "error sending close message");
541557
}
558+
559+
if matches!(lifecycle_res, Ok(LifecycleResult::ClientClose(_))) {
560+
ctx.op(pegboard::ops::actor::hibernating_request::delete::Input {
561+
actor_id: self.actor_id,
562+
gateway_id: self.shared_state.gateway_id(),
563+
request_id,
564+
})
565+
.await?;
566+
}
542567
}
543568

544569
// Send WebSocket close message to client
@@ -674,10 +699,27 @@ impl CustomServeTrait for PegboardGateway2 {
674699
.await?
675700
{
676701
tracing::debug!("exiting hibernating due to pending messages");
702+
tokio::try_join!(
703+
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
704+
actor_id: self.actor_id,
705+
gateway_id: self.shared_state.gateway_id(),
706+
request_id,
707+
}),
708+
self.shared_state.keepalive_hws(request_id),
709+
)?;
677710

678711
return Ok(HibernationResult::Continue);
679712
}
680713

714+
tokio::try_join!(
715+
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
716+
actor_id: self.actor_id,
717+
gateway_id: self.shared_state.gateway_id(),
718+
request_id,
719+
}),
720+
self.shared_state.keepalive_hws(request_id),
721+
)?;
722+
681723
// Start keepalive task
682724
let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(());
683725
let keepalive_handle = tokio::spawn(keepalive_task::task(
@@ -702,6 +744,17 @@ impl CustomServeTrait for PegboardGateway2 {
702744
let _ = keepalive_abort_tx.send(());
703745
let _ = keepalive_handle.await;
704746

747+
if matches!(res, Ok(HibernationResult::Continue)) {
748+
tokio::try_join!(
749+
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
750+
actor_id: self.actor_id,
751+
gateway_id: self.shared_state.gateway_id(),
752+
request_id,
753+
}),
754+
self.shared_state.keepalive_hws(request_id),
755+
)?;
756+
}
757+
705758
let (delete_res, _) = tokio::join!(
706759
async {
707760
match &res {
@@ -810,9 +863,12 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
810863
Ok(Message::Binary(_)) | Ok(Message::Text(_)) => {
811864
return Ok(HibernationResult::Continue);
812865
}
813-
// We don't care about the close frame because we're currently hibernating; there is no
814-
// downstream to send the close frame to.
815-
Ok(Message::Close(_)) => return Ok(HibernationResult::Close),
866+
// Consume the close frame so the websocket stack can complete the
867+
// close handshake while the actor is hibernating.
868+
Ok(Message::Close(_)) => {
869+
pinned.try_next().await?;
870+
return Ok(HibernationResult::Close);
871+
}
816872
// Ignore rest
817873
_ => {
818874
pinned.try_next().await?;
@@ -824,47 +880,106 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
824880
}
825881
}
826882

883+
fn drain_ready_tunnel_messages(
884+
request_id: protocol::RequestId,
885+
msg_rx: &mut mpsc::Receiver<protocol::ToRivetTunnelMessageKind>,
886+
) -> Vec<protocol::ToRivetTunnelMessageKind> {
887+
let mut messages = Vec::new();
888+
889+
loop {
890+
match msg_rx.try_recv() {
891+
Ok(message) => messages.push(message),
892+
Err(mpsc::error::TryRecvError::Empty) => break,
893+
Err(mpsc::error::TryRecvError::Disconnected) => break,
894+
}
895+
}
896+
897+
if !messages.is_empty() {
898+
tracing::debug!(
899+
request_id=%protocol::util::id_to_string(&request_id),
900+
message_count=messages.len(),
901+
"drained ready tunnel messages before websocket forwarding task"
902+
);
903+
}
904+
905+
messages
906+
}
907+
827908
async fn wait_for_envoy_websocket_open(
828909
msg_rx: &mut tokio::sync::mpsc::Receiver<protocol::ToRivetTunnelMessageKind>,
829910
drop_rx: &mut watch::Receiver<Option<crate::shared_state::MsgGcReason>>,
830911
stopped_sub: &mut message::SubscriptionHandle<pegboard::workflows::actor2::Stopped>,
831912
request_id: protocol::RequestId,
832913
websocket_open_timeout: Duration,
914+
stop_on_actor_stopped: bool,
833915
) -> Result<protocol::ToRivetWebSocketOpen> {
834916
let fut = async {
835917
loop {
836-
tokio::select! {
837-
res = msg_rx.recv() => {
838-
if let Some(msg) = res {
839-
match msg {
840-
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
841-
return anyhow::Ok(msg);
842-
}
843-
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
844-
tracing::warn!(?close, "websocket closed before opening");
845-
return Err(WebSocketServiceUnavailable.build());
846-
}
847-
_ => {
848-
tracing::warn!(
849-
"received unexpected message while waiting for websocket open"
850-
);
918+
if stop_on_actor_stopped {
919+
tokio::select! {
920+
res = msg_rx.recv() => {
921+
if let Some(msg) = res {
922+
match msg {
923+
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
924+
return anyhow::Ok(msg);
925+
}
926+
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
927+
tracing::warn!(?close, "websocket closed before opening");
928+
return Err(WebSocketServiceUnavailable.build());
929+
}
930+
_ => {
931+
tracing::warn!(
932+
"received unexpected message while waiting for websocket open"
933+
);
934+
}
851935
}
936+
} else {
937+
tracing::warn!(
938+
request_id=%protocol::util::id_to_string(&request_id),
939+
"received no message response during ws init",
940+
);
941+
break;
852942
}
853-
} else {
854-
tracing::warn!(
855-
request_id=%protocol::util::id_to_string(&request_id),
856-
"received no message response during ws init",
857-
);
858-
break;
943+
}
944+
_ = stopped_sub.next() => {
945+
tracing::debug!("actor stopped while waiting for websocket open");
946+
return Err(WebSocketServiceUnavailable.build());
947+
}
948+
_ = drop_rx.changed() => {
949+
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
950+
return Err(WebSocketServiceUnavailable.build());
859951
}
860952
}
861-
_ = stopped_sub.next() => {
862-
tracing::debug!("actor stopped while waiting for websocket open");
863-
return Err(WebSocketServiceUnavailable.build());
864-
}
865-
_ = drop_rx.changed() => {
866-
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
867-
return Err(WebSocketServiceUnavailable.build());
953+
} else {
954+
tokio::select! {
955+
res = msg_rx.recv() => {
956+
if let Some(msg) = res {
957+
match msg {
958+
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
959+
return anyhow::Ok(msg);
960+
}
961+
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
962+
tracing::warn!(?close, "websocket closed before opening");
963+
return Err(WebSocketServiceUnavailable.build());
964+
}
965+
_ => {
966+
tracing::warn!(
967+
"received unexpected message while waiting for websocket open"
968+
);
969+
}
970+
}
971+
} else {
972+
tracing::warn!(
973+
request_id=%protocol::util::id_to_string(&request_id),
974+
"received no message response during ws init",
975+
);
976+
break;
977+
}
978+
}
979+
_ = drop_rx.changed() => {
980+
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
981+
return Err(WebSocketServiceUnavailable.build());
982+
}
868983
}
869984
}
870985
}

0 commit comments

Comments
 (0)