Skip to content

Commit c8f49d4

Browse files
committed
chore(rivetkit): rewrite work registry + fix waituntil not preventing sleep
1 parent 395aa83 commit c8f49d4

15 files changed

Lines changed: 654 additions & 244 deletions

File tree

rivetkit-rust/packages/rivetkit-core/src/actor/context.rs

Lines changed: 109 additions & 161 deletions
Large diffs are not rendered by default.

rivetkit-rust/packages/rivetkit-core/src/actor/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub(crate) mod work_registry;
2323
pub use action::ActionDispatchError;
2424
pub use config::{ActionDefinition, ActorConfig, ActorConfigOverrides, CanHibernateWebSocket};
2525
pub use connection::ConnHandle;
26-
pub use context::{ActorContext, KeepAwakeRegion, WebSocketCallbackRegion};
26+
pub use context::{ActorContext, ActorWorkRegion, KeepAwakeRegion, WebSocketCallbackRegion};
2727
pub use factory::{ActorEntryFn, ActorFactory};
2828
pub use kv::Kv;
2929
pub use lifecycle_hooks::{ActorEvents, ActorStart, Reply};
@@ -41,3 +41,4 @@ pub use task::{
4141
LifecycleEvent, LifecycleState,
4242
};
4343
pub use task_types::{ActorChildOutcome, ShutdownKind, StateMutationReason, UserTaskKind};
44+
pub use work_registry::{ActorWorkKind, ActorWorkPolicy};

rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs

Lines changed: 164 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@ use crate::actor::context::ActorContext;
1717
use crate::actor::task_types::ShutdownKind;
1818
#[cfg(feature = "wasm-runtime")]
1919
use crate::actor::work_registry::LocalShutdownTask;
20-
use crate::actor::work_registry::{CountGuard, RegionGuard, WorkRegistry};
20+
use crate::actor::work_registry::{ActorWorkKind, CountGuard, RegionGuard, WorkRegistry};
2121
#[cfg(feature = "wasm-runtime")]
2222
use crate::runtime::RuntimeSpawner;
23-
#[cfg(test)]
24-
use crate::time::sleep_until;
25-
use crate::time::{Instant, sleep};
23+
use crate::time::{Instant, sleep, sleep_until};
2624
#[cfg(test)]
2725
use crate::types::ActorKey;
2826
#[cfg(feature = "wasm-runtime")]
@@ -113,6 +111,10 @@ impl std::fmt::Debug for SleepState {
113111
"websocket_callback_count",
114112
&self.work.websocket_callback.load(),
115113
)
114+
.field(
115+
"disconnect_callback_count",
116+
&self.work.disconnect_callback.load(),
117+
)
116118
.finish()
117119
}
118120
}
@@ -381,7 +383,6 @@ impl ActorContext {
381383
}
382384
}
383385

384-
#[cfg(test)]
385386
pub(crate) async fn wait_for_shutdown_tasks(&self, deadline: Instant) -> bool {
386387
loop {
387388
let activity = self.sleep_activity_notify();
@@ -412,6 +413,29 @@ impl ActorContext {
412413
}
413414
}
414415

416+
pub async fn wait_for_tracked_shutdown_work(&self) -> bool {
417+
let shutdown_deadline = self.shutdown_deadline_token();
418+
tokio::select! {
419+
_ = self.wait_for_tracked_shutdown_work_drained() => true,
420+
_ = shutdown_deadline.cancelled() => false,
421+
}
422+
}
423+
424+
async fn wait_for_tracked_shutdown_work_drained(&self) {
425+
loop {
426+
let activity = self.sleep_activity_notify();
427+
let notified = activity.notified();
428+
tokio::pin!(notified);
429+
notified.as_mut().enable();
430+
431+
if self.shutdown_task_count() == 0 && self.websocket_callback_count() == 0 {
432+
return;
433+
}
434+
435+
notified.await;
436+
}
437+
}
438+
415439
pub(crate) async fn wait_for_http_requests_drained(&self, deadline: Instant) -> bool {
416440
let Some(counter) = self.http_request_counter() else {
417441
return true;
@@ -461,6 +485,119 @@ impl ActorContext {
461485
self.0.sleep.work.websocket_callback.load()
462486
}
463487

488+
pub(crate) fn disconnect_callback_region_state(&self) -> RegionGuard {
489+
self.0.sleep.work.disconnect_callback_guard()
490+
}
491+
492+
#[cfg(not(feature = "wasm-runtime"))]
493+
pub(crate) fn spawn_work_inner<F>(&self, kind: ActorWorkKind, fut: F) -> bool
494+
where
495+
F: Future<Output = ()> + Send + 'static,
496+
{
497+
if Handle::try_current().is_err() {
498+
tracing::warn!(kind = kind.label(), "actor work spawned without tokio runtime");
499+
return false;
500+
}
501+
502+
if self.0.sleep.work.teardown_started.load(Ordering::Acquire) {
503+
tracing::warn!(kind = kind.label(), "actor work spawned after teardown; aborting immediately");
504+
return false;
505+
}
506+
507+
let policy = kind.policy();
508+
let region = self.begin_work_region(kind);
509+
let ctx = self.clone();
510+
let task = async move {
511+
let _region = region;
512+
if policy.aborts_at_shutdown_deadline {
513+
let shutdown_deadline = ctx.shutdown_deadline_token();
514+
tokio::select! {
515+
_ = fut => {}
516+
_ = shutdown_deadline.cancelled() => {
517+
tracing::warn!(
518+
actor_id = %ctx.actor_id(),
519+
kind = kind.label(),
520+
reason = "shutdown_deadline_elapsed",
521+
"actor work cancelled by shutdown deadline"
522+
);
523+
}
524+
}
525+
} else {
526+
fut.await;
527+
}
528+
ctx.reset_sleep_timer();
529+
}
530+
.in_current_span();
531+
if policy.aborts_at_shutdown_deadline {
532+
self.0.sleep.work.shutdown_tasks.lock().spawn(task);
533+
} else {
534+
self.0
535+
.sleep
536+
.work
537+
.unabortable_shutdown_tasks
538+
.lock()
539+
.spawn(task);
540+
}
541+
self.reset_sleep_timer();
542+
true
543+
}
544+
545+
#[cfg(feature = "wasm-runtime")]
546+
pub(crate) fn spawn_work_inner<F>(&self, kind: ActorWorkKind, fut: F) -> bool
547+
where
548+
F: Future<Output = ()> + 'static,
549+
{
550+
let mut local_shutdown_tasks = self.0.sleep.work.local_shutdown_tasks.lock();
551+
if self.0.sleep.work.teardown_started.load(Ordering::Acquire) {
552+
tracing::warn!(kind = kind.label(), "actor work spawned after teardown; aborting immediately");
553+
return false;
554+
}
555+
556+
let policy = kind.policy();
557+
let region = self.begin_work_region(kind);
558+
let ctx = self.clone();
559+
let (complete_tx, complete_rx) = futures_oneshot::channel();
560+
let (abort_handle, abort_registration) = AbortHandle::new_pair();
561+
local_shutdown_tasks.push(LocalShutdownTask {
562+
abort_handle,
563+
complete_rx,
564+
aborts_at_shutdown_deadline: policy.aborts_at_shutdown_deadline,
565+
});
566+
drop(local_shutdown_tasks);
567+
let ctx_for_task = ctx.clone();
568+
wasm_bindgen_futures::spawn_local(
569+
async move {
570+
let task = async move {
571+
let _region = region;
572+
if policy.aborts_at_shutdown_deadline {
573+
let shutdown_deadline = ctx_for_task.shutdown_deadline_token();
574+
tokio::select! {
575+
_ = fut => {}
576+
_ = shutdown_deadline.cancelled() => {
577+
tracing::warn!(
578+
actor_id = %ctx_for_task.actor_id(),
579+
kind = kind.label(),
580+
reason = "shutdown_deadline_elapsed",
581+
"actor work cancelled by shutdown deadline"
582+
);
583+
}
584+
}
585+
} else {
586+
fut.await;
587+
}
588+
let _ = complete_tx.send(());
589+
ctx_for_task.reset_sleep_timer();
590+
};
591+
if Abortable::new(task, abort_registration).await.is_err() {
592+
ctx.reset_sleep_timer();
593+
}
594+
}
595+
.in_current_span(),
596+
);
597+
self.reset_sleep_timer();
598+
true
599+
}
600+
464601
#[cfg(not(feature = "wasm-runtime"))]
465602
pub(crate) fn track_shutdown_task<F>(&self, fut: F) -> bool
466603
where
@@ -519,6 +656,7 @@ impl ActorContext {
519656
local_shutdown_tasks.push(LocalShutdownTask {
520657
abort_handle,
521658
complete_rx,
659+
aborts_at_shutdown_deadline: true,
522660
});
523661
drop(local_shutdown_tasks);
524662
let ctx_for_task = ctx.clone();
@@ -605,7 +743,9 @@ impl ActorContext {
605743

606744
if abort_remaining {
607745
for task in local_shutdown_tasks {
608-
task.abort_handle.abort();
746+
if task.aborts_at_shutdown_deadline {
747+
task.abort_handle.abort();
748+
}
609749
if task.complete_rx.await.is_err() {
610750
tracing::debug!("aborted shutdown task during teardown");
611751
}
@@ -628,29 +768,35 @@ impl ActorContext {
628768

629769
#[cfg(not(feature = "wasm-runtime"))]
630770
loop {
631-
let mut shutdown_tasks = {
771+
let mut abortable_shutdown_tasks = {
632772
let mut guard = self.0.sleep.work.shutdown_tasks.lock();
633773
let taken = std::mem::take(&mut *guard);
634-
if taken.is_empty() {
774+
let mut unabortable_guard = self.0.sleep.work.unabortable_shutdown_tasks.lock();
775+
let unabortable_taken = std::mem::take(&mut *unabortable_guard);
776+
if taken.is_empty() && unabortable_taken.is_empty() {
635777
self.0
636778
.sleep
637779
.work
638780
.teardown_started
639781
.store(true, Ordering::Release);
640782
return;
641783
}
642-
taken
784+
(taken, unabortable_taken)
643785
};
644786

645-
if abort_remaining {
646-
shutdown_tasks.shutdown().await;
647-
} else {
648-
while let Some(result) = shutdown_tasks.join_next().await {
649-
if let Err(error) = result
650-
&& !error.is_cancelled()
651-
{
652-
tracing::error!(?error, "shutdown task join failed during teardown");
653-
}
787+
abortable_shutdown_tasks.0.shutdown().await;
788+
while let Some(result) = abortable_shutdown_tasks.0.join_next().await {
789+
if let Err(error) = result
790+
&& !error.is_cancelled()
791+
{
792+
tracing::error!(?error, "shutdown task join failed during teardown");
793+
}
794+
}
795+
while let Some(result) = abortable_shutdown_tasks.1.join_next().await {
796+
if let Err(error) = result
797+
&& !error.is_cancelled()
798+
{
799+
tracing::error!(?error, "shutdown task join failed during teardown");
654800
}
655801
}
656802
}

0 commit comments

Comments
 (0)