Skip to content

Commit 61f01e4

Browse files
committed
chore(rivetkit): rewrite work registry + fix waituntil not preventing sleep
1 parent 395aa83 commit 61f01e4

15 files changed

Lines changed: 641 additions & 244 deletions

File tree

rivetkit-rust/packages/rivetkit-core/src/actor/context.rs

Lines changed: 109 additions & 161 deletions
Large diffs are not rendered by default.

rivetkit-rust/packages/rivetkit-core/src/actor/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub(crate) mod work_registry;
2323
pub use action::ActionDispatchError;
2424
pub use config::{ActionDefinition, ActorConfig, ActorConfigOverrides, CanHibernateWebSocket};
2525
pub use connection::ConnHandle;
26-
pub use context::{ActorContext, KeepAwakeRegion, WebSocketCallbackRegion};
26+
pub use context::{ActorContext, ActorWorkRegion, KeepAwakeRegion, WebSocketCallbackRegion};
2727
pub use factory::{ActorEntryFn, ActorFactory};
2828
pub use kv::Kv;
2929
pub use lifecycle_hooks::{ActorEvents, ActorStart, Reply};
@@ -41,3 +41,4 @@ pub use task::{
4141
LifecycleEvent, LifecycleState,
4242
};
4343
pub use task_types::{ActorChildOutcome, ShutdownKind, StateMutationReason, UserTaskKind};
44+
pub use work_registry::{ActorWorkKind, ActorWorkPolicy};

rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs

Lines changed: 151 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::sync::Arc;
66
#[cfg(test)]
77
use std::sync::atomic::AtomicUsize as TestAtomicUsize;
88
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9+
use std::time::Duration;
910
#[cfg(not(feature = "wasm-runtime"))]
1011
use tokio::runtime::Handle;
1112
use tokio::sync::Notify;
@@ -17,12 +18,10 @@ use crate::actor::context::ActorContext;
1718
use crate::actor::task_types::ShutdownKind;
1819
#[cfg(feature = "wasm-runtime")]
1920
use crate::actor::work_registry::LocalShutdownTask;
20-
use crate::actor::work_registry::{CountGuard, RegionGuard, WorkRegistry};
21+
use crate::actor::work_registry::{ActorWorkKind, CountGuard, RegionGuard, WorkRegistry};
2122
#[cfg(feature = "wasm-runtime")]
2223
use crate::runtime::RuntimeSpawner;
23-
#[cfg(test)]
24-
use crate::time::sleep_until;
25-
use crate::time::{Instant, sleep};
24+
use crate::time::{Instant, sleep, sleep_until};
2625
#[cfg(test)]
2726
use crate::types::ActorKey;
2827
#[cfg(feature = "wasm-runtime")]
@@ -113,6 +112,10 @@ impl std::fmt::Debug for SleepState {
113112
"websocket_callback_count",
114113
&self.work.websocket_callback.load(),
115114
)
115+
.field(
116+
"disconnect_callback_count",
117+
&self.work.disconnect_callback.load(),
118+
)
116119
.finish()
117120
}
118121
}
@@ -381,7 +384,6 @@ impl ActorContext {
381384
}
382385
}
383386

384-
#[cfg(test)]
385387
pub(crate) async fn wait_for_shutdown_tasks(&self, deadline: Instant) -> bool {
386388
loop {
387389
let activity = self.sleep_activity_notify();
@@ -412,6 +414,15 @@ impl ActorContext {
412414
}
413415
}
414416

417+
pub async fn wait_for_tracked_shutdown_work(&self) -> bool {
418+
let shutdown_deadline = self.shutdown_deadline_token();
419+
let deadline = Instant::now() + Duration::from_secs(60 * 60 * 24 * 365);
420+
tokio::select! {
421+
result = self.wait_for_shutdown_tasks(deadline) => result,
422+
_ = shutdown_deadline.cancelled() => false,
423+
}
424+
}
425+
415426
pub(crate) async fn wait_for_http_requests_drained(&self, deadline: Instant) -> bool {
416427
let Some(counter) = self.http_request_counter() else {
417428
return true;
@@ -461,6 +472,119 @@ impl ActorContext {
461472
self.0.sleep.work.websocket_callback.load()
462473
}
463474

475+
pub(crate) fn disconnect_callback_region_state(&self) -> RegionGuard {
476+
self.0.sleep.work.disconnect_callback_guard()
477+
}
478+
479+
#[cfg(not(feature = "wasm-runtime"))]
480+
pub(crate) fn spawn_work_inner<F>(&self, kind: ActorWorkKind, fut: F) -> bool
481+
where
482+
F: Future<Output = ()> + Send + 'static,
483+
{
484+
if Handle::try_current().is_err() {
485+
tracing::warn!(kind = kind.label(), "actor work spawned without tokio runtime");
486+
return false;
487+
}
488+
489+
if self.0.sleep.work.teardown_started.load(Ordering::Acquire) {
490+
tracing::warn!(kind = kind.label(), "actor work spawned after teardown; aborting immediately");
491+
return false;
492+
}
493+
494+
let policy = kind.policy();
495+
let region = self.begin_work_region(kind);
496+
let ctx = self.clone();
497+
let task = async move {
498+
let _region = region;
499+
if policy.aborts_at_shutdown_deadline {
500+
let shutdown_deadline = ctx.shutdown_deadline_token();
501+
tokio::select! {
502+
_ = fut => {}
503+
_ = shutdown_deadline.cancelled() => {
504+
tracing::warn!(
505+
actor_id = %ctx.actor_id(),
506+
kind = kind.label(),
507+
reason = "shutdown_deadline_elapsed",
508+
"actor work cancelled by shutdown deadline"
509+
);
510+
}
511+
}
512+
} else {
513+
fut.await;
514+
}
515+
ctx.reset_sleep_timer();
516+
}
517+
.in_current_span();
518+
if policy.aborts_at_shutdown_deadline {
519+
self.0.sleep.work.shutdown_tasks.lock().spawn(task);
520+
} else {
521+
self.0
522+
.sleep
523+
.work
524+
.unabortable_shutdown_tasks
525+
.lock()
526+
.spawn(task);
527+
}
528+
self.reset_sleep_timer();
529+
true
530+
}
531+
532+
#[cfg(feature = "wasm-runtime")]
533+
pub(crate) fn spawn_work_inner<F>(&self, kind: ActorWorkKind, fut: F) -> bool
534+
where
535+
F: Future<Output = ()> + 'static,
536+
{
537+
let mut local_shutdown_tasks = self.0.sleep.work.local_shutdown_tasks.lock();
538+
if self.0.sleep.work.teardown_started.load(Ordering::Acquire) {
539+
tracing::warn!(kind = kind.label(), "actor work spawned after teardown; aborting immediately");
540+
return false;
541+
}
542+
543+
let policy = kind.policy();
544+
let region = self.begin_work_region(kind);
545+
let ctx = self.clone();
546+
let (complete_tx, complete_rx) = futures_oneshot::channel();
547+
let (abort_handle, abort_registration) = AbortHandle::new_pair();
548+
local_shutdown_tasks.push(LocalShutdownTask {
549+
abort_handle,
550+
complete_rx,
551+
aborts_at_shutdown_deadline: policy.aborts_at_shutdown_deadline,
552+
});
553+
drop(local_shutdown_tasks);
554+
let ctx_for_task = ctx.clone();
555+
wasm_bindgen_futures::spawn_local(
556+
async move {
557+
let task = async move {
558+
let _region = region;
559+
if policy.aborts_at_shutdown_deadline {
560+
let shutdown_deadline = ctx_for_task.shutdown_deadline_token();
561+
tokio::select! {
562+
_ = fut => {}
563+
_ = shutdown_deadline.cancelled() => {
564+
tracing::warn!(
565+
actor_id = %ctx_for_task.actor_id(),
566+
kind = kind.label(),
567+
reason = "shutdown_deadline_elapsed",
568+
"actor work cancelled by shutdown deadline"
569+
);
570+
}
571+
}
572+
} else {
573+
fut.await;
574+
}
575+
let _ = complete_tx.send(());
576+
ctx_for_task.reset_sleep_timer();
577+
};
578+
if Abortable::new(task, abort_registration).await.is_err() {
579+
ctx.reset_sleep_timer();
580+
}
581+
}
582+
.in_current_span(),
583+
);
584+
self.reset_sleep_timer();
585+
true
586+
}
587+
464588
#[cfg(not(feature = "wasm-runtime"))]
465589
pub(crate) fn track_shutdown_task<F>(&self, fut: F) -> bool
466590
where
@@ -519,6 +643,7 @@ impl ActorContext {
519643
local_shutdown_tasks.push(LocalShutdownTask {
520644
abort_handle,
521645
complete_rx,
646+
aborts_at_shutdown_deadline: true,
522647
});
523648
drop(local_shutdown_tasks);
524649
let ctx_for_task = ctx.clone();
@@ -605,7 +730,9 @@ impl ActorContext {
605730

606731
if abort_remaining {
607732
for task in local_shutdown_tasks {
608-
task.abort_handle.abort();
733+
if task.aborts_at_shutdown_deadline {
734+
task.abort_handle.abort();
735+
}
609736
if task.complete_rx.await.is_err() {
610737
tracing::debug!("aborted shutdown task during teardown");
611738
}
@@ -628,29 +755,35 @@ impl ActorContext {
628755

629756
#[cfg(not(feature = "wasm-runtime"))]
630757
loop {
631-
let mut shutdown_tasks = {
758+
let mut abortable_shutdown_tasks = {
632759
let mut guard = self.0.sleep.work.shutdown_tasks.lock();
633760
let taken = std::mem::take(&mut *guard);
634-
if taken.is_empty() {
761+
let mut unabortable_guard = self.0.sleep.work.unabortable_shutdown_tasks.lock();
762+
let unabortable_taken = std::mem::take(&mut *unabortable_guard);
763+
if taken.is_empty() && unabortable_taken.is_empty() {
635764
self.0
636765
.sleep
637766
.work
638767
.teardown_started
639768
.store(true, Ordering::Release);
640769
return;
641770
}
642-
taken
771+
(taken, unabortable_taken)
643772
};
644773

645-
if abort_remaining {
646-
shutdown_tasks.shutdown().await;
647-
} else {
648-
while let Some(result) = shutdown_tasks.join_next().await {
649-
if let Err(error) = result
650-
&& !error.is_cancelled()
651-
{
652-
tracing::error!(?error, "shutdown task join failed during teardown");
653-
}
774+
abortable_shutdown_tasks.0.shutdown().await;
775+
while let Some(result) = abortable_shutdown_tasks.0.join_next().await {
776+
if let Err(error) = result
777+
&& !error.is_cancelled()
778+
{
779+
tracing::error!(?error, "shutdown task join failed during teardown");
780+
}
781+
}
782+
while let Some(result) = abortable_shutdown_tasks.1.join_next().await {
783+
if let Err(error) = result
784+
&& !error.is_cancelled()
785+
{
786+
tracing::error!(?error, "shutdown task join failed during teardown");
654787
}
655788
}
656789
}

0 commit comments

Comments
 (0)