Skip to content

Commit d3d9f0e

Browse files
committed
fix(rivetkit): keep actors awake until keepAwake work finishes
1 parent b40eb33 commit d3d9f0e

14 files changed

Lines changed: 418 additions & 67 deletions

File tree

rivetkit-rust/packages/rivetkit-core/src/actor/context.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,12 @@ impl ActorContext {
619619
future.await
620620
}
621621

622+
pub fn keep_awake_region(&self) -> KeepAwakeRegion {
623+
KeepAwakeRegion {
624+
guard: Some(self.keep_awake_guard()),
625+
}
626+
}
627+
622628
pub async fn internal_keep_awake<F>(&self, future: F) -> F::Output
623629
where
624630
F: Future,
@@ -1327,7 +1333,7 @@ impl ActorContext {
13271333

13281334
fn keep_awake_guard(&self) -> KeepAwakeGuard {
13291335
let region = self
1330-
.keep_awake_region()
1336+
.keep_awake_region_state()
13311337
.with_log_fields("keep_awake", Some(self.actor_id().to_owned()));
13321338
let guard = KeepAwakeGuard::new(self.clone(), region);
13331339
self.reset_sleep_timer();
@@ -1639,6 +1645,10 @@ pub struct WebSocketCallbackRegion {
16391645
guard: Option<WebSocketCallbackGuard>,
16401646
}
16411647

1648+
pub struct KeepAwakeRegion {
1649+
guard: Option<KeepAwakeGuard>,
1650+
}
1651+
16421652
impl WebSocketCallbackGuard {
16431653
fn new(ctx: ActorContext, kind: UserTaskKind, region: RegionGuard) -> Self {
16441654
Self {
@@ -1665,6 +1675,13 @@ impl Drop for WebSocketCallbackRegion {
16651675
}
16661676
}
16671677

1678+
impl Drop for KeepAwakeRegion {
1679+
fn drop(&mut self) {
1680+
// Take the guard explicitly to mirror WebSocketCallbackRegion.
1681+
self.guard.take();
1682+
}
1683+
}
1684+
16681685
impl std::fmt::Debug for ActorContext {
16691686
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16701687
f.debug_struct("ActorContext")

rivetkit-rust/packages/rivetkit-core/src/actor/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub(crate) mod work_registry;
2323
pub use action::ActionDispatchError;
2424
pub use config::{ActionDefinition, ActorConfig, ActorConfigOverrides, CanHibernateWebSocket};
2525
pub use connection::ConnHandle;
26-
pub use context::{ActorContext, WebSocketCallbackRegion};
26+
pub use context::{ActorContext, KeepAwakeRegion, WebSocketCallbackRegion};
2727
pub use factory::{ActorEntryFn, ActorFactory};
2828
pub use kv::Kv;
2929
pub use lifecycle_hooks::{ActorEvents, ActorStart, Reply};

rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ impl ActorContext {
433433
}
434434
}
435435

436-
pub(crate) fn keep_awake_region(&self) -> RegionGuard {
436+
pub(crate) fn keep_awake_region_state(&self) -> RegionGuard {
437437
self.0.sleep.work.keep_awake_guard()
438438
}
439439

rivetkit-rust/packages/rivetkit-core/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ pub use actor::config::{
118118
ActionDefinition, ActorConfig, ActorConfigInput, ActorConfigOverrides, CanHibernateWebSocket,
119119
};
120120
pub use actor::connection::ConnHandle;
121-
pub use actor::context::{ActorContext, WebSocketCallbackRegion};
121+
pub use actor::context::{ActorContext, KeepAwakeRegion, WebSocketCallbackRegion};
122122
pub use actor::factory::{ActorEntryFn, ActorFactory};
123123
pub use actor::kv::Kv;
124124
pub use actor::lifecycle_hooks::{ActorEvents, ActorStart, Reply};

rivetkit-typescript/packages/rivetkit-napi/index.d.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ export declare class ActorContext {
228228
aborted(): boolean
229229
runHandlerActive(): boolean
230230
restartRunHandler(): void
231+
beginKeepAwake(): number
232+
endKeepAwake(regionId: number): void
233+
keepAwake(promise: Promise<any>): void
231234
beginWebsocketCallback(): number
232235
endWebsocketCallback(regionId: number): void
233236
abortSignal(): AbortSignal
@@ -237,7 +240,6 @@ export declare class ActorContext {
237240
disconnectConns(predicate: (...args: any[]) => any): Promise<void>
238241
broadcast(name: string, args: Buffer): void
239242
waitUntil(promise: Promise<any>): void
240-
keepAwake(promise: Promise<any>): Promise<any>
241243
registerTask(promise: Promise<any>): void
242244
runtimeState(): object
243245
clearRuntimeState(): void

rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use parking_lot::Mutex;
1919
use rivetkit_core::types::ActorKeySegment;
2020
use rivetkit_core::{
2121
ActorContext as CoreActorContext, ConnHandle as CoreConnHandle, Request as CoreRequest,
22-
RequestSaveOpts, StateDelta, WebSocketCallbackRegion,
22+
KeepAwakeRegion, RequestSaveOpts, StateDelta, WebSocketCallbackRegion,
2323
};
2424
use scc::HashMap as SccHashMap;
2525
use tokio::sync::mpsc::UnboundedSender;
@@ -59,7 +59,9 @@ struct ActorContextShared {
5959
task_sender: Mutex<Option<UnboundedSender<RegisteredTask>>>,
6060
runtime_state: Mutex<Option<Ref<()>>>,
6161
end_reason: Mutex<Option<EndReason>>,
62+
keep_awake_regions: Mutex<BTreeMap<u32, KeepAwakeRegion>>,
6263
websocket_callback_regions: Mutex<BTreeMap<u32, WebSocketCallbackRegion>>,
64+
next_keep_awake_region_id: AtomicU32,
6365
next_websocket_callback_region_id: AtomicU32,
6466
}
6567

@@ -464,6 +466,28 @@ impl ActorContext {
464466
self.shared.run_restart().map_err(napi_anyhow_error)
465467
}
466468

469+
#[napi]
470+
pub fn begin_keep_awake(&self) -> u32 {
471+
self.shared.begin_keep_awake(self.inner.keep_awake_region())
472+
}
473+
474+
#[napi]
475+
pub fn end_keep_awake(&self, region_id: u32) {
476+
self.shared.end_keep_awake(region_id);
477+
}
478+
479+
#[napi]
480+
pub fn keep_awake(&self, promise: Promise<serde_json::Value>) -> napi::Result<()> {
481+
let region = self.inner.keep_awake_region();
482+
self.inner.wait_until(async move {
483+
let _region = region;
484+
if let Err(error) = promise.await {
485+
tracing::warn!(?error, "actor keep_awake promise rejected");
486+
}
487+
});
488+
Ok(())
489+
}
490+
467491
#[napi]
468492
pub fn begin_websocket_callback(&self) -> u32 {
469493
self.shared
@@ -583,14 +607,6 @@ impl ActorContext {
583607
Ok(())
584608
}
585609

586-
#[napi]
587-
pub async fn keep_awake(
588-
&self,
589-
promise: Promise<serde_json::Value>,
590-
) -> napi::Result<serde_json::Value> {
591-
self.inner.keep_awake(promise).await
592-
}
593-
594610
#[napi]
595611
pub fn register_task(&self, promise: Promise<serde_json::Value>) -> napi::Result<()> {
596612
self.shared
@@ -708,6 +724,39 @@ impl ActorContextShared {
708724
id
709725
}
710726

727+
fn begin_keep_awake(&self, region: KeepAwakeRegion) -> u32 {
728+
let mut regions = self.keep_awake_regions.lock();
729+
let Some(id) = self.allocate_keep_awake_region_id(&regions) else {
730+
tracing::error!("failed to begin keep-awake region: no region ids available");
731+
return 0;
732+
};
733+
regions.insert(id, region);
734+
id
735+
}
736+
737+
fn end_keep_awake(&self, region_id: u32) {
738+
if region_id == 0 {
739+
return;
740+
}
741+
self.keep_awake_regions.lock().remove(&region_id);
742+
}
743+
744+
fn allocate_keep_awake_region_id(
745+
&self,
746+
regions: &BTreeMap<u32, KeepAwakeRegion>,
747+
) -> Option<u32> {
748+
for _ in 0..=u32::MAX {
749+
let next = self
750+
.next_keep_awake_region_id
751+
.fetch_add(1, Ordering::SeqCst)
752+
.wrapping_add(1);
753+
if next != 0 && !regions.contains_key(&next) {
754+
return Some(next);
755+
}
756+
}
757+
None
758+
}
759+
711760
fn end_websocket_callback(&self, region_id: u32) {
712761
self.websocket_callback_regions.lock().remove(&region_id);
713762
}
@@ -734,7 +783,9 @@ impl ActorContextShared {
734783
std::mem::forget(old);
735784
}
736785
*self.end_reason.lock() = None;
786+
*self.keep_awake_regions.lock() = BTreeMap::new();
737787
*self.websocket_callback_regions.lock() = BTreeMap::new();
788+
self.next_keep_awake_region_id.store(0, Ordering::SeqCst);
738789
self.next_websocket_callback_region_id
739790
.store(0, Ordering::SeqCst);
740791
}

rivetkit-typescript/packages/rivetkit-wasm/src/lib.rs

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use rivetkit_core::{
1616
EnqueueAndWaitOpts, ListOpts, QueueMessage, QueueNextBatchOpts, QueueSendResult,
1717
QueueSendStatus, QueueTryNextBatchOpts, QueueWaitOpts, Request, RequestSaveOpts, Response,
1818
RuntimeSpawner, SerializeStateReason, ServeConfig, ServerlessRequest, StateDelta, WebSocket,
19-
WebSocketCallbackRegion, WsMessage,
19+
KeepAwakeRegion, WebSocketCallbackRegion, WsMessage,
2020
};
2121
use scc::HashMap as SccHashMap;
2222
use tokio::sync::oneshot;
@@ -1058,7 +1058,9 @@ pub struct WasmActorContext {
10581058
inner: rivetkit_core::ActorContext,
10591059
callbacks: WasmCallbacks,
10601060
runtime_state: JsValue,
1061+
keep_awake_regions: Rc<RefCell<HashMap<u32, KeepAwakeRegion>>>,
10611062
websocket_callback_regions: Rc<RefCell<HashMap<u32, WebSocketCallbackRegion>>>,
1063+
next_keep_awake_region_id: Rc<Cell<u32>>,
10621064
next_websocket_callback_region_id: Rc<Cell<u32>>,
10631065
}
10641066

@@ -1068,7 +1070,9 @@ impl WasmActorContext {
10681070
inner,
10691071
callbacks,
10701072
runtime_state: Object::new().into(),
1073+
keep_awake_regions: Rc::new(RefCell::new(HashMap::new())),
10711074
websocket_callback_regions: Rc::new(RefCell::new(HashMap::new())),
1075+
next_keep_awake_region_id: Rc::new(Cell::new(0)),
10721076
next_websocket_callback_region_id: Rc::new(Cell::new(0)),
10731077
}
10741078
}
@@ -1121,6 +1125,20 @@ impl WasmActorContext {
11211125
}
11221126
}
11231127

1128+
fn allocate_keep_awake_region_id(
1129+
&self,
1130+
regions: &HashMap<u32, KeepAwakeRegion>,
1131+
) -> Option<u32> {
1132+
for _ in 0..=u32::MAX {
1133+
let next = self.next_keep_awake_region_id.get().wrapping_add(1);
1134+
self.next_keep_awake_region_id.set(next);
1135+
if next != 0 && !regions.contains_key(&next) {
1136+
return Some(next);
1137+
}
1138+
}
1139+
None
1140+
}
1141+
11241142
#[wasm_bindgen]
11251143
pub fn kv(&self) -> WasmKv {
11261144
WasmKv {
@@ -1359,11 +1377,19 @@ impl WasmActorContext {
13591377
}
13601378

13611379
#[wasm_bindgen(js_name = keepAwake)]
1362-
pub async fn keep_awake(&self, promise: Promise) -> Result<JsValue, JsValue> {
1363-
self.inner
1364-
.keep_awake(JsFuture::from(promise))
1365-
.await
1366-
.map_err(|error| error)
1380+
pub fn keep_awake(&self, promise: Promise) {
1381+
console_error("keepAwake binding is deprecated; use beginKeepAwake/endKeepAwake");
1382+
let region = self.inner.keep_awake_region();
1383+
let actor_id = self.inner.actor_id().to_owned();
1384+
self.inner.register_task(async move {
1385+
let _region = region;
1386+
if let Err(error) = JsFuture::from(promise).await {
1387+
console_error(&format!(
1388+
"actor keepAwake promise rejected for actor {actor_id}: {}",
1389+
js_value_to_anyhow(error)
1390+
));
1391+
}
1392+
});
13671393
}
13681394

13691395
#[wasm_bindgen(js_name = registerTask)]
@@ -1384,6 +1410,25 @@ impl WasmActorContext {
13841410
start_run_handler(&self.callbacks, self);
13851411
}
13861412

1413+
#[wasm_bindgen(js_name = beginKeepAwake)]
1414+
pub fn begin_keep_awake(&self) -> u32 {
1415+
let mut regions = self.keep_awake_regions.borrow_mut();
1416+
let Some(region_id) = self.allocate_keep_awake_region_id(&regions) else {
1417+
console_error("failed to begin keep-awake region: no region ids available");
1418+
return 0;
1419+
};
1420+
regions.insert(region_id, self.inner.keep_awake_region());
1421+
region_id
1422+
}
1423+
1424+
#[wasm_bindgen(js_name = endKeepAwake)]
1425+
pub fn end_keep_awake(&self, region_id: u32) {
1426+
if region_id == 0 {
1427+
return;
1428+
}
1429+
self.keep_awake_regions.borrow_mut().remove(&region_id);
1430+
}
1431+
13871432
#[wasm_bindgen(js_name = beginWebsocketCallback)]
13881433
pub fn begin_websocket_callback(&self) -> u32 {
13891434
let mut regions = self.websocket_callback_regions.borrow_mut();

rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ import {
126126
sleepWaitUntilState,
127127
sleepWithRawWs,
128128
sleepWsActiveDbExceedsGrace,
129+
sleepKeepAwakeUntilIdle,
129130
} from "./sleep-db";
130131
import { saveStateActor, saveStateObserver } from "./save-state";
131132
import { lifecycleObserver, startStopRaceActor } from "./start-stop-race";
@@ -221,6 +222,7 @@ export const registry = setup({
221222
sleepWsMessageExceedsGrace,
222223
sleepWsConcurrentDbExceedsGrace,
223224
sleepWsActiveDbExceedsGrace,
225+
sleepKeepAwakeUntilIdle,
224226
saveStateActor,
225227
saveStateObserver,
226228
// From error-handling.ts

0 commit comments

Comments
 (0)