diff --git a/engine/sdks/rust/envoy-client/src/actor.rs b/engine/sdks/rust/envoy-client/src/actor.rs index 35e2214806..5db36395f1 100644 --- a/engine/sdks/rust/envoy-client/src/actor.rs +++ b/engine/sdks/rust/envoy-client/src/actor.rs @@ -9,6 +9,7 @@ use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::sync::oneshot::error::TryRecvError; use tokio::task::{JoinError, JoinSet}; +use tracing::Instrument; use crate::config::{HttpRequest, HttpResponse, WebSocketMessage}; use crate::connection::ws_send; @@ -144,6 +145,14 @@ pub fn create_actor( (tx, active_http_request_count) } +#[tracing::instrument( + skip_all, + fields( + actor_id = %actor_id, + generation = generation, + actor_key = %config.key.as_deref().unwrap_or(""), + ), +)] async fn actor_inner( shared: Arc, actor_id: String, @@ -194,7 +203,7 @@ async fn actor_inner( .await; if let Err(error) = start_result { - tracing::error!(actor_id = %ctx.actor_id, ?error, "actor start failed"); + tracing::error!(?error, "actor start failed"); send_event( &mut ctx, protocol::Event::EventActorStateUpdate(protocol::EventActorStateUpdate { @@ -209,7 +218,7 @@ async fn actor_inner( if let Some(meta_entries) = handle.take_pending_hibernation_restore(&actor_id) { if let Err(error) = handle_hws_restore(&mut ctx, &handle, meta_entries).await { - tracing::error!(actor_id = %ctx.actor_id, ?error, "actor hibernation restore failed"); + tracing::error!(?error, "actor hibernation restore failed"); send_event( &mut ctx, protocol::Event::EventActorStateUpdate(protocol::EventActorStateUpdate { @@ -241,7 +250,7 @@ async fn actor_inner( } } => { if let Some(result) = maybe_task { - handle_http_request_task_result(&ctx, result); + handle_http_request_task_result(result); } } msg = async { @@ -275,7 +284,6 @@ async fn actor_inner( } => { if pending_stop.is_some() { tracing::warn!( - actor_id = %ctx.actor_id, command_idx, "ignoring duplicate stop while actor teardown is in progress" ); @@ -294,7 +302,6 @@ async fn actor_inner( ToActor::Lost => { if pending_stop.is_some() { tracing::warn!( - actor_id = %ctx.actor_id, "ignoring lost signal while actor teardown is in progress" ); continue; @@ -368,7 +375,7 @@ async fn actor_inner( } abort_and_join_http_request_tasks(&mut ctx, &mut http_request_tasks).await; - tracing::debug!(actor_id = %ctx.actor_id, "envoy actor stopped"); + tracing::debug!("envoy actor stopped"); } fn send_event(ctx: &mut ActorContext, inner: protocol::Event) { @@ -409,7 +416,7 @@ async fn begin_stop( .await; if let Err(error) = stop_result { - tracing::error!(actor_id = %ctx.actor_id, ?error, "actor stop failed"); + tracing::error!(?error, "actor stop failed"); stop_code = protocol::StopCode::Error; if stop_message.is_none() { stop_message = Some(format!("{error:#}")); @@ -451,7 +458,6 @@ fn finalize_stop( } Err(error) => { tracing::warn!( - actor_id = %ctx.actor_id, ?error, "actor stop completion handle dropped before signaling teardown result" ); @@ -467,7 +473,7 @@ fn send_stopped_event_for_result( stop_result: anyhow::Result<()>, ) { if let Err(error) = stop_result { - tracing::error!(actor_id = %ctx.actor_id, ?error, "actor stop completion failed"); + tracing::error!(?error, "actor stop completion failed"); stop_code = protocol::StopCode::Error; if stop_message.is_none() { stop_message = Some(format!("{error:#}")); @@ -541,23 +547,26 @@ fn handle_req_start( let request_id = message_id.request_id; let request_guard = ActiveHttpRequestGuard::new(ctx.active_http_request_count.clone()); - http_request_tasks.spawn(async move { - let _request_guard = request_guard; - let response = shared - .config - .callbacks - .fetch(handle_clone, actor_id, gateway_id, request_id, request) - .await; + http_request_tasks.spawn( + async move { + let _request_guard = request_guard; + let response = shared + .config + .callbacks + .fetch(handle_clone, actor_id, gateway_id, request_id, request) + .await; - match response { - Ok(response) => { - send_response(&shared, gateway_id, request_id, response).await; - } - Err(error) => { - tracing::error!(?error, "fetch failed"); + match response { + Ok(response) => { + send_response(&shared, gateway_id, request_id, response).await; + } + Err(error) => { + tracing::error!(?error, "fetch failed"); + } } } - }); + .in_current_span(), + ); if !req.stream { ctx.pending_requests @@ -565,13 +574,13 @@ fn handle_req_start( } } -fn handle_http_request_task_result(ctx: &ActorContext, result: Result<(), JoinError>) { +fn handle_http_request_task_result(result: Result<(), JoinError>) { if let Err(error) = result { if error.is_cancelled() { return; } - tracing::error!(actor_id = %ctx.actor_id, ?error, "http request task failed"); + tracing::error!(?error, "http request task failed"); } } @@ -585,7 +594,6 @@ async fn abort_and_join_http_request_tasks( let active_http_request_count = ctx.active_http_request_count.load(); tracing::debug!( - actor_id = %ctx.actor_id, active_http_request_count, "aborting in-flight http request tasks" ); @@ -593,7 +601,7 @@ async fn abort_and_join_http_request_tasks( http_request_tasks.abort_all(); while let Some(result) = http_request_tasks.join_next().await { - handle_http_request_task_result(ctx, result); + handle_http_request_task_result(result); } } @@ -633,7 +641,7 @@ fn spawn_ws_outgoing_task( request_id: protocol::RequestId, mut outgoing_rx: mpsc::UnboundedReceiver, ) { - tokio::spawn(async move { + let ws_task = async move { let mut idx: u16 = 0; while let Some(msg) = outgoing_rx.recv().await { idx += 1; @@ -681,7 +689,8 @@ fn spawn_ws_outgoing_task( } } } - }); + }; + tokio::spawn(ws_task.in_current_span()); } async fn handle_ws_open( @@ -896,7 +905,6 @@ async fn handle_ws_message( if wrapping_lte_u16(received_index, previous_index) { tracing::info!( request_id = id_to_str(&message_id.request_id), - actor_id = %ctx.actor_id, previous_index, received_index, "received duplicate hibernating websocket message" @@ -908,7 +916,6 @@ async fn handle_ws_message( if received_index != expected_index { tracing::warn!( request_id = id_to_str(&message_id.request_id), - actor_id = %ctx.actor_id, previous_index, expected_index, received_index, @@ -1621,6 +1628,7 @@ mod tests { )), protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), shutting_down: std::sync::atomic::AtomicBool::new(false), + stopped_tx: tokio::sync::watch::channel(true).0, }); (shared, envoy_rx) } diff --git a/engine/sdks/rust/envoy-client/src/context.rs b/engine/sdks/rust/envoy-client/src/context.rs index c9a102172e..f1d69f0654 100644 --- a/engine/sdks/rust/envoy-client/src/context.rs +++ b/engine/sdks/rust/envoy-client/src/context.rs @@ -7,6 +7,7 @@ use rivet_envoy_protocol as protocol; use rivet_util::async_counter::AsyncCounter; use tokio::sync::Mutex; use tokio::sync::mpsc; +use tokio::sync::watch; use crate::actor::ToActor; use crate::config::EnvoyConfig; @@ -29,6 +30,11 @@ pub struct SharedContext { pub ws_tx: Arc>>>, pub protocol_metadata: Arc>>, pub shutting_down: AtomicBool, + // Latched signal fired by `envoy_loop` after its cleanup block completes. + // Waiters observing `true` are guaranteed that the loop has exited and + // every pending KV/SQLite request has been resolved (with `EnvoyShutdownError` + // if it didn't complete naturally). + pub stopped_tx: watch::Sender, } #[derive(Debug)] diff --git a/engine/sdks/rust/envoy-client/src/envoy.rs b/engine/sdks/rust/envoy-client/src/envoy.rs index 7eba705891..b88de1d127 100644 --- a/engine/sdks/rust/envoy-client/src/envoy.rs +++ b/engine/sdks/rust/envoy-client/src/envoy.rs @@ -252,6 +252,7 @@ pub fn start_envoy_sync(config: EnvoyConfig) -> EnvoyHandle { fn start_envoy_sync_inner(config: EnvoyConfig) -> EnvoyHandle { let (envoy_tx, envoy_rx) = mpsc::unbounded_channel::(); let (start_tx, start_rx) = tokio::sync::watch::channel(()); + let (stopped_tx, _stopped_rx) = tokio::sync::watch::channel(false); let envoy_key = uuid::Uuid::new_v4().to_string(); let shared = Arc::new(SharedContext { @@ -264,6 +265,7 @@ fn start_envoy_sync_inner(config: EnvoyConfig) -> EnvoyHandle { ws_tx: Arc::new(tokio::sync::Mutex::new(None)), protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), shutting_down: std::sync::atomic::AtomicBool::new(false), + stopped_tx, }); let handle = EnvoyHandle { @@ -271,13 +273,6 @@ fn start_envoy_sync_inner(config: EnvoyConfig) -> EnvoyHandle { started_rx: start_rx, }; - // Start signal handler - let handle2 = handle.clone(); - tokio::spawn(async move { - let _ = tokio::signal::ctrl_c().await; - handle2.shutdown(false); - }); - start_connection(shared.clone()); let ctx = EnvoyContext { @@ -459,6 +454,11 @@ async fn envoy_loop( tracing::info!("envoy stopped"); ctx.shared.config.callbacks.on_shutdown(); + + // Latched signal: waiters on `EnvoyHandle::wait_stopped` observe this and + // any future callers of `wait_stopped` resolve immediately because watch + // retains the last value. + let _ = ctx.shared.stopped_tx.send(true); } async fn handle_conn_message( diff --git a/engine/sdks/rust/envoy-client/src/events.rs b/engine/sdks/rust/envoy-client/src/events.rs index 467e4634c8..f9b9044fb1 100644 --- a/engine/sdks/rust/envoy-client/src/events.rs +++ b/engine/sdks/rust/envoy-client/src/events.rs @@ -164,6 +164,7 @@ mod tests { )), protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), shutting_down: std::sync::atomic::AtomicBool::new(false), + stopped_tx: tokio::sync::watch::channel(true).0, }); let handle = EnvoyHandle { shared: shared.clone(), diff --git a/engine/sdks/rust/envoy-client/src/handle.rs b/engine/sdks/rust/envoy-client/src/handle.rs index 85624ef6f0..9c2e8b83a8 100644 --- a/engine/sdks/rust/envoy-client/src/handle.rs +++ b/engine/sdks/rust/envoy-client/src/handle.rs @@ -36,6 +36,29 @@ impl EnvoyHandle { } } + /// Resolves when the envoy loop has finished its cleanup block. + /// + /// Returning does NOT imply successful delivery of pending KV/SQLite/tunnel + /// requests. The cleanup block errors out every outstanding request with + /// `"envoy shutting down"`. Callers needing durability must wait on individual + /// request acks before invoking shutdown. + /// + /// Latched: safe to call before, during, or after the envoy loop exits. + /// A waiter arriving after the loop already exited resolves immediately. + pub async fn wait_stopped(&self) { + let mut rx = self.shared.stopped_tx.subscribe(); + if *rx.borrow_and_update() { + return; + } + let _ = rx.changed().await; + } + + /// Convenience: signal shutdown then await `wait_stopped`. + pub async fn shutdown_and_wait(&self, immediate: bool) { + self.shutdown(immediate); + self.wait_stopped().await; + } + pub async fn get_protocol_metadata(&self) -> Option { self.shared.protocol_metadata.lock().await.clone() } diff --git a/rivetkit-rust/engine/artifacts/errors/registry.shut_down.json b/rivetkit-rust/engine/artifacts/errors/registry.shut_down.json new file mode 100644 index 0000000000..95304d6a26 --- /dev/null +++ b/rivetkit-rust/engine/artifacts/errors/registry.shut_down.json @@ -0,0 +1,5 @@ +{ + "code": "shut_down", + "group": "registry", + "message": "Registry is shut down." +} \ No newline at end of file diff --git a/rivetkit-rust/packages/rivetkit-core/CLAUDE.md b/rivetkit-rust/packages/rivetkit-core/CLAUDE.md index c161b5462e..e2b3b9fc46 100644 --- a/rivetkit-rust/packages/rivetkit-core/CLAUDE.md +++ b/rivetkit-rust/packages/rivetkit-core/CLAUDE.md @@ -24,6 +24,13 @@ - Flush the actor-connect `WebSocketSender` after queuing a setup `Error` frame and before closing so the envoy writer handles the error before the close terminates the connection. - Bound actor-connect websocket setup at the registry boundary as well as inside the actor task. The HTTP upgrade can complete before `connection_open` replies, so a missing reply must still close the socket instead of idling until the client test timeout. +## Run modes + +- Two run modes exist for `CoreRegistry`. **Persistent envoy**: `serve_with_config(...)` starts one outbound envoy via `start_envoy` and holds it for the process lifetime; used by standalone Rust binaries and TS `registry.start()`. **Serverless request**: `into_serverless_runtime(...)` returns a `CoreServerlessRuntime` whose `handle_request(...)` lazily starts an envoy on first request via `ensure_envoy(...)` and caches it; used by Node/Bun/Deno HTTP hosts and platform fetch handlers. Both modes end up holding a long-lived `EnvoyHandle`. +- Shutdown is a property of the host's `CoreRegistry` handle, not of whichever entrypoint ran first. Route process-level shutdown (SIGINT/SIGTERM) through a single `CoreRegistry::shutdown()` that trips one shared cancel token observed by `serve_with_config` and calls `CoreServerlessRuntime::shutdown()` on the cached runtime. Never attach the shutdown signal to `serve_with_config`'s config parameter — that misses Mode B entirely. +- `rivetkit-core` and `rivet-envoy-client` must not install process signal handlers (no `tokio::signal::ctrl_c()` in library code). `tokio::signal::ctrl_c()` calls `sigaction(SIGINT, ...)` at the POSIX level and prevents Node from exiting when rivetkit is embedded via NAPI. Signal policy belongs to the host binary or the TS registry layer. +- Per-request `CancellationToken` on `handle_serverless_request` cancels a single in-flight request and does not tear down the cached envoy. Do not overload it with registry shutdown. + ## Test harness - `tests/modules/task.rs` tests that install a tracing subscriber with `set_default(...)` must take `test_hook_lock()` first, or full `cargo test` parallelism makes the log-capture assertions flaky. diff --git a/rivetkit-rust/packages/rivetkit-core/examples/counter.rs b/rivetkit-rust/packages/rivetkit-core/examples/counter.rs index 8b4da793ad..30ffece301 100644 --- a/rivetkit-rust/packages/rivetkit-core/examples/counter.rs +++ b/rivetkit-rust/packages/rivetkit-core/examples/counter.rs @@ -99,8 +99,19 @@ fn counter_factory() -> ActorFactory { async fn main() -> Result<()> { let mut registry = CoreRegistry::new(); registry.register("counter", counter_factory()); - tokio::select! { - res = registry.serve() => res, - _ = tokio::signal::ctrl_c() => Ok(()), + let token = tokio_util::sync::CancellationToken::new(); + let serve = tokio::spawn({ + let token = token.clone(); + async move { registry.serve(token).await } + }); + match tokio::signal::ctrl_c().await { + Ok(()) => {} + Err(err) => tracing::warn!(?err, "ctrl_c install failed; cancelling anyway"), + } + token.cancel(); + match serve.await { + Ok(Ok(())) => Ok(()), + Ok(Err(err)) => Err(err), + Err(join_err) => Err(join_err.into()), } } diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/schedule.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/schedule.rs index c6d1e738cf..3591561e7f 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/schedule.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/schedule.rs @@ -7,6 +7,7 @@ use futures::future::BoxFuture; use rivet_envoy_client::handle::EnvoyHandle; use tokio::runtime::Handle; use tokio::sync::oneshot; +use tracing::Instrument; use uuid::Uuid; use crate::actor::context::ActorContext; @@ -32,6 +33,7 @@ impl ActorContext { pub fn at(&self, timestamp_ms: i64, action_name: &str, args: &[u8]) { if let Err(error) = self.schedule_event(timestamp_ms, action_name, args) { tracing::error!( + actor_id = %self.actor_id(), ?error, action_name, timestamp_ms, @@ -307,17 +309,20 @@ impl ActorContext { if let Ok(handle) = Handle::try_current() { let state_ctx = self.clone(); let (persist_done_tx, persist_done_rx) = oneshot::channel(); - handle.spawn(async move { - let _ = ack_rx.await; - if let Err(error) = state_ctx.persist_last_pushed_alarm(timestamp_ms).await { - tracing::error!( - ?error, - ?timestamp_ms, - "failed to persist last pushed actor alarm" - ); + handle.spawn( + async move { + let _ = ack_rx.await; + if let Err(error) = state_ctx.persist_last_pushed_alarm(timestamp_ms).await { + tracing::error!( + ?error, + ?timestamp_ms, + "failed to persist last pushed actor alarm" + ); + } + let _ = persist_done_tx.send(()); } - let _ = persist_done_tx.send(()); - }); + .in_current_span(), + ); self.0 .schedule_pending_alarm_writes .lock() @@ -356,35 +361,46 @@ impl ActorContext { ); // Intentionally detached but abortable: the handle is stored in // `local_alarm_task` and cancelled when alarms are resynced or stopped. - let handle = tokio_handle.spawn(async move { - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - if schedule.0.schedule_local_alarm_epoch.load(Ordering::SeqCst) != local_alarm_epoch { - return; + let handle = tokio_handle.spawn( + async move { + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + if schedule.0.schedule_local_alarm_epoch.load(Ordering::SeqCst) != local_alarm_epoch + { + return; + } + tracing::debug!( + timestamp_ms = next_alarm, + local_alarm_epoch, + "local actor alarm fired" + ); + let Some(callback) = schedule.0.schedule_local_alarm_callback.lock().clone() else { + return; + }; + callback().await; } - tracing::debug!( - actor_id = %schedule.actor_id(), - timestamp_ms = next_alarm, - local_alarm_epoch, - "local actor alarm fired" - ); - let Some(callback) = schedule.0.schedule_local_alarm_callback.lock().clone() else { - return; - }; - callback().await; - }); + .in_current_span(), + ); *self.0.schedule_local_alarm_task.lock() = Some(handle); } pub(crate) fn sync_alarm_logged(&self) { if let Err(error) = self.sync_alarm() { - tracing::error!(?error, "failed to sync scheduled actor alarm"); + tracing::error!( + actor_id = %self.actor_id(), + ?error, + "failed to sync scheduled actor alarm" + ); } } pub(crate) fn sync_future_alarm_logged(&self) { if let Err(error) = self.sync_future_alarm() { - tracing::error!(?error, "failed to sync future scheduled actor alarm"); + tracing::error!( + actor_id = %self.actor_id(), + ?error, + "failed to sync future scheduled actor alarm" + ); } } @@ -502,6 +518,7 @@ mod tests { )), protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), shutting_down: AtomicBool::new(false), + stopped_tx: tokio::sync::watch::channel(true).0, }); (EnvoyHandle::from_shared(shared), envoy_rx) diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs index ac069b3e78..1b79d8a874 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs @@ -12,6 +12,7 @@ use tokio::task::JoinHandle; #[cfg(test)] use tokio::time::sleep_until; use tokio::time::{Instant, sleep}; +use tracing::Instrument; use crate::actor::config::ActorConfig; use crate::actor::context::ActorContext; @@ -465,10 +466,13 @@ impl ActorContext { let counter = self.0.sleep.work.shutdown_counter.clone(); counter.increment(); let guard = CountGuard::from_incremented(counter); - shutdown_tasks.spawn(async move { - let _guard = guard; - fut.await; - }); + shutdown_tasks.spawn( + async move { + let _guard = guard; + fut.await; + } + .in_current_span(), + ); true } diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/state.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/state.rs index 270634ff5e..7949798dbb 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/state.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/state.rs @@ -9,6 +9,7 @@ use tokio::sync::mpsc; use tokio::task::JoinHandle; #[cfg(test)] use tokio::time::timeout; +use tracing::Instrument; use crate::actor::connection::make_connection_key; use crate::actor::context::ActorContext; @@ -573,17 +574,20 @@ impl ActorContext { // Intentionally detached but abortable: pending delayed saves are // retained in `pending_save`, replaced by newer saves, and awaited at // shutdown through the state save guard. - let handle = tokio_handle.spawn(async move { - if !delay.is_zero() { - tokio::time::sleep(delay).await; - } + let handle = tokio_handle.spawn( + async move { + if !delay.is_zero() { + tokio::time::sleep(delay).await; + } - state.take_pending_save(); + state.take_pending_save(); - if let Err(error) = state.persist_if_dirty().await { - tracing::error!(?error, "failed to persist actor state"); + if let Err(error) = state.persist_if_dirty().await { + tracing::error!(?error, "failed to persist actor state"); + } } - }); + .in_current_span(), + ); *pending_save = Some(PendingSave { scheduled_at, @@ -611,15 +615,18 @@ impl ActorContext { let state = self.clone(); let mut tracked_persist = self.0.tracked_persist.lock(); let previous = tracked_persist.take(); - let handle = tokio_handle.spawn(async move { - if let Some(previous) = previous { - let _ = previous.await; - } + let handle = tokio_handle.spawn( + async move { + if let Some(previous) = previous { + let _ = previous.await; + } - if let Err(error) = state.persist_state(SaveStateOpts { immediate: true }).await { - tracing::error!(?error, description, "failed to persist actor state"); + if let Err(error) = state.persist_state(SaveStateOpts { immediate: true }).await { + tracing::error!(?error, description, "failed to persist actor state"); + } } - }); + .in_current_span(), + ); *tracked_persist = Some(handle); } diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs index a3c623d839..1011d161ae 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs @@ -45,6 +45,7 @@ use parking_lot::Mutex; use tokio::sync::{broadcast, mpsc, oneshot}; use tokio::task::{JoinError, JoinHandle}; use tokio::time::{Duration, Instant, sleep_until, timeout}; +use tracing::Instrument; use crate::actor::action::ActionDispatchError; use crate::actor::connection::ConnHandle; @@ -63,7 +64,7 @@ use crate::actor::state::{ }; use crate::actor::task_types::StopReason; use crate::error::{ActorLifecycle as ActorLifecycleError, ActorRuntime}; -use crate::types::SaveStateOpts; +use crate::types::{SaveStateOpts, format_actor_key}; use crate::websocket::WebSocket; pub type ActionDispatchResult = std::result::Result, ActionDispatchError>; @@ -498,6 +499,14 @@ impl ActorTask { self } + #[tracing::instrument( + skip_all, + fields( + actor_id = %self.actor_id, + generation = self.generation, + actor_key = %format_actor_key(self.ctx.key()), + ), + )] pub async fn run(mut self) -> Result<()> { let exit = self.run_live().await; let LiveExit::Shutdown { reason } = exit else { @@ -852,18 +861,21 @@ impl ActorTask { fn core_dispatched_hook_reply(&self, operation: &'static str) -> Reply<()> { let (tx, rx) = oneshot::channel(); let ctx = self.ctx.clone(); - tokio::spawn(async move { - match rx.await { - Ok(Ok(())) => {} - Ok(Err(error)) => { - tracing::error!(?error, operation, "core dispatched hook failed"); - } - Err(error) => { - tracing::error!(?error, operation, "core dispatched hook reply dropped"); + tokio::spawn( + async move { + match rx.await { + Ok(Ok(())) => {} + Ok(Err(error)) => { + tracing::error!(?error, operation, "core dispatched hook failed"); + } + Err(error) => { + tracing::error!(?error, operation, "core dispatched hook reply dropped"); + } } + ctx.mark_core_dispatched_hook_completed(); } - ctx.mark_core_dispatched_hook_completed(); - }); + .in_current_span(), + ); tx.into() } @@ -908,7 +920,15 @@ impl ActorTask { conn, reply, } => { + tracing::info!( + actor_id = %self.ctx.actor_id(), + action_name = %name, + conn_id = ?conn.id(), + args_len = args.len(), + "actor task: handling DispatchCommand::Action" + ); let (tracked_reply_tx, tracked_reply_rx) = oneshot::channel(); + let action_name_for_log = name.clone(); match self.send_actor_event( "dispatch_action", ActorEvent::Action { @@ -919,13 +939,30 @@ impl ActorTask { }, ) { Ok(()) => { + tracing::info!( + actor_id = %self.ctx.actor_id(), + action_name = %action_name_for_log, + "actor task: ActorEvent::Action enqueued" + ); self.log_dispatch_command_handled(command_kind, "enqueued"); + let actor_id = self.ctx.actor_id().to_owned(); self.ctx.wait_until(async move { match tracked_reply_rx.await { Ok(result) => { + tracing::info!( + actor_id = %actor_id, + action_name = %action_name_for_log, + ok = result.is_ok(), + "actor task: tracked reply received, forwarding" + ); let _ = reply.send(result); } Err(_) => { + tracing::warn!( + actor_id = %actor_id, + action_name = %action_name_for_log, + "actor task: tracked reply dropped before completion" + ); let _ = reply.send(Err(ActorLifecycleError::DroppedReply.build())); } @@ -933,6 +970,12 @@ impl ActorTask { }); } Err(error) => { + tracing::warn!( + actor_id = %self.ctx.actor_id(), + action_name = %action_name_for_log, + ?error, + "actor task: failed to enqueue ActorEvent::Action" + ); let _ = reply.send(Err(error)); self.log_dispatch_command_handled(command_kind, "enqueue_failed"); } @@ -1260,15 +1303,18 @@ impl ActorTask { startup_ready: startup_ready_tx, }; let factory = self.factory.clone(); - self.run_handle = Some(tokio::spawn(async move { - match AssertUnwindSafe(factory.start(start)).catch_unwind().await { - Ok(result) => result, - Err(_) => Err(ActorRuntime::Panicked { - operation: "run handler".to_owned(), + self.run_handle = Some(tokio::spawn( + async move { + match AssertUnwindSafe(factory.start(start)).catch_unwind().await { + Ok(result) => result, + Err(_) => Err(ActorRuntime::Panicked { + operation: "run handler".to_owned(), + } + .build()), } - .build()), } - })); + .in_current_span(), + )); if let Some(startup_ready_rx) = startup_ready_rx { startup_ready_rx .await diff --git a/rivetkit-rust/packages/rivetkit-core/src/lib.rs b/rivetkit-rust/packages/rivetkit-core/src/lib.rs index d419086246..593f6b5270 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/lib.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/lib.rs @@ -36,5 +36,7 @@ pub use error::ActorLifecycle; pub use inspector::{Inspector, InspectorSnapshot}; pub use registry::{CoreRegistry, ServeConfig}; pub use serverless::{CoreServerlessRuntime, ServerlessRequest, ServerlessResponse}; -pub use types::{ActorKey, ActorKeySegment, ConnId, ListOpts, SaveStateOpts, WsMessage}; +pub use types::{ + ActorKey, ActorKeySegment, ConnId, ListOpts, SaveStateOpts, WsMessage, format_actor_key, +}; pub use websocket::WebSocket; diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/dispatch.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/dispatch.rs index a2176b4b84..d642e2a9c9 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/dispatch.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/dispatch.rs @@ -9,12 +9,17 @@ pub(super) async fn dispatch_action_through_task( args: Vec, ) -> std::result::Result, ActionDispatchError> { let (reply_tx, reply_rx) = oneshot::channel(); + tracing::info!( + action_name = %name, + conn_id = ?conn.id(), + "dispatch_action: sending DispatchCommand::Action" + ); try_send_dispatch_command( dispatch, capacity, "dispatch_action", DispatchCommand::Action { - name, + name: name.clone(), args, conn, reply: reply_tx, @@ -22,9 +27,29 @@ pub(super) async fn dispatch_action_through_task( None, ) .map_err(ActionDispatchError::from_anyhow)?; + tracing::info!( + action_name = %name, + "dispatch_action: command queued, awaiting reply" + ); - reply_rx - .await + let result = reply_rx.await; + match &result { + Ok(Ok(bytes)) => tracing::info!( + action_name = %name, + output_len = bytes.len(), + "dispatch_action: reply received" + ), + Ok(Err(error)) => tracing::warn!( + action_name = %name, + ?error, + "dispatch_action: reply was an error" + ), + Err(_) => tracing::warn!( + action_name = %name, + "dispatch_action: reply channel dropped" + ), + } + result .map_err(|_| ActionDispatchError::from_anyhow(ActorLifecycleError::DroppedReply.build()))? .map_err(ActionDispatchError::from_anyhow) } diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/envoy_callbacks.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/envoy_callbacks.rs index e154584b7a..5cf408b818 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/envoy_callbacks.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/envoy_callbacks.rs @@ -1,3 +1,5 @@ +use tracing::Instrument; + use super::*; use crate::error::ActorRuntime; @@ -63,16 +65,18 @@ impl EnvoyCallbacks for RegistryCallbacks { ) -> EnvoyBoxFuture> { let dispatcher = self.dispatcher.clone(); Box::pin(async move { - tokio::spawn(async move { - if let Err(error) = dispatcher.stop_actor(&actor_id, reason, stop_handle).await { - tracing::error!( - actor_id, - generation, - ?error, - "actor stop failed after asynchronous completion handoff" - ); + tokio::spawn( + async move { + if let Err(error) = dispatcher.stop_actor(&actor_id, reason, stop_handle).await + { + tracing::error!( + ?error, + "actor stop failed after asynchronous completion handoff", + ); + } } - }); + .in_current_span(), + ); Ok(()) }) } @@ -87,6 +91,11 @@ impl EnvoyCallbacks for RegistryCallbacks { _request_id: protocol::RequestId, request: HttpRequest, ) -> EnvoyBoxFuture> { + tracing::info!( + method = %request.method, + path = %request.path, + "envoy callback: fetch request" + ); let dispatcher = self.dispatcher.clone(); Box::pin(async move { dispatcher.handle_fetch(&actor_id, request).await }) } @@ -104,6 +113,12 @@ impl EnvoyCallbacks for RegistryCallbacks { is_restoring_hibernatable: bool, sender: WebSocketSender, ) -> EnvoyBoxFuture> { + tracing::info!( + path = %_path, + is_hibernatable = _is_hibernatable, + is_restoring_hibernatable, + "envoy callback: websocket request" + ); let dispatcher = self.dispatcher.clone(); Box::pin(async move { dispatcher diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/inspector.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/inspector.rs index 1c4fcd51c6..f04a65aff4 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/inspector.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/inspector.rs @@ -238,14 +238,27 @@ impl RegistryDispatcher { action_name: &str, args: Vec, ) -> std::result::Result, ActionDispatchError> { + tracing::info!( + action_name, + args_len = args.len(), + "inspector RPC: connecting transient conn" + ); let conn = match instance .ctx .connect_conn(Vec::new(), false, None, None, async { Ok(Vec::new()) }) .await { Ok(conn) => conn, - Err(error) => return Err(ActionDispatchError::from_anyhow(error)), + Err(error) => { + tracing::warn!(action_name, ?error, "inspector RPC: connect_conn failed"); + return Err(ActionDispatchError::from_anyhow(error)); + } }; + tracing::info!( + action_name, + conn_id = ?conn.id(), + "inspector RPC: dispatching to actor task" + ); let output = dispatch_action_through_task( &instance.dispatch, instance.factory.config().dispatch_command_inbox_capacity, @@ -254,6 +267,14 @@ impl RegistryDispatcher { args, ) .await; + match &output { + Ok(bytes) => tracing::info!( + action_name, + output_len = bytes.len(), + "inspector RPC: action returned" + ), + Err(error) => tracing::warn!(action_name, ?error, "inspector RPC: action errored"), + } if let Err(error) = conn.disconnect(None).await { tracing::warn!( ?error, diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/inspector_ws.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/inspector_ws.rs index a52ec442b1..f71f2607e6 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/inspector_ws.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/inspector_ws.rs @@ -3,6 +3,7 @@ use super::http::authorization_bearer_token_map; use super::inspector::*; use super::websocket::{closing_websocket_handler, websocket_inspector_token}; use super::*; +use tracing::Instrument; /// Aborts the wrapped task on drop. Ensures the overlay task cannot outlive /// the websocket handler even if `on_close` never fires (for example when the @@ -23,6 +24,7 @@ impl RegistryDispatcher { _request: &HttpRequest, headers: &HashMap, ) -> Result { + tracing::info!(actor_id, "inspector WS: handler invoked, verifying auth"); if InspectorAuth::new() .verify( &instance.ctx, @@ -38,6 +40,7 @@ impl RegistryDispatcher { ); return Ok(closing_websocket_handler(1008, "inspector.unauthorized")); } + tracing::info!(actor_id, "inspector WS: auth passed, building handler"); let dispatcher = self.clone(); // Forced-sync: inspector websocket slots are filled/cleared inside @@ -53,11 +56,20 @@ impl RegistryDispatcher { let on_message_instance = instance.clone(); let on_message_dispatcher = dispatcher.clone(); + let on_message_actor_id = actor_id.to_owned(); + let on_close_actor_id = actor_id.to_owned(); + let on_open_actor_id = actor_id.to_owned(); Ok(WebSocketHandler { on_message: Box::new(move |message: WebSocketMessage| { let dispatcher = on_message_dispatcher.clone(); let instance = on_message_instance.clone(); + let actor_id = on_message_actor_id.clone(); Box::pin(async move { + tracing::info!( + actor_id = %actor_id, + bytes = message.data.len(), + "inspector WS: on_message fired" + ); dispatcher .handle_inspector_websocket_message( &instance, @@ -67,11 +79,18 @@ impl RegistryDispatcher { .await; }) }), - on_close: Box::new(move |_code, _reason| { + on_close: Box::new(move |code, reason| { let slot = subscription_slot.clone(); let overlay_slot = overlay_task_slot.clone(); let attach_slot = attach_guard_slot.clone(); + let actor_id = on_close_actor_id.clone(); Box::pin(async move { + tracing::info!( + actor_id = %actor_id, + ?code, + ?reason, + "inspector WS: on_close fired" + ); let mut guard = slot.lock(); guard.take(); let mut overlay_guard = overlay_slot.lock(); @@ -81,7 +100,9 @@ impl RegistryDispatcher { }) }), on_open: Some(Box::new(move |open_sender| { + let actor_id = on_open_actor_id.clone(); Box::pin(async move { + tracing::info!(actor_id = %actor_id, "inspector WS: on_open fired, building init"); match on_open_dispatcher .inspector_init_message(&on_open_instance) .await @@ -93,6 +114,7 @@ impl RegistryDispatcher { .close(Some(1011), Some("inspector.init_error".to_owned())); return; } + tracing::info!(actor_id = %actor_id, "inspector WS: init message sent"); } Err(error) => { tracing::error!(?error, "failed to build inspector init message"); @@ -118,42 +140,49 @@ impl RegistryDispatcher { *guard = Some(attach_guard); } let overlay_sender = open_sender.clone(); - let overlay_task = tokio::spawn(async move { - loop { - match overlay_rx.recv().await { - Ok(payload) => match decode_inspector_overlay_state(&payload) { - Ok(Some(state)) => { - if let Err(error) = send_inspector_message( - &overlay_sender, - &InspectorServerMessage::StateUpdated( - inspector_protocol::StateUpdated { state }, - ), - ) { + let overlay_actor_id = on_open_instance.ctx.actor_id().to_owned(); + let overlay_task = tokio::spawn( + async move { + loop { + match overlay_rx.recv().await { + Ok(payload) => match decode_inspector_overlay_state(&payload) { + Ok(Some(state)) => { + if let Err(error) = send_inspector_message( + &overlay_sender, + &InspectorServerMessage::StateUpdated( + inspector_protocol::StateUpdated { state }, + ), + ) { + tracing::error!( + ?error, + "failed to push inspector overlay update" + ); + break; + } + } + Ok(None) => {} + Err(error) => { tracing::error!( ?error, - "failed to push inspector overlay update" + "failed to decode inspector overlay update" ); - break; } - } - Ok(None) => {} - Err(error) => { - tracing::error!( - ?error, - "failed to decode inspector overlay update" + }, + Err(broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!( + skipped, + "inspector overlay subscriber lagged; waiting for next sync" ); } - }, - Err(broadcast::error::RecvError::Lagged(skipped)) => { - tracing::warn!( - skipped, - "inspector overlay subscriber lagged; waiting for next sync" - ); + Err(broadcast::error::RecvError::Closed) => break, } - Err(broadcast::error::RecvError::Closed) => break, } } - }); + .instrument(tracing::info_span!( + "inspector_ws", + actor_id = %overlay_actor_id, + )), + ); let mut overlay_guard = on_open_overlay_slot.lock(); *overlay_guard = Some(AbortOnDropTask(overlay_task)); @@ -170,32 +199,39 @@ impl RegistryDispatcher { let dispatcher = listener_dispatcher.clone(); let instance = listener_instance.clone(); let sender = listener_sender.clone(); - tokio::spawn(async move { - match dispatcher - .inspector_push_message_for_signal(&instance, signal) - .await - { - Ok(Some(message)) => { - if let Err(error) = - send_inspector_message(&sender, &message) - { + let actor_id = instance.ctx.actor_id().to_owned(); + tokio::spawn( + async move { + match dispatcher + .inspector_push_message_for_signal(&instance, signal) + .await + { + Ok(Some(message)) => { + if let Err(error) = + send_inspector_message(&sender, &message) + { + tracing::error!( + ?error, + ?signal, + "failed to push inspector websocket update" + ); + } + } + Ok(None) => {} + Err(error) => { tracing::error!( ?error, ?signal, - "failed to push inspector websocket update" + "failed to build inspector websocket update" ); } } - Ok(None) => {} - Err(error) => { - tracing::error!( - ?error, - ?signal, - "failed to build inspector websocket update" - ); - } } - }); + .instrument(tracing::info_span!( + "inspector_ws", + actor_id = %actor_id, + )), + ); })); let mut guard = on_open_slot.lock(); *guard = Some(subscription); @@ -211,28 +247,67 @@ impl RegistryDispatcher { payload: &[u8], ) { let response = match inspector_protocol::decode_client_message(payload) { - Ok(message) => match self - .process_inspector_websocket_message(instance, message) - .await - { - Ok(response) => response, - Err(error) => Some(InspectorServerMessage::Error( + Ok(message) => { + tracing::info!( + actor_id = %instance.ctx.actor_id(), + message_kind = client_message_kind(&message), + payload_len = payload.len(), + "inspector WS: decoded client message" + ); + match self + .process_inspector_websocket_message(instance, message) + .await + { + Ok(response) => { + tracing::info!( + actor_id = %instance.ctx.actor_id(), + response_kind = response.as_ref().map(server_message_kind).unwrap_or("None"), + "inspector WS: processed client message" + ); + response + } + Err(error) => { + tracing::warn!( + actor_id = %instance.ctx.actor_id(), + ?error, + "inspector WS: process_inspector_websocket_message returned error" + ); + Some(InspectorServerMessage::Error( + inspector_protocol::ErrorMessage { + message: error.to_string(), + }, + )) + } + } + } + Err(error) => { + tracing::warn!( + actor_id = %instance.ctx.actor_id(), + payload_len = payload.len(), + ?error, + "inspector WS: failed to decode client message" + ); + Some(InspectorServerMessage::Error( inspector_protocol::ErrorMessage { message: error.to_string(), }, - )), - }, - Err(error) => Some(InspectorServerMessage::Error( - inspector_protocol::ErrorMessage { - message: error.to_string(), - }, - )), + )) + } }; - if let Some(response) = response - && let Err(error) = send_inspector_message(sender, &response) - { - tracing::error!(?error, "failed to send inspector websocket response"); + if let Some(response) = response { + match send_inspector_message(sender, &response) { + Ok(()) => tracing::debug!( + actor_id = %instance.ctx.actor_id(), + response_kind = server_message_kind(&response), + "inspector WS: sent response" + ), + Err(error) => tracing::error!( + ?error, + response_kind = server_message_kind(&response), + "failed to send inspector websocket response" + ), + } } } @@ -264,10 +339,22 @@ impl RegistryDispatcher { ))) } inspector_protocol::ClientMessage::ActionRequest(request) => { + tracing::info!( + rid = ?request.id, + action_name = %request.name, + args_len = request.args.len(), + "inspector WS: ActionRequest received" + ); let output = self .execute_inspector_action_bytes(instance, &request.name, request.args) .await .map_err(ActionDispatchError::into_anyhow)?; + tracing::info!( + rid = ?request.id, + action_name = %request.name, + output_len = output.len(), + "inspector WS: ActionResponse ready to send" + ); Ok(Some(InspectorServerMessage::ActionResponse( inspector_protocol::ActionResponse { rid: request.id, @@ -482,3 +569,41 @@ fn inspector_state_payload(ctx: &ActorContext, is_state_enabled: bool) -> Option let state = ctx.state(); if state.is_empty() { None } else { Some(state) } } + +fn client_message_kind(message: &inspector_protocol::ClientMessage) -> &'static str { + use inspector_protocol::ClientMessage as C; + match message { + C::PatchStateRequest(_) => "PatchStateRequest", + C::StateRequest(_) => "StateRequest", + C::ConnectionsRequest(_) => "ConnectionsRequest", + C::ActionRequest(_) => "ActionRequest", + C::RpcsListRequest(_) => "RpcsListRequest", + C::TraceQueryRequest(_) => "TraceQueryRequest", + C::QueueRequest(_) => "QueueRequest", + C::WorkflowHistoryRequest(_) => "WorkflowHistoryRequest", + C::WorkflowReplayRequest(_) => "WorkflowReplayRequest", + C::DatabaseSchemaRequest(_) => "DatabaseSchemaRequest", + C::DatabaseTableRowsRequest(_) => "DatabaseTableRowsRequest", + } +} + +fn server_message_kind(message: &InspectorServerMessage) -> &'static str { + match message { + InspectorServerMessage::Init(_) => "Init", + InspectorServerMessage::StateResponse(_) => "StateResponse", + InspectorServerMessage::StateUpdated(_) => "StateUpdated", + InspectorServerMessage::ConnectionsResponse(_) => "ConnectionsResponse", + InspectorServerMessage::ConnectionsUpdated(_) => "ConnectionsUpdated", + InspectorServerMessage::ActionResponse(_) => "ActionResponse", + InspectorServerMessage::RpcsListResponse(_) => "RpcsListResponse", + InspectorServerMessage::TraceQueryResponse(_) => "TraceQueryResponse", + InspectorServerMessage::QueueResponse(_) => "QueueResponse", + InspectorServerMessage::QueueUpdated(_) => "QueueUpdated", + InspectorServerMessage::WorkflowHistoryResponse(_) => "WorkflowHistoryResponse", + InspectorServerMessage::WorkflowHistoryUpdated(_) => "WorkflowHistoryUpdated", + InspectorServerMessage::WorkflowReplayResponse(_) => "WorkflowReplayResponse", + InspectorServerMessage::DatabaseSchemaResponse(_) => "DatabaseSchemaResponse", + InspectorServerMessage::DatabaseTableRowsResponse(_) => "DatabaseTableRowsResponse", + InspectorServerMessage::Error(_) => "Error", + } +} diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/mod.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/mod.rs index 07fafb5618..ace1536eaa 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/mod.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/mod.rs @@ -4,7 +4,7 @@ use std::io::Cursor; use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; -use std::time::Instant; +use std::time::{Duration, Instant}; use ::http::StatusCode; use anyhow::{Context, Result}; @@ -25,6 +25,7 @@ use serde_bytes::ByteBuf; use serde_json::{Value as JsonValue, json}; use tokio::sync::{Mutex as TokioMutex, Notify, broadcast, mpsc, oneshot}; use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; use vbare::OwnedVersionedData; use crate::actor::action::ActionDispatchError; @@ -69,6 +70,11 @@ mod websocket; use inspector::build_actor_inspector; use websocket::is_actor_connect_path; +/// Bound on `handle.shutdown_and_wait` inside `serve_with_config` teardown. +/// Protects against indefinite hangs if the envoy reconnect loop is stuck; +/// the TS/outer-host grace period is the ultimate backstop. +const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(20); + #[derive(Debug, Default)] pub struct CoreRegistry { factories: HashMap>, @@ -410,11 +416,16 @@ impl CoreRegistry { self.factories.insert(name.to_owned(), factory); } - pub async fn serve(self) -> Result<()> { - self.serve_with_config(ServeConfig::from_env()).await + pub async fn serve(self, shutdown: CancellationToken) -> Result<()> { + self.serve_with_config(ServeConfig::from_env(), shutdown) + .await } - pub async fn serve_with_config(self, config: ServeConfig) -> Result<()> { + pub async fn serve_with_config( + self, + config: ServeConfig, + shutdown: CancellationToken, + ) -> Result<()> { let dispatcher = self.into_dispatcher(&config); let _engine_process = match config.engine_binary_path.as_ref() { Some(binary_path) => { @@ -427,7 +438,7 @@ impl CoreRegistry { dispatcher: dispatcher.clone(), }); - let _handle = start_envoy(rivet_envoy_client::config::EnvoyConfig { + let handle = start_envoy(rivet_envoy_client::config::EnvoyConfig { version: config.version, endpoint: config.endpoint, token: config.token, @@ -445,8 +456,22 @@ impl CoreRegistry { // `sigaction(SIGINT, ...)` at the POSIX level, which overrides the // host's default SIGINT handling when rivetkit-core is embedded in // Node via NAPI and leaves the host process unable to exit. Callers - // drive shutdown themselves by dropping the task. - std::future::pending::>().await + // trip the `shutdown` token instead. + shutdown.cancelled().await; + + // Bounded drain. If envoy cannot reach the engine (reconnect loop stuck), + // we fall back to immediate `Stop` rather than hanging indefinitely. + // The outer host (TS signal handler / Rust binary) is the backstop. + match tokio::time::timeout(SHUTDOWN_DRAIN_TIMEOUT, handle.shutdown_and_wait(false)).await { + Ok(()) => {} + Err(_) => { + tracing::warn!("envoy shutdown drain exceeded timeout; forcing immediate stop"); + handle.shutdown(true); + handle.wait_stopped().await; + } + } + + Ok(()) } fn into_dispatcher(self, config: &ServeConfig) -> Arc { diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs index a68d99e1e3..70be849e94 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs @@ -4,6 +4,7 @@ use super::inspector::encode_json_as_cbor; use super::*; use crate::error::ProtocolError; use tokio::time::timeout; +use tracing::Instrument; impl RegistryDispatcher { pub(super) async fn handle_websocket( @@ -18,13 +19,22 @@ impl RegistryDispatcher { is_restoring_hibernatable: bool, sender: WebSocketSender, ) -> Result { + tracing::info!(actor_id, path, "handle_websocket: routing"); let instance = self.active_actor(actor_id).await?; if is_inspector_connect_path(path)? { + tracing::info!( + actor_id, + "handle_websocket: dispatching to inspector handler" + ); return self .handle_inspector_websocket(actor_id, instance, request, headers) .await; } if is_actor_connect_path(path)? { + tracing::info!( + actor_id, + "handle_websocket: dispatching to actor-connect handler" + ); return self .handle_actor_connect_websocket( actor_id, @@ -40,6 +50,11 @@ impl RegistryDispatcher { ) .await; } + tracing::info!( + actor_id, + path, + "handle_websocket: dispatching to raw handler" + ); match self .handle_raw_websocket( actor_id, @@ -359,114 +374,123 @@ impl RegistryDispatcher { let ctx = ctx.clone(); let conn = conn.clone(); let message_index = message.message_index; - tokio::spawn(async move { - let response = match dispatch_action_through_task( - &dispatch, - on_message_dispatch_capacity, - conn.clone(), - request.name.clone(), - request.args.into_vec(), - ) - .await - { - Ok(output) => ActorConnectToClient::ActionResponse( - ActorConnectActionResponse { - id: request.id, - output: ByteBuf::from(output), - }, - ), - Err(error) => { - if conn.is_hibernatable() && ctx.sleep_requested() { - tracing::debug!( - conn_id = conn.id(), - message_index, - action_name = request.name, - "deferring hibernatable actor websocket action while actor is entering sleep" - ); - return; - } - ActorConnectToClient::Error(action_dispatch_error_response( - error, request.id, - )) - } - }; - - if conn.is_hibernatable() - && let Err(error) = persist_and_ack_hibernatable_actor_message( - &ctx, - &conn, - message_index, + let actor_id = ctx.actor_id().to_owned(); + tokio::spawn( + async move { + let response = match dispatch_action_through_task( + &dispatch, + on_message_dispatch_capacity, + conn.clone(), + request.name.clone(), + request.args.into_vec(), ) .await - { - tracing::warn!( - ?error, - conn_id = conn.id(), - "failed to persist and ack hibernatable actor websocket message" - ); - sender.close( - Some(1011), - Some("actor.hibernation_persist_failed".to_owned()), - ); - return; - } - - match send_actor_connect_message( - &sender, - encoding, - &response, - max_outgoing_message_size, - ) { - Ok(()) => {} - Err(ActorConnectSendError::OutgoingTooLong) => { - let error_response = - ActorConnectToClient::Error(ActorConnectError { - group: "message".to_owned(), - code: "outgoing_too_long".to_owned(), - message: "Outgoing message too long".to_owned(), - metadata: None, - action_id: Some(request.id), - }); - if let Err(error) = send_actor_connect_message( - &sender, - encoding, - &error_response, - usize::MAX, - ) { - match error { - ActorConnectSendError::OutgoingTooLong => { - sender.close( - Some(1011), - Some( - "message.outgoing_too_long".to_owned(), - ), - ); - } - ActorConnectSendError::Encode(error) => { - tracing::error!( - ?error, - "failed to send actor websocket outgoing-size error" - ); - sender.close( - Some(1011), - Some("actor.send_failed".to_owned()), - ); - } + { + Ok(output) => ActorConnectToClient::ActionResponse( + ActorConnectActionResponse { + id: request.id, + output: ByteBuf::from(output), + }, + ), + Err(error) => { + if conn.is_hibernatable() && ctx.sleep_requested() { + tracing::debug!( + conn_id = conn.id(), + message_index, + action_name = request.name, + "deferring hibernatable actor websocket action while actor is entering sleep" + ); + return; } + ActorConnectToClient::Error( + action_dispatch_error_response(error, request.id), + ) } - } - Err(ActorConnectSendError::Encode(error)) => { - tracing::error!( + }; + + if conn.is_hibernatable() + && let Err(error) = + persist_and_ack_hibernatable_actor_message( + &ctx, + &conn, + message_index, + ) + .await + { + tracing::warn!( ?error, - "failed to send actor websocket response" + conn_id = conn.id(), + "failed to persist and ack hibernatable actor websocket message" ); sender.close( Some(1011), - Some("actor.send_failed".to_owned()), + Some("actor.hibernation_persist_failed".to_owned()), ); + return; + } + + match send_actor_connect_message( + &sender, + encoding, + &response, + max_outgoing_message_size, + ) { + Ok(()) => {} + Err(ActorConnectSendError::OutgoingTooLong) => { + let error_response = + ActorConnectToClient::Error(ActorConnectError { + group: "message".to_owned(), + code: "outgoing_too_long".to_owned(), + message: "Outgoing message too long".to_owned(), + metadata: None, + action_id: Some(request.id), + }); + if let Err(error) = send_actor_connect_message( + &sender, + encoding, + &error_response, + usize::MAX, + ) { + match error { + ActorConnectSendError::OutgoingTooLong => { + sender.close( + Some(1011), + Some( + "message.outgoing_too_long" + .to_owned(), + ), + ); + } + ActorConnectSendError::Encode(error) => { + tracing::error!( + ?error, + "failed to send actor websocket outgoing-size error" + ); + sender.close( + Some(1011), + Some("actor.send_failed".to_owned()), + ); + } + } + } + } + Err(ActorConnectSendError::Encode(error)) => { + tracing::error!( + ?error, + "failed to send actor websocket response" + ); + sender.close( + Some(1011), + Some("actor.send_failed".to_owned()), + ); + } } } - }); + .instrument(tracing::info_span!( + "actor_connect_ws", + actor_id = %actor_id, + )), + ); } } }) diff --git a/rivetkit-rust/packages/rivetkit-core/src/serverless.rs b/rivetkit-rust/packages/rivetkit-core/src/serverless.rs index a5c248f6eb..97c67eab6a 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/serverless.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/serverless.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use anyhow::{Context, Result}; @@ -23,6 +24,10 @@ use crate::registry::{RegistryCallbacks, RegistryDispatcher, ServeConfig}; const DEFAULT_BASE_PATH: &str = "/api/rivet"; const SSE_PING_INTERVAL: Duration = Duration::from_secs(1); +/// Bound on `handle.shutdown_and_wait` inside teardown paths. If envoy cannot +/// reach the engine (reconnect loop stuck), we fall back to immediate `Stop` +/// rather than hanging indefinitely. Must stay below the outer TS grace ceiling. +const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(20); #[derive(Clone)] pub struct CoreServerlessRuntime { @@ -30,6 +35,7 @@ pub struct CoreServerlessRuntime { dispatcher: Arc, envoy: Arc>>, _engine_process: Arc>>, + shutting_down: Arc, } #[derive(Clone, Debug)] @@ -128,6 +134,15 @@ struct IncomingMessageTooLong { limit: usize, } +#[derive(rivet_error::RivetError, Serialize)] +#[error( + "registry", + "shut_down", + "Registry is shut down.", + "Registry is shut down; no new requests can be accepted." +)] +struct RuntimeShutDown; + impl CoreServerlessRuntime { pub(crate) async fn new( factories: HashMap>, @@ -163,9 +178,32 @@ impl CoreServerlessRuntime { dispatcher, envoy: Arc::new(TokioMutex::new(None)), _engine_process: Arc::new(TokioMutex::new(engine_process)), + shutting_down: Arc::new(AtomicBool::new(false)), }) } + /// Tear down the cached envoy handle. Idempotent. + /// + /// Sets `shutting_down` so concurrent `ensure_envoy` callers short-circuit + /// instead of starting a fresh envoy after teardown, and waits (with a + /// bounded timeout) for `envoy_loop` to exit. If the drain exceeds the + /// timeout (e.g. engine unreachable), falls back to an immediate `Stop`. + pub async fn shutdown(&self) { + self.shutting_down.store(true, Ordering::Release); + let handle = { self.envoy.lock().await.take() }; + let Some(handle) = handle else { return }; + match tokio::time::timeout(SHUTDOWN_DRAIN_TIMEOUT, handle.shutdown_and_wait(false)).await { + Ok(()) => {} + Err(_) => { + tracing::warn!( + "serverless runtime envoy drain exceeded timeout; forcing immediate stop" + ); + handle.shutdown(true); + handle.wait_stopped().await; + } + } + } + pub async fn handle_request(&self, req: ServerlessRequest) -> ServerlessResponse { let cors = cors_headers(&req); match self.handle_request_inner(req).await { @@ -399,6 +437,9 @@ impl CoreServerlessRuntime { } async fn ensure_envoy(&self, headers: &StartHeaders) -> Result { + if self.shutting_down.load(Ordering::Acquire) { + return Err(RuntimeShutDown.build()); + } let mut guard = self.envoy.lock().await; if let Some(handle) = guard.as_ref() { if !endpoints_match(handle.endpoint(), &headers.endpoint) @@ -413,6 +454,10 @@ impl CoreServerlessRuntime { let callbacks = Arc::new(RegistryCallbacks { dispatcher: self.dispatcher.clone(), }); + // not_global: true to avoid caching the handle in the process-wide + // `GLOBAL_ENVOY` OnceLock. Without this, a shutdown-during-build race + // (spec §3 step 7) leaves a dead handle cached for the life of the + // process and any subsequent consumer gets it back. let handle = start_envoy(EnvoyConfig { version: self.settings.version, endpoint: headers.endpoint.clone(), @@ -425,11 +470,27 @@ impl CoreServerlessRuntime { pool_name: headers.pool_name.clone(), prepopulate_actor_names: HashMap::new(), metadata: None, - not_global: false, + not_global: true, debug_latency_ms: None, callbacks, }) .await; + // Re-check under the lock: shutdown may have run while we were awaiting + // `start_envoy`. If so, tear down the freshly-built envoy rather than + // installing it into the cache. + if self.shutting_down.load(Ordering::Acquire) { + drop(guard); + match tokio::time::timeout(SHUTDOWN_DRAIN_TIMEOUT, handle.shutdown_and_wait(false)) + .await + { + Ok(()) => {} + Err(_) => { + handle.shutdown(true); + handle.wait_stopped().await; + } + } + return Err(RuntimeShutDown.build()); + } *guard = Some(handle.clone()); Ok(handle) } diff --git a/rivetkit-rust/packages/rivetkit-core/src/types.rs b/rivetkit-rust/packages/rivetkit-core/src/types.rs index ad1e250830..533f1a928a 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/types.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/types.rs @@ -9,6 +9,34 @@ pub enum ActorKeySegment { Number(f64), } +pub fn format_actor_key(key: &ActorKey) -> String { + key.iter() + .map(|segment| match segment { + ActorKeySegment::String(value) => escape_actor_key_segment(value), + ActorKeySegment::Number(value) => escape_actor_key_segment(&value.to_string()), + }) + .collect::>() + .join("/") +} + +fn escape_actor_key_segment(segment: &str) -> String { + if segment.is_empty() { + return "\\0".to_owned(); + } + + let mut escaped = String::with_capacity(segment.len()); + for ch in segment.chars() { + match ch { + '\\' | '/' => { + escaped.push('\\'); + escaped.push(ch); + } + _ => escaped.push(ch), + } + } + escaped +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum WsMessage { Text(String), diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/context.rs b/rivetkit-rust/packages/rivetkit-core/tests/modules/context.rs index ccc1a4b190..394d850c1a 100644 --- a/rivetkit-rust/packages/rivetkit-core/tests/modules/context.rs +++ b/rivetkit-rust/packages/rivetkit-core/tests/modules/context.rs @@ -327,6 +327,7 @@ mod moved_tests { )), protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), shutting_down: std::sync::atomic::AtomicBool::new(false), + stopped_tx: tokio::sync::watch::channel(true).0, }); shared .actors @@ -371,6 +372,7 @@ mod moved_tests { )), protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), shutting_down: std::sync::atomic::AtomicBool::new(false), + stopped_tx: tokio::sync::watch::channel(true).0, }); EnvoyHandle::from_shared(shared) } diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/task.rs b/rivetkit-rust/packages/rivetkit-core/tests/modules/task.rs index 3acef21cb6..f8556c1293 100644 --- a/rivetkit-rust/packages/rivetkit-core/tests/modules/task.rs +++ b/rivetkit-rust/packages/rivetkit-core/tests/modules/task.rs @@ -249,6 +249,7 @@ mod moved_tests { )), protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), shutting_down: AtomicBool::new(false), + stopped_tx: tokio::sync::watch::channel(true).0, }); (EnvoyHandle::from_shared(shared), envoy_rx) diff --git a/rivetkit-rust/packages/rivetkit/examples/chat.rs b/rivetkit-rust/packages/rivetkit/examples/chat.rs index 9f396b6427..86fab8ad70 100644 --- a/rivetkit-rust/packages/rivetkit/examples/chat.rs +++ b/rivetkit-rust/packages/rivetkit/examples/chat.rs @@ -106,8 +106,19 @@ async fn save_chat_state(ctx: &Ctx, state: &ChatState) -> Result<()> { async fn main() -> Result<()> { let mut registry = Registry::new(); registry.register::("chat", run); - tokio::select! { - res = registry.serve() => res, - _ = tokio::signal::ctrl_c() => Ok(()), + let token = tokio_util::sync::CancellationToken::new(); + let serve = tokio::spawn({ + let token = token.clone(); + async move { registry.serve(token).await } + }); + match tokio::signal::ctrl_c().await { + Ok(()) => {} + Err(err) => tracing::warn!(?err, "ctrl_c install failed; cancelling anyway"), + } + token.cancel(); + match serve.await { + Ok(Ok(())) => Ok(()), + Ok(Err(err)) => Err(err), + Err(join_err) => Err(join_err.into()), } } diff --git a/rivetkit-rust/packages/rivetkit/examples/counter.rs b/rivetkit-rust/packages/rivetkit/examples/counter.rs index ebf66423c4..fa96d706f9 100644 --- a/rivetkit-rust/packages/rivetkit/examples/counter.rs +++ b/rivetkit-rust/packages/rivetkit/examples/counter.rs @@ -105,8 +105,19 @@ async fn main() -> Result<()> { .action("increment", Counter::increment) .action("get_count", Counter::get_count) .done(); - tokio::select! { - res = registry.serve() => res, - _ = tokio::signal::ctrl_c() => Ok(()), + let token = tokio_util::sync::CancellationToken::new(); + let serve = tokio::spawn({ + let token = token.clone(); + async move { registry.serve(token).await } + }); + match tokio::signal::ctrl_c().await { + Ok(()) => {} + Err(err) => tracing::warn!(?err, "ctrl_c install failed; cancelling anyway"), + } + token.cancel(); + match serve.await { + Ok(Ok(())) => Ok(()), + Ok(Err(err)) => Err(err), + Err(join_err) => Err(join_err.into()), } } diff --git a/rivetkit-rust/packages/rivetkit/src/registry.rs b/rivetkit-rust/packages/rivetkit/src/registry.rs index bb230a033a..6b81b4ce4f 100644 --- a/rivetkit-rust/packages/rivetkit/src/registry.rs +++ b/rivetkit-rust/packages/rivetkit/src/registry.rs @@ -5,6 +5,7 @@ use rivet_error::RivetError; use rivetkit_core::{ ActorConfig, ActorFactory as CoreActorFactory, ActorStart, CoreRegistry, ServeConfig, }; +use tokio_util::sync::CancellationToken; use crate::{ actor::Actor, @@ -47,12 +48,16 @@ impl Registry { self } - pub async fn serve(self) -> Result<()> { - self.inner.serve().await + pub async fn serve(self, shutdown: CancellationToken) -> Result<()> { + self.inner.serve(shutdown).await } - pub async fn serve_with_config(self, config: ServeConfig) -> Result<()> { - self.inner.serve_with_config(config).await + pub async fn serve_with_config( + self, + config: ServeConfig, + shutdown: CancellationToken, + ) -> Result<()> { + self.inner.serve_with_config(config, shutdown).await } } diff --git a/rivetkit-rust/packages/rivetkit/tests/client.rs b/rivetkit-rust/packages/rivetkit/tests/client.rs index 202bd27c1b..c219095058 100644 --- a/rivetkit-rust/packages/rivetkit/tests/client.rs +++ b/rivetkit-rust/packages/rivetkit/tests/client.rs @@ -161,6 +161,7 @@ fn test_envoy_handle(endpoint: String) -> EnvoyHandle { )), protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), shutting_down: AtomicBool::new(false), + stopped_tx: tokio::sync::watch::channel(true).0, }); EnvoyHandle::from_shared(shared) diff --git a/rivetkit-typescript/packages/rivetkit-napi/index.d.ts b/rivetkit-typescript/packages/rivetkit-napi/index.d.ts index 36766247d3..e2f02d7a7d 100644 --- a/rivetkit-typescript/packages/rivetkit-napi/index.d.ts +++ b/rivetkit-typescript/packages/rivetkit-napi/index.d.ts @@ -290,6 +290,14 @@ export declare class CoreRegistry { constructor() register(name: string, factory: NapiActorFactory): void serve(config: JsServeConfig): Promise + /** + * Trip the shutdown token and tear down any live serverless runtime. + * + * Idempotent. Safe to call when neither mode has been activated. + * Does not block on the `serve()` future; TS awaits that promise + * separately to avoid re-entrancy. + */ + shutdown(): Promise handleServerlessRequest(req: JsServerlessRequest, onStreamEvent: (...args: any[]) => any, cancelToken: CancellationToken, config: JsServeConfig): Promise } export declare class Schedule { diff --git a/rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs b/rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs index fdc2919d3f..174c5d55f6 100644 --- a/rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs +++ b/rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs @@ -540,7 +540,7 @@ impl ActorContext { let status = abort.call(Ok(()), ThreadsafeFunctionCallMode::NonBlocking); tracing::debug!(kind = "abortSignal", ?status, "napi TSF callback returned"); if status != napi::Status::Ok { - tracing::warn!(?status, "failed to deliver abort signal"); + tracing::warn!(actor_id, ?status, "failed to deliver abort signal"); } }); diff --git a/rivetkit-typescript/packages/rivetkit-napi/src/lib.rs b/rivetkit-typescript/packages/rivetkit-napi/src/lib.rs index 6fec66d759..0abef32843 100644 --- a/rivetkit-typescript/packages/rivetkit-napi/src/lib.rs +++ b/rivetkit-typescript/packages/rivetkit-napi/src/lib.rs @@ -79,8 +79,11 @@ pub(crate) fn init_tracing(log_level: Option<&str>) { .unwrap_or_else(|| "warn".to_string()); tracing_subscriber::fmt() + .json() .with_env_filter(tracing_subscriber::EnvFilter::new(&filter)) .with_target(true) + .with_current_span(true) + .with_span_list(false) .with_writer(std::io::stdout) .init(); }); diff --git a/rivetkit-typescript/packages/rivetkit-napi/src/napi_actor_events.rs b/rivetkit-typescript/packages/rivetkit-napi/src/napi_actor_events.rs index 5fd570acec..f4f4222fb8 100644 --- a/rivetkit-typescript/packages/rivetkit-napi/src/napi_actor_events.rs +++ b/rivetkit-typescript/packages/rivetkit-napi/src/napi_actor_events.rs @@ -310,6 +310,10 @@ fn configure_run_handler(bindings: &CallbackBindings, ctx: &ActorContext) -> Run run_handler } +#[tracing::instrument( + skip_all, + fields(actor_id = %ctx.inner().actor_id()), +)] pub(crate) async fn dispatch_event( event: ActorEvent, bindings: &Arc, @@ -329,7 +333,19 @@ pub(crate) async fn dispatch_event( conn, reply, } => { + tracing::info!( + actor_id = %ctx.inner().actor_id(), + action_name = %name, + args_len = args.len(), + has_conn = conn.is_some(), + "napi: dispatching ActorEvent::Action to JS" + ); let Some(callback) = bindings.actions.get(&name).cloned() else { + tracing::warn!( + actor_id = %ctx.inner().actor_id(), + action_name = %name, + "napi: no action callback registered", + ); reply.send(Err(action_not_found(name))); return; }; @@ -338,6 +354,7 @@ pub(crate) async fn dispatch_event( let ctx = ctx.clone(); spawn_reply(tasks, abort.clone(), reply, async move { + tracing::info!(action_name = %name, "napi: invoking action JS callback"); let output = with_dispatch_cancel_token(|cancel_token| { with_structured_timeout( "actor", @@ -356,6 +373,11 @@ pub(crate) async fn dispatch_event( ) }) .await?; + tracing::info!( + action_name = %name, + output_len = output.len(), + "napi: action JS callback returned" + ); if let Some(callback) = on_before_action_response { with_structured_timeout( @@ -522,8 +544,9 @@ pub(crate) async fn dispatch_event( return; }; let ctx = ctx.clone(); + let actor_id = ctx.inner().actor_id().to_owned(); let timeout = config.on_connect_timeout; - spawn_task(tasks, abort.clone(), async move { + spawn_task(tasks, abort.clone(), actor_id, async move { with_timeout( "onDisconnect", timeout, @@ -577,7 +600,11 @@ pub(crate) async fn dispatch_event( } .await; if let Err(error) = result { - tracing::error!(?error, "graceful cleanup callback failed"); + tracing::error!( + actor_id = %ctx.inner().actor_id(), + ?error, + "graceful cleanup callback failed", + ); } reply.send(Ok(())); }); @@ -598,7 +625,11 @@ pub(crate) async fn dispatch_event( } .await; if let Err(error) = result { - tracing::error!(?error, "disconnect cleanup callback failed"); + tracing::error!( + actor_id = %ctx.inner().actor_id(), + ?error, + "disconnect cleanup callback failed", + ); } reply.send(Ok(())); }); @@ -712,7 +743,7 @@ pub(crate) fn spawn_reply( }); } -fn spawn_task(tasks: &mut JoinSet<()>, abort: CancellationToken, work: F) +fn spawn_task(tasks: &mut JoinSet<()>, abort: CancellationToken, actor_id: String, work: F) where F: std::future::Future> + Send + 'static, { @@ -721,7 +752,7 @@ where _ = abort.cancelled() => {} result = work => { if let Err(error) = result { - tracing::error!(?error, "napi background callback failed"); + tracing::error!(actor_id, ?error, "napi background callback failed"); } } } diff --git a/rivetkit-typescript/packages/rivetkit-napi/src/registry.rs b/rivetkit-typescript/packages/rivetkit-napi/src/registry.rs index 555d82cf05..cac97546ab 100644 --- a/rivetkit-typescript/packages/rivetkit-napi/src/registry.rs +++ b/rivetkit-typescript/packages/rivetkit-napi/src/registry.rs @@ -6,11 +6,12 @@ use napi::JsObject; use napi::bindgen_prelude::{Buffer, Env, Promise}; use napi::threadsafe_function::{ErrorStrategy, ThreadSafeCallContext, ThreadsafeFunction}; use napi_derive::napi; -use parking_lot::Mutex; use rivetkit_core::{ CoreRegistry as NativeCoreRegistry, CoreServerlessRuntime, ServeConfig, ServerlessRequest, serverless::ServerlessStreamError, }; +use tokio::sync::{Mutex as TokioMutex, Notify}; +use tokio_util::sync::CancellationToken as CoreCancellationToken; use crate::actor_factory::NapiActorFactory; use crate::cancellation_token::CancellationToken; @@ -66,13 +67,35 @@ enum ServerlessStreamEvent { }, } +/// Registry lifecycle state machine. +/// +/// Mode A (`serve`) and Mode B (`handle_serverless_request` -> `Serverless(...)`) +/// are mutually exclusive per instance: both transition out of `Registering`. +/// `BuildingServerless` is a sentinel held across the `into_serverless_runtime` +/// `.await` so a concurrent `shutdown()` can observe an in-flight build and +/// either wait for it to settle into `Serverless(_)` (then tear it down) or +/// transition directly to `ShutDown` while the build-side checks terminal +/// state before installing. +enum RegistryState { + Registering(NativeCoreRegistry), + BuildingServerless, + Serving, + Serverless(CoreServerlessRuntime), + ShuttingDown, + ShutDown, +} + #[napi] #[derive(Clone)] pub struct CoreRegistry { - // Registration is a synchronous N-API boundary; the lock is released before - // async serving begins. - inner: Arc>>, - serverless_runtime: Arc>>, + state: Arc>, + shutdown_token: CoreCancellationToken, + /// Notified whenever the state transitions out of `BuildingServerless` + /// (to `Serverless(_)` on success, or `ShutDown` on failure/shutdown). + /// Lets concurrent `ensure_serverless_runtime` callers that arrive during + /// a build wait for the build to settle and then re-check the fast path + /// instead of erroring with a misleading mode-conflict. + build_complete: Arc, } #[napi] @@ -82,19 +105,31 @@ impl CoreRegistry { crate::init_tracing(None); tracing::debug!(class = "CoreRegistry", "constructed napi class"); Self { - inner: Arc::new(Mutex::new(Some(NativeCoreRegistry::new()))), - serverless_runtime: Arc::new(Mutex::new(None)), + state: Arc::new(TokioMutex::new(RegistryState::Registering( + NativeCoreRegistry::new(), + ))), + shutdown_token: CoreCancellationToken::new(), + build_complete: Arc::new(Notify::new()), } } #[napi] pub fn register(&self, name: String, factory: &NapiActorFactory) -> napi::Result<()> { - let mut guard = self.inner.lock(); - let registry = guard - .as_mut() - .ok_or_else(|| registry_already_serving_error())?; - registry.register_shared(&name, factory.actor_factory()); - Ok(()) + // Registration runs on the sync N-API thread before any async work. + // `try_lock` must always succeed here: no other path holds the lock at + // this point. If somehow contended, surface the structured error rather + // than blocking. + let mut guard = self + .state + .try_lock() + .map_err(|_| registry_register_busy_error())?; + match &mut *guard { + RegistryState::Registering(registry) => { + registry.register_shared(&name, factory.actor_factory()); + Ok(()) + } + _ => Err(registry_not_registering_error()), + } } #[napi] @@ -109,36 +144,92 @@ impl CoreRegistry { "serving native registry" ); let registry = { - let mut guard = self.inner.lock(); - guard - .take() - .ok_or_else(|| registry_already_serving_error())? + let mut guard = self.state.lock().await; + match std::mem::replace(&mut *guard, RegistryState::Serving) { + RegistryState::Registering(registry) => registry, + other => { + // Restore prior state so later shutdown sees the right variant. + *guard = other; + return Err(registry_not_registering_error()); + } + } }; registry - .serve_with_config(ServeConfig { - version: config.version, - endpoint: config.endpoint, - token: config.token, - namespace: config.namespace, - pool_name: config.pool_name, - engine_binary_path: config.engine_binary_path.map(PathBuf::from), - handle_inspector_http_in_runtime: config - .handle_inspector_http_in_runtime - .unwrap_or(false), - serverless_base_path: config.serverless_base_path, - serverless_package_version: config.serverless_package_version, - serverless_client_endpoint: config.serverless_client_endpoint, - serverless_client_namespace: config.serverless_client_namespace, - serverless_client_token: config.serverless_client_token, - serverless_validate_endpoint: config.serverless_validate_endpoint, - serverless_max_start_payload_bytes: config.serverless_max_start_payload_bytes - as usize, - }) + .serve_with_config( + ServeConfig { + version: config.version, + endpoint: config.endpoint, + token: config.token, + namespace: config.namespace, + pool_name: config.pool_name, + engine_binary_path: config.engine_binary_path.map(PathBuf::from), + handle_inspector_http_in_runtime: config + .handle_inspector_http_in_runtime + .unwrap_or(false), + serverless_base_path: config.serverless_base_path, + serverless_package_version: config.serverless_package_version, + serverless_client_endpoint: config.serverless_client_endpoint, + serverless_client_namespace: config.serverless_client_namespace, + serverless_client_token: config.serverless_client_token, + serverless_validate_endpoint: config.serverless_validate_endpoint, + serverless_max_start_payload_bytes: config.serverless_max_start_payload_bytes + as usize, + }, + self.shutdown_token.clone(), + ) .await .map_err(napi_anyhow_error) } + /// Trip the shutdown token and tear down any live serverless runtime. + /// + /// Idempotent. Safe to call when neither mode has been activated. + /// Does not block on the `serve()` future; TS awaits that promise + /// separately to avoid re-entrancy. + #[napi] + pub async fn shutdown(&self) -> napi::Result<()> { + tracing::debug!(class = "CoreRegistry", "shutdown requested"); + // Trip the cancel first, outside the lock, so a `serve_with_config` + // already past the state transition observes cancel promptly. + self.shutdown_token.cancel(); + + let (runtime, was_building) = { + let mut guard = self.state.lock().await; + match std::mem::replace(&mut *guard, RegistryState::ShuttingDown) { + RegistryState::Registering(_) | RegistryState::Serving => (None, false), + RegistryState::Serverless(runtime) => (Some(runtime), false), + RegistryState::BuildingServerless => { + // An `ensure_serverless_runtime` call is mid-build. Its + // post-build re-check will observe `shutdown_token` and + // tear down the runtime itself before settling state. + (None, true) + } + RegistryState::ShuttingDown | RegistryState::ShutDown => { + // Already in progress / done. + *guard = RegistryState::ShutDown; + return Ok(()); + } + } + }; + + if let Some(runtime) = runtime { + runtime.shutdown().await; + } + + if !was_building { + let mut guard = self.state.lock().await; + *guard = RegistryState::ShutDown; + } + // Wake any `ensure_serverless_runtime` waiters parked on + // `BuildingServerless`. They re-check state and observe the shutdown. + // Also covers the case where `was_building` is true: the builder + // itself is not a waiter, but future callers that arrive while the + // builder is draining need to see `ShuttingDown` and error promptly. + self.build_complete.notify_waiters(); + Ok(()) + } + #[napi(ts_return_type = "Promise")] pub fn handle_serverless_request( &self, @@ -208,17 +299,63 @@ impl CoreRegistry { &self, config: JsServeConfig, ) -> napi::Result { - if let Some(runtime) = self.serverless_runtime.lock().as_ref().cloned() { - return Ok(runtime); + // Loop handles the "another caller is mid-build" case: arm the notify + // before re-checking so we can't miss a wakeup, then wait for the + // builder to transition out of `BuildingServerless`. + loop { + { + let guard = self.state.lock().await; + if let RegistryState::Serverless(runtime) = &*guard { + return Ok(runtime.clone()); + } + if matches!( + *guard, + RegistryState::ShuttingDown | RegistryState::ShutDown + ) { + return Err(registry_shut_down_error()); + } + if matches!(*guard, RegistryState::Serving) { + return Err(registry_wrong_mode_error()); + } + if matches!(*guard, RegistryState::BuildingServerless) { + // Another caller is building. Arm the notification before + // dropping the lock so a completion we race against still + // wakes us. + let notify = self.build_complete.clone(); + let notified = notify.notified(); + tokio::pin!(notified); + notified.as_mut().enable(); + drop(guard); + notified.await; + continue; + } + // RegistryState::Registering(_): fall through to build. + } + + // Transition Registering -> BuildingServerless, drop lock, build, + // re-acquire, install or tear down based on terminal state. + let registry = { + let mut guard = self.state.lock().await; + match std::mem::replace(&mut *guard, RegistryState::BuildingServerless) { + RegistryState::Registering(registry) => registry, + other => { + // State changed under us between fast-path and here; + // restore and re-evaluate. + *guard = other; + continue; + } + } + }; + return self.build_serverless_runtime(registry, config).await; } + } - let registry = { - let mut guard = self.inner.lock(); - guard - .take() - .ok_or_else(|| registry_already_serving_error())? - }; - let runtime = registry + async fn build_serverless_runtime( + &self, + registry: NativeCoreRegistry, + config: JsServeConfig, + ) -> napi::Result { + let build_result = registry .into_serverless_runtime(ServeConfig { version: config.version, endpoint: config.endpoint, @@ -238,10 +375,48 @@ impl CoreRegistry { serverless_max_start_payload_bytes: config.serverless_max_start_payload_bytes as usize, }) - .await - .map_err(napi_anyhow_error)?; - *self.serverless_runtime.lock() = Some(runtime.clone()); - Ok(runtime) + .await; + + // Re-acquire the lock and re-check state. Shutdown may have run during + // the build. If so, tear down the freshly-built runtime rather than + // installing it, preventing an orphaned runtime post-shutdown. + let mut guard = self.state.lock().await; + let result = match build_result { + Ok(runtime) => { + if self.shutdown_token.is_cancelled() + || matches!( + *guard, + RegistryState::ShuttingDown | RegistryState::ShutDown + ) { + // Drop the lock while we drain the envoy. + drop(guard); + runtime.shutdown().await; + let mut guard = self.state.lock().await; + *guard = RegistryState::ShutDown; + drop(guard); + Err(registry_shut_down_error()) + } else { + *guard = RegistryState::Serverless(runtime.clone()); + drop(guard); + Ok(runtime) + } + } + Err(error) => { + // Build failed. The inner `NativeCoreRegistry` was consumed by + // `into_serverless_runtime` and cannot be restored. Any future + // call on this `CoreRegistry` must observe a terminal state + // with a clear error, not the misleading `wrong_mode` that + // leaving `BuildingServerless` would produce. + *guard = RegistryState::ShutDown; + drop(guard); + Err(napi_anyhow_error(error)) + } + }; + // Wake any `ensure_serverless_runtime` callers parked on + // `BuildingServerless`. They re-check state and either get the cached + // runtime or the shutdown error. + self.build_complete.notify_waiters(); + result } } @@ -291,11 +466,41 @@ impl From for JsServerlessStreamError { } } -fn registry_already_serving_error() -> napi::Error { +fn registry_not_registering_error() -> napi::Error { + napi_anyhow_error( + NapiInvalidState { + state: "core registry".to_owned(), + reason: "already serving or shut down".to_owned(), + } + .build(), + ) +} + +fn registry_wrong_mode_error() -> napi::Error { + napi_anyhow_error( + NapiInvalidState { + state: "core registry".to_owned(), + reason: "mode conflict: another run mode is already active".to_owned(), + } + .build(), + ) +} + +fn registry_shut_down_error() -> napi::Error { + napi_anyhow_error( + NapiInvalidState { + state: "core registry".to_owned(), + reason: "shut down".to_owned(), + } + .build(), + ) +} + +fn registry_register_busy_error() -> napi::Error { napi_anyhow_error( NapiInvalidState { state: "core registry".to_owned(), - reason: "already serving".to_owned(), + reason: "register called concurrently with serve or shutdown".to_owned(), } .build(), ) diff --git a/rivetkit-typescript/packages/rivetkit/src/common/log.ts b/rivetkit-typescript/packages/rivetkit/src/common/log.ts index 62dd7003d6..25a7853155 100644 --- a/rivetkit-typescript/packages/rivetkit/src/common/log.ts +++ b/rivetkit-typescript/packages/rivetkit/src/common/log.ts @@ -6,12 +6,6 @@ import { } from "pino"; import { z } from "zod/v4"; import { getLogLevel, getLogTarget, getLogTimestamp } from "@/utils/env-vars"; -import { - castToLogValue, - formatTimestamp, - LOGGER_CONFIG, - stringify, -} from "./logfmt"; export type { Logger } from "pino"; @@ -66,47 +60,6 @@ export function configureBaseLogger(logger: Logger): void { loggerCache.clear(); } -// TODO: This can be simplified in logfmt.ts -function customWrite(level: string, o: any) { - const entries: any = {}; - - // Add timestamp if enabled - if (getLogTimestamp() && o.time) { - const date = typeof o.time === "number" ? new Date(o.time) : new Date(); - entries.ts = formatTimestamp(date); - } - - // Add level - entries.level = level.toUpperCase(); - - // Add target if present - if (o.target) { - entries.target = o.target; - } - - // Add message - if (o.msg) { - entries.msg = o.msg; - } - - // Add other properties - for (const [key, value] of Object.entries(o)) { - if ( - key !== "time" && - key !== "level" && - key !== "target" && - key !== "msg" && - key !== "pid" && - key !== "hostname" - ) { - entries[key] = castToLogValue(value); - } - } - - const output = stringify(entries); - console.log(output); -} - /** * Configure the default logger with optional log level. */ @@ -128,69 +81,6 @@ export function configureDefaultLogger(logLevel?: LogLevel) { }, }, timestamp: getLogTimestamp() ? stdTimeFunctions.epochTime : false, - browser: { - write: { - fatal: customWrite.bind(null, "fatal"), - error: customWrite.bind(null, "error"), - warn: customWrite.bind(null, "warn"), - info: customWrite.bind(null, "info"), - debug: customWrite.bind(null, "debug"), - trace: customWrite.bind(null, "trace"), - }, - }, - hooks: { - logMethod(inputArgs, method, level) { - // TODO: This is a hack to not implement our own Pino transport target. We can get better perf if we have our own transport target. - - const levelMap: Record = { - 10: "trace", - 20: "debug", - 30: "info", - 40: "warn", - 50: "error", - 60: "fatal", - }; - const levelName = levelMap[level] || "info"; - const time = getLogTimestamp() ? Date.now() : undefined; - - // Get bindings from the logger instance (child logger fields) - const bindings = (this as any).bindings?.() || {}; - - // TODO: This can be simplified in logfmt.ts - if (inputArgs.length >= 2) { - const [objOrMsg, msg] = inputArgs; - if (typeof objOrMsg === "object" && objOrMsg !== null) { - customWrite(levelName, { - ...bindings, - ...objOrMsg, - msg, - time, - }); - } else { - customWrite(levelName, { - ...bindings, - msg: String(objOrMsg), - time, - }); - } - } else if (inputArgs.length === 1) { - const [objOrMsg] = inputArgs; - if (typeof objOrMsg === "object" && objOrMsg !== null) { - customWrite(levelName, { - ...bindings, - ...objOrMsg, - time, - }); - } else { - customWrite(levelName, { - ...bindings, - msg: String(objOrMsg), - time, - }); - } - } - }, - }, }); loggerCache.clear(); diff --git a/rivetkit-typescript/packages/rivetkit/src/common/logfmt.ts b/rivetkit-typescript/packages/rivetkit/src/common/logfmt.ts deleted file mode 100644 index 528e8b1419..0000000000 --- a/rivetkit-typescript/packages/rivetkit/src/common/logfmt.ts +++ /dev/null @@ -1,221 +0,0 @@ -import { type LogLevel, LogLevels } from "./log-levels"; - -const LOG_LEVEL_COLORS: Record = { - [LogLevels.CRITICAL]: "\x1b[31m", // Red - [LogLevels.ERROR]: "\x1b[31m", // Red - [LogLevels.WARN]: "\x1b[33m", // Yellow - [LogLevels.INFO]: "\x1b[32m", // Green - [LogLevels.DEBUG]: "\x1b[36m", // Cyan - [LogLevels.TRACE]: "\x1b[36m", // Cyan -}; - -const RESET_COLOR = "\x1b[0m"; - -/** - * Serializes logfmt line from an object. - * - * ## Styling Methodology - * - * The three things you need to know for every log line is the level, the - * message, and who called it. These properties are highlighted in different colros - * and sorted in th eorder that you usually read them. - * - * Once you've found a log line you care about, then you want to find the - * property you need to see. The property names are bolded and the default color - * while the rest of the data is dim. This lets you scan to find the property - * name quickly then look closer to read the data associated with the - * property. - */ -export function stringify(data: any) { - let line = ""; - const entries = Object.entries(data); - - for (let i = 0; i < entries.length; i++) { - const [key, valueRaw] = entries[i]; - - let isNull = false; - let valueString: string; - if (valueRaw == null) { - isNull = true; - valueString = ""; - } else { - valueString = valueRaw.toString(); - } - - // Clip value unless specifically the error message - if (valueString.length > 512 && key !== "msg" && key !== "error") - valueString = `${valueString.slice(0, 512)}...`; - - const needsQuoting = - valueString.indexOf(" ") > -1 || valueString.indexOf("=") > -1; - const needsEscaping = - valueString.indexOf('"') > -1 || valueString.indexOf("\\") > -1; - - valueString = valueString.replace(/\n/g, "\\n"); - if (needsEscaping) valueString = valueString.replace(/["\\]/g, "\\$&"); - if (needsQuoting || needsEscaping) valueString = `"${valueString}"`; - if (valueString === "" && !isNull) valueString = '""'; - - if (LOGGER_CONFIG.enableColor) { - // With color - - // Special message colors - let color = "\x1b[2m"; - if (key === "level") { - const level = LogLevels[valueString as LogLevel]; - const levelColor = LOG_LEVEL_COLORS[level]; - if (levelColor) { - color = levelColor; - } - } else if (key === "msg") { - color = "\x1b[32m"; - } else if (key === "trace") { - color = "\x1b[34m"; - } - - // Format line - line += `\x1b[0m\x1b[1m${key}\x1b[0m\x1b[2m=\x1b[0m${color}${valueString}${RESET_COLOR}`; - } else { - // No color - line += `${key}=${valueString}`; - } - - if (i !== entries.length - 1) { - line += " "; - } - } - - return line; -} - -export function formatTimestamp(date: Date): string { - const year = date.getUTCFullYear(); - const month = String(date.getUTCMonth() + 1).padStart(2, "0"); - const day = String(date.getUTCDate()).padStart(2, "0"); - const hours = String(date.getUTCHours()).padStart(2, "0"); - const minutes = String(date.getUTCMinutes()).padStart(2, "0"); - const seconds = String(date.getUTCSeconds()).padStart(2, "0"); - const milliseconds = String(date.getUTCMilliseconds()).padStart(3, "0"); - - return `${year}-${month}-${day}T${hours}:${minutes}:${seconds}.${milliseconds}Z`; -} - -export function castToLogValue(v: unknown): any { - if ( - typeof v === "string" || - typeof v === "number" || - typeof v === "bigint" || - typeof v === "boolean" || - v === null || - v === undefined - ) { - return v; - } - if (v instanceof Error) { - //args.push(...errorToLogEntries(k, v)); - return String(v); - } - try { - return JSON.stringify(v); - } catch { - return "[cannot stringify]"; - } -} - -// MARK: Config -interface GlobalLoggerConfig { - enableColor: boolean; - enableSpreadObject: boolean; - enableErrorStack: boolean; -} - -export const LOGGER_CONFIG: GlobalLoggerConfig = { - enableColor: false, - enableSpreadObject: false, - enableErrorStack: false, -}; - -// MARK: Utils -/** - * Converts an object in to an easier to read KV of entries. - */ -export function spreadObjectToLogEntries(base: string, data: unknown): any { - if ( - LOGGER_CONFIG.enableSpreadObject && - typeof data === "object" && - !Array.isArray(data) && - data !== null && - Object.keys(data).length !== 0 && - Object.keys(data).length < 16 - ) { - const logData: any = {}; - for (const key in data) { - Object.assign( - logData, - spreadObjectToLogEntries( - `${base}.${key}`, - // biome-ignore lint/suspicious/noExplicitAny: FIXME - (data as any)[key], - ), - ); - } - return logData; - } - - return { [base]: JSON.stringify(data) }; -} - -export function errorToLogEntries(base: string, error: unknown): any { - if (error instanceof Error) { - return { - [`${base}.message`]: error.message, - ...(LOGGER_CONFIG.enableErrorStack && error.stack - ? { [`${base}.stack`]: formatStackTrace(error.stack) } - : {}), - ...(error.cause - ? errorToLogEntries(`${base}.cause`, error.cause) - : {}), - }; - } - return { [base]: `${error}` }; -} - -// export function errorToLogEntries(base: string, error: unknown): LogEntry[] { -// if (error instanceof RuntimeError) { -// return [ -// [`${base}.code`, error.code], -// [`${base}.description`, error.errorConfig?.description], -// [`${base}.module`, error.moduleName], -// ...(error.trace ? [[`${base}.trace`, stringifyTrace(error.trace)] as LogEntry] : []), -// ...(LOGGER_CONFIG.enableErrorStack && error.stack -// ? [[`${base}.stack`, formatStackTrace(error.stack)] as LogEntry] -// : []), -// ...(error.meta ? [[`${base}.meta`, JSON.stringify(error.meta)] as LogEntry] : []), -// ...(error.cause ? errorToLogEntries(`${base}.cause`, error.cause) : []), -// ]; -// } else if (error instanceof Error) { -// return [ -// [`${base}.name`, error.name], -// [`${base}.message`, error.message], -// ...(LOGGER_CONFIG.enableErrorStack && error.stack -// ? [[`${base}.stack`, formatStackTrace(error.stack)] as LogEntry] -// : []), -// ...(error.cause ? errorToLogEntries(`${base}.cause`, error.cause) : []), -// ]; -// } else { -// return [ -// [base, `${error}`], -// ]; -// } -// } - -/** - * Formats a JS stack trace in to a legible one-liner. - */ -function formatStackTrace(stackTrace: string): string { - const regex = /at (.+?)$/gm; - const matches = [...stackTrace.matchAll(regex)]; - // Reverse array since the stack goes from top level -> bottom level - matches.reverse(); - return matches.map((match) => match[1].trim()).join(" > "); -} diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts b/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts index ac74d434d0..4c0fb830a4 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts @@ -169,6 +169,40 @@ export const RegistryConfigSchema = z envoy: EnvoyConfigSchema.optional().default(() => EnvoyConfigSchema.parse({}), ), + + // MARK: Shutdown + /** + * Graceful shutdown configuration for SIGINT/SIGTERM. + * + * When a persistent envoy is running (Mode A, started via `registry.start()`), + * rivetkit installs Node SIGINT/SIGTERM handlers that call into core's + * `shutdown()` and wait up to `gracePeriodMs` for the envoy to drain + * before re-raising the signal to let Node exit via its default path. + * + * Handlers are NOT installed when `handler(request)` is used alone + * (Mode B / serverless): platform runtimes (Cloudflare Workers, Vercel, + * Deno Deploy) own their own signal policy there, and `process.on` may + * not exist. + */ + shutdown: z + .object({ + /** + * Wait this many milliseconds for the serve promise to resolve + * after calling `CoreRegistry::shutdown()`. Defaults to 30s, + * matching Kubernetes `terminationGracePeriodSeconds`. + * + * Must be >= rivetkit-core's drain timeout (20s) + margin. + */ + gracePeriodMs: z.number().int().min(1_000).optional().default(30_000), + /** + * If true, rivetkit will not install SIGINT/SIGTERM handlers. + * Use when the host application owns signal policy and will + * call `nativeRegistry.shutdown()` itself. + */ + disableSignalHandlers: z.boolean().optional().default(false), + }) + .optional() + .default(() => ({ gracePeriodMs: 30_000, disableSignalHandlers: false })), }) .transform((config, ctx) => { const isDevEnv = isDev(); diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/index.ts b/rivetkit-typescript/packages/rivetkit/src/registry/index.ts index 5820fbb231..a5d9e3f89a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/index.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/index.ts @@ -5,10 +5,13 @@ import { RegistryConfigSchema, } from "./config"; import { ENGINE_ENDPOINT } from "@/common/engine"; +import { logger } from "./log"; import { buildNativeRegistry } from "./native"; import { configureServerlessPool } from "@/serverless/configure"; import { VERSION } from "@/utils"; +type ShutdownSignal = "SIGINT" | "SIGTERM"; + export type FetchHandler = ( request: Request, ...args: any @@ -33,6 +36,9 @@ export class Registry { #nativeServerlessPromise?: ReturnType; #configureServerlessPoolPromise?: Promise; #welcomePrinted = false; + #shutdownInstalled = false; + #shutdownInFlight: Promise | null = null; + #signalHandlers: Partial void>> = {}; constructor(config: RegistryConfigInput) { this.#config = config; @@ -132,50 +138,59 @@ export class Registry { headers[key] = value; }); - const head = await registry.handleServerlessRequest( - { - method: request.method, - url: request.url, - headers, - body: Buffer.from(requestBody), - }, - async ( - error: unknown, - event?: { - kind: "chunk" | "end"; - chunk?: Buffer; - error?: { - group: string; - code: string; - message: string; - }; + let head; + try { + head = await registry.handleServerlessRequest( + { + method: request.method, + url: request.url, + headers, + body: Buffer.from(requestBody), }, - ) => { - if (error) throw error; - if (!event || settled) return; - if (event.kind === "chunk") { - await waitForBackpressure(); - if (settled) return; - if (event.chunk) controllerRef?.enqueue(event.chunk); - return; - } + async ( + error: unknown, + event?: { + kind: "chunk" | "end"; + chunk?: Buffer; + error?: { + group: string; + code: string; + message: string; + }; + }, + ) => { + if (error) throw error; + if (!event || settled) return; + if (event.kind === "chunk") { + await waitForBackpressure(); + if (settled) return; + if (event.chunk) controllerRef?.enqueue(event.chunk); + return; + } - settled = true; - resolveBackpressure(); - request.signal.removeEventListener("abort", abort); - if (event.error) { - controllerRef?.error( - new Error( - `${event.error.group}.${event.error.code}: ${event.error.message}`, - ), - ); - } else { - controllerRef?.close(); - } - }, - cancelToken, - serveConfig, - ); + settled = true; + resolveBackpressure(); + request.signal.removeEventListener("abort", abort); + if (event.error) { + controllerRef?.error( + new Error( + `${event.error.group}.${event.error.code}: ${event.error.message}`, + ), + ); + } else { + controllerRef?.close(); + } + }, + cancelToken, + serveConfig, + ); + } catch (err) { + // The NAPI call itself rejected (e.g. `registry_shut_down_error`). + // Clean up the abort listener so it doesn't leak, then propagate. + request.signal.removeEventListener("abort", abort); + cancelToken.cancel(); + throw err; + } return new Response(stream, { status: head.status, @@ -202,17 +217,151 @@ export class Registry { */ #startEnvoy(config: RegistryConfig, printWelcome: boolean) { if (!this.#nativeServePromise) { - this.#nativeServePromise = buildNativeRegistry( - config, - ).then(async ({ registry, serveConfig }) => { - await registry.serve(serveConfig); - }); + const nativeRegistryPromise = buildNativeRegistry(config); + this.#nativeServePromise = nativeRegistryPromise + .then(async ({ registry, serveConfig }) => { + await registry.serve(serveConfig); + }) + .catch((err) => { + // Always-attached catch so the stored promise never leaves a + // rejection unhandled. Downstream awaits (e.g. #runShutdown's + // Promise.race) attach their own catches and still observe + // resolution via the race. + logger().warn({ err }, "native registry serve errored"); + }); + // Install signal handlers once an envoy lifecycle has begun. Only + // Mode A ever reaches here. Mode B (handler(request)) intentionally + // does not install handlers because it runs on Workers/Vercel/Deno + // Deploy where `process.on` is absent or forbidden; those platforms + // own their own signal policy. + this.#installSignalHandlers(config, nativeRegistryPromise); } if (printWelcome) { this.#printWelcome(config, "serverful"); } } + #installSignalHandlers( + config: RegistryConfig, + nativeRegistryPromise: ReturnType, + ): void { + if (this.#shutdownInstalled) return; + if (config.shutdown?.disableSignalHandlers) return; + // Guard against non-Node runtimes (Workers/Edge) where `process` may + // exist but `process.on` is unavailable or forbidden. + if ( + typeof process === "undefined" || + typeof process.on !== "function" || + typeof process.kill !== "function" + ) { + return; + } + this.#shutdownInstalled = true; + + const install = (signal: ShutdownSignal) => { + const handler = () => + this.#onShutdownSignal(signal, config, nativeRegistryPromise); + this.#signalHandlers[signal] = handler; + process.on(signal, handler); + }; + install("SIGINT"); + install("SIGTERM"); + } + + #onShutdownSignal( + signal: ShutdownSignal, + config: RegistryConfig, + nativeRegistryPromise: ReturnType, + ): void { + if (this.#shutdownInFlight !== null) { + // Second delivery of the same (or another) shutdown signal. + // Remove our handler only (preserving any user-installed listeners) + // and re-raise so Node proceeds with its default exit path. + this.#removeSignalHandlers(); + process.kill(process.pid, signal); + return; + } + this.#shutdownInFlight = this.#runShutdown( + signal, + config, + nativeRegistryPromise, + ).catch((err) => { + logger().warn({ err }, "shutdown error"); + }); + } + + async #runShutdown( + signal: ShutdownSignal, + config: RegistryConfig, + nativeRegistryPromise: ReturnType, + ): Promise { + const gracePeriodMs = config.shutdown?.gracePeriodMs ?? 30_000; + // Race the entire drain sequence (both modes + serve promise) against + // a single grace ceiling. Without this, each mode's Rust-side drain + // (20s) could stack sequentially and blow past gracePeriodMs before + // we re-raise the signal. + const drain = async () => { + // Shut down every live `CoreRegistry` we know about. Mode A + // (`start()`) and Mode B (`handler()`) each build a separate + // native registry, so one signal handler fans out to both to + // honor the spec invariant "single shutdown tears down both modes". + const registries: Promise[] = [ + (async () => { + try { + const { registry } = await nativeRegistryPromise; + await registry.shutdown(); + } catch (err) { + logger().warn( + { err }, + "native registry shutdown errored (mode A)", + ); + } + })(), + ]; + if (this.#nativeServerlessPromise) { + registries.push( + (async () => { + try { + const { registry } = await this.#nativeServerlessPromise!; + await registry.shutdown(); + } catch (err) { + logger().warn( + { err }, + "native registry shutdown errored (mode B)", + ); + } + })(), + ); + } + await Promise.all(registries); + + if (this.#nativeServePromise) { + // Swallow rejection so the race doesn't itself reject; the + // always-attached `.catch` at the promise assignment site has + // already logged any serve-side error. + await this.#nativeServePromise.catch(() => undefined); + } + }; + await Promise.race([ + drain(), + new Promise((resolve) => + setTimeout(resolve, gracePeriodMs).unref?.(), + ), + ]); + this.#removeSignalHandlers(); + process.kill(process.pid, signal); + } + + #removeSignalHandlers(): void { + for (const [signal, handler] of Object.entries(this.#signalHandlers) as [ + ShutdownSignal, + () => void, + ][]) { + if (handler) process.removeListener(signal, handler); + } + this.#signalHandlers = {}; + } + public startEnvoy() { this.#startEnvoy(this.parseConfig(), true); }