Skip to content

Commit 7a5e467

Browse files
committed
refactor(pegboard-gateway2): consolidate hibernation in-flight and tunnel flush
1 parent 5ef7bf6 commit 7a5e467

3 files changed

Lines changed: 347 additions & 109 deletions

File tree

engine/packages/pegboard-gateway2/src/lib.rs

Lines changed: 214 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::{
2020
sync::{Arc, atomic::AtomicU64},
2121
time::Duration,
2222
};
23-
use tokio::sync::{Mutex, watch};
23+
use tokio::sync::{Mutex, mpsc, watch};
2424
use tokio_tungstenite::tungstenite::{
2525
Message,
2626
protocol::frame::{CloseFrame, coding::CloseCode},
@@ -270,9 +270,31 @@ impl PegboardGateway2 {
270270
"should not be creating a new in flight entry after hibernation"
271271
);
272272

273-
// If we are reconnecting after hibernation, don't send an open message
273+
// If we are reconnecting after hibernation, the actor restore path
274+
// re-sends the websocket-open ack once the connection is attached. Wait
275+
// for that ack before replaying buffered client messages.
274276
let can_hibernate = if after_hibernation {
275-
true
277+
tracing::debug!("gateway waiting for restored websocket open from tunnel");
278+
let open_msg = wait_for_envoy_websocket_open(
279+
&mut msg_rx,
280+
&mut drop_rx,
281+
&mut stopped_sub,
282+
request_id,
283+
Duration::from_millis(
284+
self.ctx
285+
.config()
286+
.pegboard()
287+
.gateway_websocket_open_timeout_ms(),
288+
),
289+
false,
290+
)
291+
.await?;
292+
293+
self.shared_state
294+
.toggle_hibernation(request_id, open_msg.can_hibernate)
295+
.await?;
296+
297+
open_msg.can_hibernate
276298
} else {
277299
// Send WebSocket open message
278300
let open_message = protocol::ToEnvoyTunnelMessageKind::ToEnvoyWebSocketOpen(
@@ -289,61 +311,20 @@ impl PegboardGateway2 {
289311

290312
tracing::debug!("gateway waiting for websocket open from tunnel");
291313

292-
// Wait for WebSocket open acknowledgment
293-
let fut = async {
294-
loop {
295-
tokio::select! {
296-
res = msg_rx.recv() => {
297-
if let Some(msg) = res {
298-
match msg {
299-
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
300-
return anyhow::Ok(msg);
301-
}
302-
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
303-
tracing::warn!(?close, "websocket closed before opening");
304-
return Err(WebSocketServiceUnavailable.build());
305-
}
306-
_ => {
307-
tracing::warn!(
308-
"received unexpected message while waiting for websocket open"
309-
);
310-
}
311-
}
312-
} else {
313-
tracing::warn!(
314-
request_id=%protocol::util::id_to_string(&request_id),
315-
"received no message response during ws init",
316-
);
317-
break;
318-
}
319-
}
320-
_ = stopped_sub.next() => {
321-
tracing::debug!("actor stopped while waiting for websocket open");
322-
return Err(WebSocketServiceUnavailable.build());
323-
}
324-
_ = drop_rx.changed() => {
325-
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
326-
return Err(WebSocketServiceUnavailable.build());
327-
}
328-
}
329-
}
330-
331-
Err(WebSocketServiceUnavailable.build())
332-
};
333-
334-
let websocket_open_timeout = Duration::from_millis(
335-
self.ctx
336-
.config()
337-
.pegboard()
338-
.gateway_websocket_open_timeout_ms(),
339-
);
340-
let open_msg = tokio::time::timeout(websocket_open_timeout, fut)
341-
.await
342-
.map_err(|_| {
343-
tracing::warn!("timed out waiting for websocket open from envoy");
344-
345-
WebSocketServiceUnavailable.build()
346-
})??;
314+
let open_msg = wait_for_envoy_websocket_open(
315+
&mut msg_rx,
316+
&mut drop_rx,
317+
&mut stopped_sub,
318+
request_id,
319+
Duration::from_millis(
320+
self.ctx
321+
.config()
322+
.pegboard()
323+
.gateway_websocket_open_timeout_ms(),
324+
),
325+
true,
326+
)
327+
.await?;
347328

348329
self.shared_state
349330
.toggle_hibernation(request_id, open_msg.can_hibernate)
@@ -360,6 +341,7 @@ impl PegboardGateway2 {
360341
.resend_pending_websocket_messages(request_id)
361342
.await?;
362343

344+
let initial_tunnel_messages = drain_ready_tunnel_messages(request_id, &mut msg_rx);
363345
let ws_rx = client_ws.recv();
364346

365347
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
@@ -375,6 +357,7 @@ impl PegboardGateway2 {
375357
stopped_sub,
376358
msg_rx,
377359
drop_rx,
360+
initial_tunnel_messages,
378361
can_hibernate,
379362
egress_bytes.clone(),
380363
tunnel_to_ws_abort_rx,
@@ -418,18 +401,25 @@ impl PegboardGateway2 {
418401
} else {
419402
None
420403
};
404+
let tunnel_to_ws_abort_handle = tunnel_to_ws.abort_handle();
405+
let ws_to_tunnel_abort_handle = ws_to_tunnel.abort_handle();
421406

422407
// Wait for all tasks to complete
423408
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res, metrics_res) = tokio::join!(
424409
async {
425-
let res = tunnel_to_ws.await?;
410+
let res = match tunnel_to_ws.await {
411+
Ok(res) => res,
412+
Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted),
413+
Err(err) => Err(err.into()),
414+
};
426415

427416
// Abort other if not aborted
428417
if !matches!(res, Ok(LifecycleResult::Aborted)) {
429418
tracing::debug!(?res, "tunnel to ws task completed, aborting counterpart");
430419

431420
let _ = ping_abort_tx.send(());
432421
let _ = ws_to_tunnel_abort_tx.send(());
422+
ws_to_tunnel_abort_handle.abort();
433423
let _ = keepalive_abort_tx.send(());
434424
let _ = metrics_abort_tx.send(());
435425
} else {
@@ -439,14 +429,19 @@ impl PegboardGateway2 {
439429
res
440430
},
441431
async {
442-
let res = ws_to_tunnel.await?;
432+
let res = match ws_to_tunnel.await {
433+
Ok(res) => res,
434+
Err(err) if err.is_cancelled() => Ok(LifecycleResult::Aborted),
435+
Err(err) => Err(err.into()),
436+
};
443437

444438
// Abort other if not aborted
445439
if !matches!(res, Ok(LifecycleResult::Aborted)) {
446440
tracing::debug!(?res, "ws to tunnel task completed, aborting counterpart");
447441

448442
let _ = ping_abort_tx.send(());
449443
let _ = tunnel_to_ws_abort_tx.send(());
444+
tunnel_to_ws_abort_handle.abort();
450445
let _ = keepalive_abort_tx.send(());
451446
let _ = metrics_abort_tx.send(());
452447
} else {
@@ -560,6 +555,15 @@ impl PegboardGateway2 {
560555
{
561556
tracing::error!(?err, "error sending close message");
562557
}
558+
559+
if matches!(lifecycle_res, Ok(LifecycleResult::ClientClose(_))) {
560+
ctx.op(pegboard::ops::actor::hibernating_request::delete::Input {
561+
actor_id: self.actor_id,
562+
gateway_id: self.shared_state.gateway_id(),
563+
request_id,
564+
})
565+
.await?;
566+
}
563567
}
564568

565569
// Send WebSocket close message to client
@@ -695,10 +699,27 @@ impl CustomServeTrait for PegboardGateway2 {
695699
.await?
696700
{
697701
tracing::debug!("exiting hibernating due to pending messages");
702+
tokio::try_join!(
703+
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
704+
actor_id: self.actor_id,
705+
gateway_id: self.shared_state.gateway_id(),
706+
request_id,
707+
}),
708+
self.shared_state.keepalive_hws(request_id),
709+
)?;
698710

699711
return Ok(HibernationResult::Continue);
700712
}
701713

714+
tokio::try_join!(
715+
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
716+
actor_id: self.actor_id,
717+
gateway_id: self.shared_state.gateway_id(),
718+
request_id,
719+
}),
720+
self.shared_state.keepalive_hws(request_id),
721+
)?;
722+
702723
// Start keepalive task
703724
let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(());
704725
let keepalive_handle = tokio::spawn(keepalive_task::task(
@@ -723,6 +744,17 @@ impl CustomServeTrait for PegboardGateway2 {
723744
let _ = keepalive_abort_tx.send(());
724745
let _ = keepalive_handle.await;
725746

747+
if matches!(res, Ok(HibernationResult::Continue)) {
748+
tokio::try_join!(
749+
ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
750+
actor_id: self.actor_id,
751+
gateway_id: self.shared_state.gateway_id(),
752+
request_id,
753+
}),
754+
self.shared_state.keepalive_hws(request_id),
755+
)?;
756+
}
757+
726758
let (delete_res, _) = tokio::join!(
727759
async {
728760
match &res {
@@ -831,9 +863,12 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
831863
Ok(Message::Binary(_)) | Ok(Message::Text(_)) => {
832864
return Ok(HibernationResult::Continue);
833865
}
834-
// We don't care about the close frame because we're currently hibernating; there is no
835-
// downstream to send the close frame to.
836-
Ok(Message::Close(_)) => return Ok(HibernationResult::Close),
866+
// Consume the close frame so the websocket stack can complete the
867+
// close handshake while the actor is hibernating.
868+
Ok(Message::Close(_)) => {
869+
pinned.try_next().await?;
870+
return Ok(HibernationResult::Close);
871+
}
837872
// Ignore rest
838873
_ => {
839874
pinned.try_next().await?;
@@ -845,6 +880,122 @@ async fn hibernate_ws(ws_rx: Arc<Mutex<WebSocketReceiver>>) -> Result<Hibernatio
845880
}
846881
}
847882

883+
fn drain_ready_tunnel_messages(
884+
request_id: protocol::RequestId,
885+
msg_rx: &mut mpsc::Receiver<protocol::ToRivetTunnelMessageKind>,
886+
) -> Vec<protocol::ToRivetTunnelMessageKind> {
887+
let mut messages = Vec::new();
888+
889+
loop {
890+
match msg_rx.try_recv() {
891+
Ok(message) => messages.push(message),
892+
Err(mpsc::error::TryRecvError::Empty) => break,
893+
Err(mpsc::error::TryRecvError::Disconnected) => break,
894+
}
895+
}
896+
897+
if !messages.is_empty() {
898+
tracing::debug!(
899+
request_id=%protocol::util::id_to_string(&request_id),
900+
message_count=messages.len(),
901+
"drained ready tunnel messages before websocket forwarding task"
902+
);
903+
}
904+
905+
messages
906+
}
907+
908+
async fn wait_for_envoy_websocket_open(
909+
msg_rx: &mut tokio::sync::mpsc::Receiver<protocol::ToRivetTunnelMessageKind>,
910+
drop_rx: &mut watch::Receiver<Option<crate::shared_state::MsgGcReason>>,
911+
stopped_sub: &mut message::SubscriptionHandle<pegboard::workflows::actor2::Stopped>,
912+
request_id: protocol::RequestId,
913+
websocket_open_timeout: Duration,
914+
stop_on_actor_stopped: bool,
915+
) -> Result<protocol::ToRivetWebSocketOpen> {
916+
let fut = async {
917+
loop {
918+
if stop_on_actor_stopped {
919+
tokio::select! {
920+
res = msg_rx.recv() => {
921+
if let Some(msg) = res {
922+
match msg {
923+
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
924+
return anyhow::Ok(msg);
925+
}
926+
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
927+
tracing::warn!(?close, "websocket closed before opening");
928+
return Err(WebSocketServiceUnavailable.build());
929+
}
930+
_ => {
931+
tracing::warn!(
932+
"received unexpected message while waiting for websocket open"
933+
);
934+
}
935+
}
936+
} else {
937+
tracing::warn!(
938+
request_id=%protocol::util::id_to_string(&request_id),
939+
"received no message response during ws init",
940+
);
941+
break;
942+
}
943+
}
944+
_ = stopped_sub.next() => {
945+
tracing::debug!("actor stopped while waiting for websocket open");
946+
return Err(WebSocketServiceUnavailable.build());
947+
}
948+
_ = drop_rx.changed() => {
949+
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
950+
return Err(WebSocketServiceUnavailable.build());
951+
}
952+
}
953+
} else {
954+
tokio::select! {
955+
res = msg_rx.recv() => {
956+
if let Some(msg) = res {
957+
match msg {
958+
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
959+
return anyhow::Ok(msg);
960+
}
961+
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
962+
tracing::warn!(?close, "websocket closed before opening");
963+
return Err(WebSocketServiceUnavailable.build());
964+
}
965+
_ => {
966+
tracing::warn!(
967+
"received unexpected message while waiting for websocket open"
968+
);
969+
}
970+
}
971+
} else {
972+
tracing::warn!(
973+
request_id=%protocol::util::id_to_string(&request_id),
974+
"received no message response during ws init",
975+
);
976+
break;
977+
}
978+
}
979+
_ = drop_rx.changed() => {
980+
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
981+
return Err(WebSocketServiceUnavailable.build());
982+
}
983+
}
984+
}
985+
}
986+
987+
Err(WebSocketServiceUnavailable.build())
988+
};
989+
990+
tokio::time::timeout(websocket_open_timeout, fut)
991+
.await
992+
.map_err(|_| {
993+
tracing::warn!("timed out waiting for websocket open from envoy");
994+
995+
WebSocketServiceUnavailable.build()
996+
})?
997+
}
998+
848999
#[derive(Debug)]
8491000
enum Metric {
8501001
HttpIngress(usize),

0 commit comments

Comments
 (0)