Skip to content

Commit 2497d7d

Browse files
committed
feat: [US-008] - [Explicit wake-time liveness query for hibernated conns (follow-up from US-007 audit)]
1 parent dd1a6d6 commit 2497d7d

5 files changed

Lines changed: 254 additions & 39 deletions

File tree

rivetkit-rust/packages/rivetkit-core/src/actor/context.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::BTreeSet;
12
use std::future::Future;
23
use std::sync::Arc;
34
use std::sync::Weak;
@@ -66,6 +67,8 @@ pub(crate) struct ActorContextInner {
6667
inspector: std::sync::RwLock<Option<Inspector>>,
6768
actor_events: std::sync::RwLock<Option<mpsc::Sender<ActorEvent>>>,
6869
lifecycle_events: std::sync::RwLock<Option<mpsc::Sender<LifecycleEvent>>>,
70+
hibernated_connection_liveness_override:
71+
std::sync::RwLock<Option<BTreeSet<(Vec<u8>, Vec<u8>)>>>,
6972
lifecycle_event_inbox_capacity: usize,
7073
metrics: ActorMetrics,
7174
diagnostics: ActorDiagnostics,
@@ -206,6 +209,7 @@ impl ActorContext {
206209
inspector: std::sync::RwLock::new(None),
207210
actor_events: std::sync::RwLock::new(None),
208211
lifecycle_events: std::sync::RwLock::new(None),
212+
hibernated_connection_liveness_override: std::sync::RwLock::new(None),
209213
lifecycle_event_inbox_capacity,
210214
metrics,
211215
diagnostics,
@@ -801,6 +805,47 @@ impl ActorContext {
801805
self.request_save(false);
802806
}
803807

808+
pub(crate) fn hibernated_connection_is_live(
809+
&self,
810+
gateway_id: &[u8],
811+
request_id: &[u8],
812+
) -> Result<bool> {
813+
if let Some(override_pairs) = self
814+
.0
815+
.hibernated_connection_liveness_override
816+
.read()
817+
.expect("hibernated connection liveness override lock poisoned")
818+
.as_ref()
819+
{
820+
return Ok(
821+
override_pairs.contains(&(gateway_id.to_vec(), request_id.to_vec()))
822+
);
823+
}
824+
825+
// TODO(hibernation-liveness): Replace this with an envoy-client query for
826+
// active gateway_id/request_id membership once that API exists. `EnvoyHandle`
827+
// currently exposes actor metadata and hibernation restore/ack hooks, but not
828+
// a way to ask whether a persisted hibernating request is still live.
829+
todo!(
830+
"explicit wake-time hibernated connection liveness query requires envoy-client gateway_id/request_id membership support"
831+
);
832+
}
833+
834+
#[cfg(test)]
835+
pub(crate) fn set_hibernated_connection_liveness_override<I>(
836+
&self,
837+
pairs: I,
838+
) where
839+
I: IntoIterator<Item = (Vec<u8>, Vec<u8>)>,
840+
{
841+
*self
842+
.0
843+
.hibernated_connection_liveness_override
844+
.write()
845+
.expect("hibernated connection liveness override lock poisoned") =
846+
Some(pairs.into_iter().collect());
847+
}
848+
804849
fn prepare_state_deltas(
805850
&self,
806851
deltas: Vec<StateDelta>,

rivetkit-rust/packages/rivetkit-core/src/actor/task.rs

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::sync::Arc;
55
use anyhow::{Context, Result, anyhow};
66
use futures::FutureExt;
77
use tokio::sync::{mpsc, oneshot};
8-
use tokio::task::{JoinError, JoinHandle, yield_now};
8+
use tokio::task::{JoinError, JoinHandle};
99
use tokio::time::{Duration, Instant, sleep, sleep_until, timeout};
1010

1111
use crate::actor::action::ActionDispatchError;
@@ -506,7 +506,10 @@ impl ActorTask {
506506
.restore_hibernatable_connections()
507507
.await
508508
.context("restore hibernatable connections")?;
509-
self.settle_hibernated_connections().await;
509+
self
510+
.settle_hibernated_connections()
511+
.await
512+
.context("settle hibernated connections")?;
510513
self.ctx.schedule().sync_future_alarm_logged();
511514

512515
self.transition_to(LifecycleState::Started);
@@ -565,30 +568,39 @@ impl ActorTask {
565568
}));
566569
}
567570

568-
async fn settle_hibernated_connections(&self) {
569-
let mut last_count = self
571+
async fn settle_hibernated_connections(&self) -> Result<()> {
572+
let mut dead_conn_ids = Vec::new();
573+
574+
for conn in self
570575
.ctx
571576
.conns()
572577
.into_iter()
573578
.filter(|conn| conn.is_hibernatable())
574-
.count();
575-
if last_count == 0 {
576-
return;
579+
{
580+
let Some(hibernation) = conn.hibernation() else {
581+
dead_conn_ids.push(conn.id().to_owned());
582+
continue;
583+
};
584+
if self
585+
.ctx
586+
.hibernated_connection_is_live(
587+
&hibernation.gateway_id,
588+
&hibernation.request_id,
589+
)?
590+
{
591+
continue;
592+
}
593+
dead_conn_ids.push(conn.id().to_owned());
577594
}
578595

579-
for _ in 0..8 {
580-
yield_now().await;
581-
let count = self
596+
for conn_id in dead_conn_ids {
597+
self
582598
.ctx
583-
.conns()
584-
.into_iter()
585-
.filter(|conn| conn.is_hibernatable())
586-
.count();
587-
if count == last_count {
588-
break;
589-
}
590-
last_count = count;
599+
.request_hibernation_transport_removal(conn_id.clone());
600+
self.ctx.remove_conn(&conn_id);
591601
}
602+
603+
Ok(())
592604
}
593605

594606
async fn fire_due_alarms(&mut self) -> Result<()> {

rivetkit-rust/packages/rivetkit-core/tests/modules/task.rs

Lines changed: 127 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,17 @@ mod moved_tests {
9999
conn
100100
}
101101

102+
fn configure_live_hibernated_pairs(
103+
ctx: &ActorContext,
104+
pairs: impl IntoIterator<Item = (&'static [u8], &'static [u8])>,
105+
) {
106+
ctx.set_hibernated_connection_liveness_override(
107+
pairs
108+
.into_iter()
109+
.map(|(gateway_id, request_id)| (gateway_id.to_vec(), request_id.to_vec())),
110+
);
111+
}
112+
102113
#[tokio::test]
103114
async fn save_tick_respects_debounce_and_immediate_requests() {
104115
let ctx = new_with_kv("actor-1", "task-save", Vec::new(), "local", new_in_memory());
@@ -272,6 +283,7 @@ mod moved_tests {
272283
request_headers: BTreeMap::from([("x-test".to_owned(), "true".to_owned())]),
273284
}));
274285
ctx.add_conn(hibernating_conn.clone());
286+
configure_live_hibernated_pairs(&ctx, [(b"gateway".as_slice(), b"request".as_slice())]);
275287

276288
let hibernating_conn_id = hibernating_conn.id().to_owned();
277289
let factory = Arc::new(ActorFactory::new(
@@ -381,6 +393,7 @@ mod moved_tests {
381393
request_headers: BTreeMap::new(),
382394
}));
383395
ctx.add_conn(hibernating_conn.clone());
396+
configure_live_hibernated_pairs(&ctx, [(b"gateway".as_slice(), b"request".as_slice())]);
384397

385398
let hibernating_conn_id = hibernating_conn.id().to_owned();
386399
let factory = Arc::new(ActorFactory::new(
@@ -493,6 +506,7 @@ mod moved_tests {
493506
let (_dispatch_tx, dispatch_rx) = mpsc::channel(4);
494507
let (events_tx, events_rx) = mpsc::channel(4);
495508
ctx.configure_lifecycle_events(Some(events_tx));
509+
configure_live_hibernated_pairs(&ctx, [(b"gateway".as_slice(), b"request".as_slice())]);
496510
let (started_tx, started_rx) = oneshot::channel();
497511
let started_tx = Arc::new(Mutex::new(Some(started_tx)));
498512
let factory = Arc::new(ActorFactory::new(Default::default(), move |start| {
@@ -871,23 +885,10 @@ mod moved_tests {
871885
let (_dispatch_tx, dispatch_rx) = mpsc::channel(4);
872886
let (events_tx, events_rx) = mpsc::channel(4);
873887
ctx.configure_lifecycle_events(Some(events_tx));
874-
875-
let stale_ctx = ctx.clone();
876-
tokio::spawn(async move {
877-
for _ in 0..50 {
878-
if stale_ctx.conns().len() == 2 {
879-
break;
880-
}
881-
sleep(Duration::from_millis(5)).await;
882-
}
883-
if let Some(conn) = stale_ctx
884-
.conns()
885-
.into_iter()
886-
.find(|conn| conn.id() == "conn-stale")
887-
{
888-
let _ = conn.disconnect(Some("stale")).await;
889-
}
890-
});
888+
configure_live_hibernated_pairs(
889+
&ctx,
890+
[(b"gateway-live".as_slice(), b"request-live".as_slice())],
891+
);
891892

892893
let (started_tx, started_rx) = oneshot::channel();
893894
let started_tx = Arc::new(Mutex::new(Some(started_tx)));
@@ -971,4 +972,113 @@ mod moved_tests {
971972
.await
972973
.expect("sleep stop should succeed");
973974
}
975+
976+
#[tokio::test]
977+
async fn wake_start_reaps_dead_hibernated_connections_without_engine_registration() {
978+
let kv = new_in_memory();
979+
let seed_ctx =
980+
new_with_kv("actor-wake-dead", "task-wake", Vec::new(), "local", kv.clone());
981+
let dead_conn = ConnHandle::new("conn-dead", Vec::new(), Vec::new(), true);
982+
dead_conn.configure_hibernation(Some(HibernatableConnectionMetadata {
983+
gateway_id: b"gateway-dead".to_vec(),
984+
request_id: b"request-dead".to_vec(),
985+
server_message_index: 7,
986+
client_message_index: 11,
987+
request_path: "/ws".to_owned(),
988+
request_headers: BTreeMap::new(),
989+
}));
990+
seed_ctx.add_conn(dead_conn.clone());
991+
seed_ctx
992+
.save_state(vec![StateDelta::ConnHibernation {
993+
conn: dead_conn.id().into(),
994+
bytes: vec![9, 8, 7],
995+
}])
996+
.await
997+
.expect("seed hibernation should persist");
998+
999+
let ctx =
1000+
new_with_kv("actor-wake-dead", "task-wake", Vec::new(), "local", kv.clone());
1001+
let (_lifecycle_tx, lifecycle_rx) = mpsc::channel(4);
1002+
let (_dispatch_tx, dispatch_rx) = mpsc::channel(4);
1003+
let (events_tx, events_rx) = mpsc::channel(4);
1004+
ctx.configure_lifecycle_events(Some(events_tx));
1005+
ctx.set_hibernated_connection_liveness_override(std::iter::empty());
1006+
1007+
let (started_tx, started_rx) = oneshot::channel();
1008+
let started_tx = Arc::new(Mutex::new(Some(started_tx)));
1009+
let factory = Arc::new(ActorFactory::new(Default::default(), move |start| {
1010+
let started_tx = started_tx.clone();
1011+
Box::pin(async move {
1012+
let mut events = start.events;
1013+
started_tx
1014+
.lock()
1015+
.expect("started sender lock poisoned")
1016+
.take()
1017+
.expect("started sender should exist")
1018+
.send(start.hibernated.into_iter().map(|(conn, _)| conn.id().to_owned()).collect::<Vec<_>>())
1019+
.expect("started info should send");
1020+
while let Some(event) = events.recv().await {
1021+
match event {
1022+
ActorEvent::SaveTick { reply } => {
1023+
reply.send(Ok(Vec::new()));
1024+
}
1025+
ActorEvent::Sleep { reply } | ActorEvent::Destroy { reply } => {
1026+
reply.send(Ok(Vec::new()));
1027+
break;
1028+
}
1029+
ActorEvent::ConnectionOpen { .. } => {
1030+
panic!("dead hibernated connection should not refire ConnectionOpen");
1031+
}
1032+
_ => {}
1033+
}
1034+
}
1035+
Ok(())
1036+
})
1037+
}));
1038+
1039+
let mut task = ActorTask::new(
1040+
"actor-wake-dead".into(),
1041+
0,
1042+
lifecycle_rx,
1043+
dispatch_rx,
1044+
events_rx,
1045+
factory,
1046+
ctx.clone(),
1047+
None,
1048+
None,
1049+
);
1050+
let (start_tx, start_rx) = oneshot::channel();
1051+
task
1052+
.handle_lifecycle(LifecycleCommand::Start { reply: start_tx })
1053+
.await;
1054+
start_rx
1055+
.await
1056+
.expect("start reply should send")
1057+
.expect("start should succeed");
1058+
1059+
assert_eq!(started_rx.await.expect("start info should send"), Vec::<String>::new());
1060+
assert!(ctx.conns().is_empty());
1061+
1062+
task
1063+
.handle_event(crate::actor::task::LifecycleEvent::SaveRequested {
1064+
immediate: false,
1065+
})
1066+
.await;
1067+
task.on_state_save_tick().await;
1068+
1069+
let last_batch = kv
1070+
.test_last_apply_batch()
1071+
.expect("last apply batch should be recorded");
1072+
assert_eq!(last_batch.deletes, vec![make_connection_key("conn-dead")]);
1073+
assert!(
1074+
kv.get(&make_connection_key("conn-dead"))
1075+
.await
1076+
.expect("persisted connection lookup should succeed")
1077+
.is_none()
1078+
);
1079+
1080+
task.handle_stop(StopReason::Sleep)
1081+
.await
1082+
.expect("sleep stop should succeed");
1083+
}
9741084
}

0 commit comments

Comments
 (0)