diff --git a/.gitignore b/.gitignore index 7dd3d65272..d4dc7ecb5a 100644 --- a/.gitignore +++ b/.gitignore @@ -78,3 +78,6 @@ examples/*/public/ !examples/multiplayer-game-patterns/public/** !examples/multiplayer-game-patterns-vercel/public/ !examples/multiplayer-game-patterns-vercel/public/** + +# Native addon binaries +*.node diff --git a/CLAUDE.md b/CLAUDE.md index cbd229a4d0..f3ec25a531 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -89,6 +89,9 @@ docker-compose up -d git commit -m "chore(my-pkg): foo bar" ``` +- We use Graphite for stacked PRs. Diff against the parent branch (`gt ls` to see the stack), not `main`. +- To revert a file to the version before this branch's changes, checkout from the first child branch (below in the stack), not from `main` or the parent. Child branches contain the pre-this-branch state of files modified by branches further down the stack. + **Never push to `main` unless explicitly specified by the user.** ## Dependency Management @@ -339,6 +342,10 @@ When making changes to the engine or RivetKit, ensure the corresponding document - **Landing page changes**: When updating the landing page (`website/src/pages/index.astro` and its section components in `website/src/components/marketing/sections/`), update `README.md` to reflect the same headlines, features, benchmarks, and talking points where applicable. - **Sandbox provider changes**: When adding, removing, or modifying sandbox providers in `rivetkit-typescript/packages/rivetkit/src/sandbox/providers/`, update `website/src/content/docs/actors/sandbox.mdx` to keep provider documentation, option tables, and custom provider guidance in sync. +### CLAUDE.md conventions + +- When adding entries to any CLAUDE.md file, keep them concise. Ideally a single bullet point or minimal bullet points. Do not write paragraphs. + ### Comments - Write comments as normal, complete sentences. Avoid fragmented structures with parentheticals and dashes like `// Spawn engine (if configured) - regardless of start kind`. Instead, write `// Spawn the engine if configured`. Especially avoid dashes (hyphens are OK). diff --git a/engine/CLAUDE.md b/engine/CLAUDE.md index 86a04d8d2f..19175eb9a4 100644 --- a/engine/CLAUDE.md +++ b/engine/CLAUDE.md @@ -33,6 +33,10 @@ When changing a versioned VBARE schema, follow the existing migration pattern. - When adding fields to epoxy workflow state structs, mark them `#[serde(default)]` so Gasoline can replay older serialized state. - Epoxy integration tests that spin up `tests/common::TestCtx` must call `shutdown()` before returning. +## Concurrent containers + +Never use `Mutex>` or `RwLock>`. Use `scc::HashMap` (preferred), `moka::Cache` (for TTL/bounded), or `DashMap`. Same for sets: use `scc::HashSet` instead of `Mutex>`. Note that `scc` async methods do not hold locks across `.await` points. Use `entry_async` for atomic read-then-write. + ## Test snapshots Use `test-snapshot-gen` to generate and load RocksDB snapshots of the full UDB KV store for migration and integration tests. Scenarios produce per-replica RocksDB checkpoints stored under `engine/packages/test-snapshot-gen/snapshots/` (git LFS tracked). In tests, use `test_snapshot::SnapshotTestCtx::from_snapshot("scenario-name")` to boot a cluster from snapshot data. See `docs-internal/engine/TEST_SNAPSHOTS.md` for the full guide. diff --git a/engine/packages/epoxy/src/ops/kv/read_value.rs b/engine/packages/epoxy/src/ops/kv/read_value.rs index 3a7875f032..3d36af6e0a 100644 --- a/engine/packages/epoxy/src/ops/kv/read_value.rs +++ b/engine/packages/epoxy/src/ops/kv/read_value.rs @@ -21,10 +21,7 @@ pub(crate) struct LocalValueRead { /// 1. **V2 value** (`EPOXY_V2/replica/{id}/kv/{key}/value`). The current write path. /// 2. **Legacy committed value** (`EPOXY_V1/replica/{id}/kv/{key}/committed_value`). Written by /// the original EPaxos protocol. Deserialized as raw bytes with version 0 and mutable=false. -/// 3. **Legacy v2-format value** (`EPOXY_V1/replica/{id}/kv/{key}/value`). Written during the -/// intermediate v1-to-v2 transition where the key layout matched v2 but the subspace was -/// still v1. -/// 4. **Optimistic cache** (`EPOXY_V2/replica/{id}/kv/{key}/cache`). Only checked when +/// 3. **Optimistic cache** (`EPOXY_V2/replica/{id}/kv/{key}/cache`). Only checked when /// `include_cache` is true. Contains values fetched from remote replicas for the optimistic /// read path. /// @@ -38,24 +35,20 @@ pub(crate) async fn read_local_value( ) -> Result { let value_key = KvValueKey::new(key.clone()); let legacy_value_key = LegacyCommittedValueKey::new(key.clone()); - let legacy_v2_value_key = KvValueKey::new(key.clone()); let cache_key = KvOptimisticCacheKey::new(key); let subspace = keys::subspace(replica_id); let legacy_subspace = keys::legacy_subspace(replica_id); let packed_value_key = subspace.pack(&value_key); let packed_legacy_value_key = legacy_subspace.pack(&legacy_value_key); - let packed_legacy_v2_value_key = legacy_subspace.pack(&legacy_v2_value_key); let packed_cache_key = subspace.pack(&cache_key); ctx.udb()? .run(|tx| { let packed_value_key = packed_value_key.clone(); let packed_legacy_value_key = packed_legacy_value_key.clone(); - let packed_legacy_v2_value_key = packed_legacy_v2_value_key.clone(); let packed_cache_key = packed_cache_key.clone(); let value_key = value_key.clone(); let legacy_value_key = legacy_value_key.clone(); - let legacy_v2_value_key = legacy_v2_value_key.clone(); let cache_key = cache_key.clone(); async move { @@ -79,14 +72,6 @@ pub(crate) async fn read_local_value( }); } - // Legacy v2-format value (v1 subspace, v2 key layout) - if let Some(value) = tx.get(&packed_legacy_v2_value_key, Serializable).await? { - return Ok(LocalValueRead { - value: Some(legacy_v2_value_key.deserialize(&value)?), - cache_value: None, - }); - } - let cache_value = if include_cache { tx.get(&packed_cache_key, Serializable) .await? diff --git a/engine/packages/pegboard-envoy/src/conn.rs b/engine/packages/pegboard-envoy/src/conn.rs index 723fb6be30..768e5e9ce2 100644 --- a/engine/packages/pegboard-envoy/src/conn.rs +++ b/engine/packages/pegboard-envoy/src/conn.rs @@ -147,7 +147,7 @@ pub async fn handle_init( // Read existing data let (create_ts_entry, old_last_ping_ts_entry, version_entry) = tokio::try_join!( tx.read_opt(&create_ts_key, Serializable), - tx.read_opt(&create_ts_key, Serializable), + tx.read_opt(&last_ping_ts_key, Serializable), tx.read_opt(&version_key, Serializable), )?; diff --git a/engine/packages/pegboard-kv-channel/src/lib.rs b/engine/packages/pegboard-kv-channel/src/lib.rs index 6354ca2d15..3a13b319c2 100644 --- a/engine/packages/pegboard-kv-channel/src/lib.rs +++ b/engine/packages/pegboard-kv-channel/src/lib.rs @@ -25,7 +25,6 @@ use rivet_guard_core::{ }; use tokio::sync::{Mutex, mpsc, watch}; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; -use uuid::Uuid; pub use rivet_kv_channel_protocol as protocol; @@ -40,27 +39,13 @@ const KEY_WRAPPER_OVERHEAD: usize = 2; /// Prevents a malicious client from exhausting memory via unbounded actor_channels. const MAX_ACTORS_PER_CONNECTION: usize = 1000; -/// Shared state across all KV channel connections. -pub struct KvChannelState { - /// Maps actor_id string to the connection_id holding the single-writer lock and a reference - /// to that connection's open_actors set. The Arc reference allows lock eviction to remove the - /// actor from the old connection's set without acquiring the global lock on the KV hot path. - actor_locks: Mutex>>)>>, -} - pub struct PegboardKvChannelCustomServe { ctx: StandaloneCtx, - state: Arc, } impl PegboardKvChannelCustomServe { pub fn new(ctx: StandaloneCtx) -> Self { - Self { - ctx, - state: Arc::new(KvChannelState { - actor_locks: Mutex::new(HashMap::new()), - }), - } + Self { ctx } } } @@ -89,7 +74,6 @@ impl CustomServeTrait for PegboardKvChannelCustomServe { _after_hibernation: bool, ) -> Result> { let ctx = self.ctx.with_ray(req_ctx.ray_id(), req_ctx.req_id())?; - let state = self.state.clone(); // Parse URL params. let url = url::Url::parse(&format!("ws://placeholder{}", req_ctx.path())) @@ -123,44 +107,24 @@ impl CustomServeTrait for PegboardKvChannelCustomServe { .ok_or_else(|| namespace::errors::Namespace::NotFound.build()) .with_context(|| format!("namespace not found: {namespace_name}"))?; - // Assign connection ID. Uses UUID to eliminate any possibility of ID collision. - let conn_id = Uuid::new_v4(); let namespace_id = namespace.namespace_id; - tracing::info!(%conn_id, %namespace_id, "kv channel connection established"); + tracing::info!(%namespace_id, "kv channel connection established"); - // Track actors opened by this connection for cleanup on disconnect. + // Track actors opened by this connection. let open_actors: Arc>> = Arc::new(Mutex::new(HashSet::new())); let last_pong_ts = Arc::new(AtomicI64::new(util::timestamp::now())); - // Run the connection loop. Any error triggers cleanup below. let result = run_connection( ctx.clone(), - state.clone(), ws_handle, - conn_id, namespace_id, - open_actors.clone(), + open_actors, last_pong_ts, ) .await; - // Release all locks held by this connection. Only remove entries where the lock is still - // held by this conn_id, since another connection may have evicted it via ActorOpenRequest. - { - let open = open_actors.lock().await; - let mut locks = state.actor_locks.lock().await; - for actor_id in open.iter() { - if let Some((lock_conn, _)) = locks.get(actor_id) { - if *lock_conn == conn_id { - locks.remove(actor_id); - tracing::debug!(%conn_id, %actor_id, "released actor lock on disconnect"); - } - } - } - } - - tracing::info!(%conn_id, "kv channel connection closed"); + tracing::info!("kv channel connection closed"); result.map(|_| None) } @@ -170,9 +134,7 @@ impl CustomServeTrait for PegboardKvChannelCustomServe { async fn run_connection( ctx: StandaloneCtx, - state: Arc, ws_handle: WebSocketHandle, - conn_id: Uuid, namespace_id: Id, open_actors: Arc>>, last_pong_ts: Arc, @@ -200,9 +162,7 @@ async fn run_connection( // Run message loop. let msg_result = message_loop( &ctx, - &state, &ws_handle, - conn_id, namespace_id, &open_actors, &last_pong_ts, @@ -250,9 +210,7 @@ async fn ping_task( async fn message_loop( ctx: &StandaloneCtx, - state: &Arc, ws_handle: &WebSocketHandle, - conn_id: Uuid, namespace_id: Id, open_actors: &Arc>>, last_pong_ts: &AtomicI64, @@ -297,9 +255,7 @@ async fn message_loop( Message::Binary(data) => { handle_binary_message( ctx, - state, ws_handle, - conn_id, namespace_id, open_actors, last_pong_ts, @@ -329,9 +285,7 @@ async fn message_loop( async fn handle_binary_message( ctx: &StandaloneCtx, - state: &Arc, ws_handle: &WebSocketHandle, - conn_id: Uuid, namespace_id: Id, open_actors: &Arc>>, last_pong_ts: &AtomicI64, @@ -366,9 +320,7 @@ async fn handle_binary_message( let (tx, rx) = mpsc::channel(64); actor_tasks.spawn(actor_request_task( Clone::clone(ctx), - Clone::clone(state), Clone::clone(ws_handle), - conn_id, namespace_id, Clone::clone(open_actors), rx, @@ -421,41 +373,36 @@ async fn handle_binary_message( /// dropped (connection end) or after processing an ActorCloseRequest. async fn actor_request_task( ctx: StandaloneCtx, - state: Arc, ws_handle: WebSocketHandle, - conn_id: Uuid, namespace_id: Id, open_actors: Arc>>, mut rx: mpsc::Receiver, ) { - // Cached actor resolution. Populated on first KV request, reused for all - // subsequent requests. Actor name is immutable so this never goes stale. - let mut cached_actor: Option<(Id, String)> = None; + // Cache keyed by actor id since a single connection multiplexes many actors. + let mut cached_actors: HashMap = HashMap::new(); while let Some(req) = rx.recv().await { let is_close = matches!(req.data, protocol::RequestData::ActorCloseRequest); let response_data = match &req.data { // Open/close are lifecycle ops that don't need a resolved actor. - protocol::RequestData::ActorOpenRequest | protocol::RequestData::ActorCloseRequest => { - handle_request(&ctx, &state, conn_id, namespace_id, &open_actors, &req).await + protocol::RequestData::ActorOpenRequest + | protocol::RequestData::ActorCloseRequest => { + handle_request(&open_actors, &req).await } // KV ops: resolve once, cache, reuse. _ => { let is_open = open_actors.lock().await.contains(&req.actor_id); if !is_open { - let locks = state.actor_locks.lock().await; - if locks.contains_key(&req.actor_id) { - error_response("actor_locked", "actor is locked by another connection") - } else { - error_response("actor_not_open", "actor is not opened on this connection") - } + error_response( + "actor_not_open", + "actor is not opened on this connection", + ) } else { - // Lazy-resolve and cache. - if cached_actor.is_none() { + if !cached_actors.contains_key(&req.actor_id) { match resolve_actor(&ctx, &req.actor_id, namespace_id).await { Ok(v) => { - cached_actor = Some(v); + cached_actors.insert(req.actor_id.clone(), v); } Err(resp) => { // Don't cache failures. Next request will retry. @@ -467,7 +414,8 @@ async fn actor_request_task( } } } - let (parsed_id, actor_name) = cached_actor.as_ref().unwrap(); + let (parsed_id, actor_name) = + cached_actors.get(&req.actor_id).unwrap(); let recipient = actor_kv::Recipient { actor_id: *parsed_id, @@ -526,19 +474,15 @@ async fn send_response(ws_handle: &WebSocketHandle, request_id: u32, data: proto /// Handles actor lifecycle requests (open/close). KV operations are handled /// directly in `actor_request_task` with cached actor resolution. async fn handle_request( - _ctx: &StandaloneCtx, - state: &KvChannelState, - conn_id: Uuid, - _namespace_id: Id, open_actors: &Arc>>, req: &protocol::ToRivetRequest, ) -> protocol::ResponseData { match &req.data { protocol::RequestData::ActorOpenRequest => { - handle_actor_open(state, conn_id, open_actors, &req.actor_id).await + handle_actor_open(open_actors, &req.actor_id).await } protocol::RequestData::ActorCloseRequest => { - handle_actor_close(state, conn_id, open_actors, &req.actor_id).await + handle_actor_close(open_actors, &req.actor_id).await } _ => unreachable!("KV operations are handled in actor_request_task"), } @@ -547,8 +491,6 @@ async fn handle_request( // MARK: Actor open/close async fn handle_actor_open( - state: &KvChannelState, - conn_id: Uuid, open_actors: &Arc>>, actor_id: &str, ) -> protocol::ResponseData { @@ -563,47 +505,17 @@ async fn handle_actor_open( } } - let mut locks = state.actor_locks.lock().await; - - // If the actor is locked by a different connection, unconditionally evict the old lock. - // This handles reconnection scenarios where the server hasn't detected the old connection's - // disconnect yet. The old connection's next KV request will fail the fast-path check - // (open_actors.contains) and return actor_not_open. - // See docs-internal/engine/NATIVE_SQLITE_REVIEW_FINDINGS.md Finding 4. - if let Some((existing_conn, old_open_actors)) = locks.get(actor_id) { - if *existing_conn != conn_id { - old_open_actors.lock().await.remove(actor_id); - tracing::info!( - %conn_id, - old_conn_id = %existing_conn, - %actor_id, - "evicted stale actor lock from old connection" - ); - } - } - - locks.insert(actor_id.to_string(), (conn_id, open_actors.clone())); open_actors.lock().await.insert(actor_id.to_string()); - tracing::debug!(%conn_id, %actor_id, "actor lock acquired"); + tracing::debug!(%actor_id, "actor opened"); protocol::ResponseData::ActorOpenResponse } async fn handle_actor_close( - state: &KvChannelState, - conn_id: Uuid, open_actors: &Arc>>, actor_id: &str, ) -> protocol::ResponseData { - let mut locks = state.actor_locks.lock().await; - - if let Some((lock_conn, _)) = locks.get(actor_id) { - if *lock_conn == conn_id { - locks.remove(actor_id); - open_actors.lock().await.remove(actor_id); - tracing::debug!(%conn_id, %actor_id, "actor lock released"); - } - } - + open_actors.lock().await.remove(actor_id); + tracing::debug!(%actor_id, "actor closed"); protocol::ResponseData::ActorCloseResponse } @@ -817,7 +729,7 @@ async fn handle_kv_delete_range( /// Look up an actor by ID and return the parsed ID and actor name. /// -/// Defense-in-depth: verifies the actor belongs to the authenticated namespace. +/// Verifies the actor belongs to the authenticated namespace. async fn resolve_actor( ctx: &StandaloneCtx, actor_id: &str, @@ -827,11 +739,15 @@ async fn resolve_actor( .map_err(|err| error_response("actor_not_found", &format!("invalid actor id: {err}")))?; let actor = ctx - .op(pegboard::ops::actor::get_for_runner::Input { - actor_id: parsed_id, + .op(pegboard::ops::actor::get::Input { + actor_ids: vec![parsed_id], + fetch_error: false, }) .await - .map_err(|err| internal_error(&err))?; + .map_err(|err| internal_error(&err))? + .actors + .into_iter() + .next(); match actor { Some(actor) => { diff --git a/engine/sdks/typescript/envoy-client/src/tasks/envoy/index.ts b/engine/sdks/typescript/envoy-client/src/tasks/envoy/index.ts index a93ad57a62..17cefb31c9 100644 --- a/engine/sdks/typescript/envoy-client/src/tasks/envoy/index.ts +++ b/engine/sdks/typescript/envoy-client/src/tasks/envoy/index.ts @@ -151,9 +151,14 @@ export function startEnvoySync(config: EnvoyConfig): EnvoyHandle { for await (const msg of envoyRx) { if (msg.type === "conn-message") { - await handleConnMessage(ctx, startTx, lostTimeout, msg.message); + lostTimeout = handleConnMessage( + ctx, + startTx, + lostTimeout, + msg.message, + ); } else if (msg.type === "conn-close") { - handleConnClose(ctx, lostTimeout); + lostTimeout = handleConnClose(ctx, lostTimeout); if (msg.evict) break; } else if (msg.type === "send-events") { handleSendEvents(ctx, msg.events); @@ -171,7 +176,12 @@ export function startEnvoySync(config: EnvoyConfig): EnvoyHandle { } // Cleanup + if (lostTimeout) { + clearTimeout(lostTimeout); + } ctx.shared.wsTx?.send({ type: "close", code: 1000, reason: "envoy.shutdown" }); + connHandle.abort(); + await connHandle.catch(() => undefined); clearInterval(ackInterval); clearInterval(kvCleanupInterval); @@ -204,7 +214,7 @@ function handleConnMessage( startTx: WatchSender, lostTimeout: NodeJS.Timeout | undefined, message: ToEnvoyFromConnMessage, -) { +): NodeJS.Timeout | undefined { if (message.tag === "ToEnvoyInit") { ctx.shared.protocolMetadata = message.val.metadata; log(ctx.shared)?.info({ @@ -212,7 +222,10 @@ function handleConnMessage( protocolMetadata: message.val.metadata, }); - clearTimeout(lostTimeout); + if (lostTimeout) { + clearTimeout(lostTimeout); + lostTimeout = undefined; + } resendUnacknowledgedEvents(ctx); processUnsentKvRequests(ctx); resendBufferedTunnelMessages(ctx); @@ -229,9 +242,14 @@ function handleConnMessage( } else { unreachable(message); } + + return lostTimeout; } -function handleConnClose(ctx: EnvoyContext, lostTimeout: NodeJS.Timeout | undefined) { +function handleConnClose( + ctx: EnvoyContext, + lostTimeout: NodeJS.Timeout | undefined, +): NodeJS.Timeout | undefined { if (!lostTimeout) { let lostThreshold = ctx.shared.protocolMetadata ? Number(ctx.shared.protocolMetadata.envoyLostThreshold) : 10000; log(ctx.shared)?.debug({ @@ -268,6 +286,8 @@ function handleConnClose(ctx: EnvoyContext, lostTimeout: NodeJS.Timeout | undefi lostThreshold, ); } + + return lostTimeout; } function handleShutdown(ctx: EnvoyContext) { @@ -346,7 +366,14 @@ function createHandle( return { shutdown(immediate: boolean) { - ctx.shared.envoyTx.send({ type: "shutdown" }); + if (immediate) { + log(ctx.shared)?.debug({ + msg: "envoy received immediate shutdown", + }); + ctx.shared.envoyTx.send({ type: "stop" }); + } else { + ctx.shared.envoyTx.send({ type: "shutdown" }); + } }, getProtocolMetadata(): protocol.ProtocolMetadata | undefined { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index dc0a1c4415..4b1ea2e60a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -1,4 +1,5 @@ import * as cbor from "cbor-x"; +import { stringifyError } from "@/common/utils"; import type * as protocol from "@/schemas/client-protocol/mod"; import { CURRENT_VERSION as CLIENT_PROTOCOL_CURRENT_VERSION, @@ -53,6 +54,7 @@ export class Conn< Q extends QueueSchemaConfig = Record, > { #actor: ActorInstance; + #disconnectPromise?: Promise; get [CONN_ACTOR_SYMBOL](): ActorInstance { return this.#actor; @@ -261,28 +263,43 @@ export class Conn< * @param reason - The reason for disconnection. */ async disconnect(reason?: string) { - if (this[CONN_DRIVER_SYMBOL]) { - const driver = this[CONN_DRIVER_SYMBOL]; - if (driver.disconnect) { - await driver.disconnect(this.#actor, this, reason); - } else { - this.#actor.rLog.debug({ - msg: "no disconnect handler for conn driver", - conn: this.id, - }); - } + if (!this.#disconnectPromise) { + this.#disconnectPromise = (async () => { + if (this[CONN_DRIVER_SYMBOL]) { + const driver = this[CONN_DRIVER_SYMBOL]; + try { + if (driver.disconnect) { + try { + await driver.disconnect(this.#actor, this, reason); + } catch (error) { + this.#actor.rLog.warn({ + msg: "conn driver disconnect failed, continuing connection cleanup", + conn: this.id, + reason, + error: stringifyError(error), + }); + } + } else { + this.#actor.rLog.debug({ + msg: "no disconnect handler for conn driver", + conn: this.id, + }); + } - try { - await this.#actor.connectionManager.connDisconnected(this); - } finally { - this[CONN_DRIVER_SYMBOL] = undefined; - } - } else { - this.#actor.rLog.warn({ - msg: "missing connection driver state for disconnect", - conn: this.id, - }); - this[CONN_DRIVER_SYMBOL] = undefined; + await this.#actor.connectionManager.connDisconnected(this); + } finally { + this[CONN_DRIVER_SYMBOL] = undefined; + } + } else { + this.#actor.rLog.warn({ + msg: "missing connection driver state for disconnect", + conn: this.id, + }); + this[CONN_DRIVER_SYMBOL] = undefined; + } + })(); } + + await this.#disconnectPromise; } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index 9194699a39..22f4921b1c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -136,6 +136,7 @@ const ACTIVE_ASYNC_REGION_ERROR_MESSAGES: Record< "active websocket callback count went below 0, this is a RivetKit bug", }; + /** Actor type alias with all `any` types. Used for `extends` in classes referencing this actor. */ export type AnyActorInstance = ActorInstance< any, @@ -1022,7 +1023,7 @@ export class ActorInstance< let spanEnded = false; try { - return await this.#traces.withSpan(actionSpan, async () => { + const output = await this.#traces.withSpan(actionSpan, async () => { this.#rLog.debug({ msg: "executing action", actionName, @@ -1060,6 +1061,8 @@ export class ActorInstance< return output; }); + + return output; } catch (error) { this.#metrics.actionErrors++; const isTimeout = error instanceof DeadlineError; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts index 6c7dc42224..969d272dcf 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts @@ -218,8 +218,13 @@ export async function handleWebSocketConnect( exposeInternalError, }: WebSocketHandlerOpts, ): Promise { - // Process WS messages in order to avoid races between subscription updates - // and subsequent action requests. + // Parse and apply subscription updates in order so subscribe/unsubscribe + // messages are visible to later messages deterministically. + // + // Action execution itself is intentionally not awaited in this chain. + // Actions are allowed to overlap on a single connection, and tests depend + // on a follow-up action being able to unblock a long-running action on the + // same WebSocket. let pendingMessage = Promise.resolve(); return { @@ -248,7 +253,25 @@ export async function handleWebSocketConnect( maxIncomingMessageSize: runConfig.maxIncomingMessageSize, }); - await actor.processMessage(message, conn); + + if (message.body.tag === "SubscriptionRequest") { + await actor.processMessage(message, conn); + return; + } + + void actor.processMessage(message, conn).catch((error) => { + const { group, code } = deconstructError( + error, + actor.rLog, + { + wsEvent: "message", + actionId: message.body.val.id, + actionName: message.body.val.name, + }, + exposeInternalError, + ); + ws.close(1011, `${group}.${code}`); + }); }) .catch((error) => { const { group, code } = deconstructError( @@ -313,7 +336,12 @@ export async function handleWebSocketConnect( export async function handleRawWebSocket( setWebSocket: (ws: UniversalWebSocket) => void, - { request, actor, closePromiseResolvers, conn }: WebSocketHandlerOpts, + { + request, + actor, + closePromiseResolvers, + conn, + }: WebSocketHandlerOpts, ): Promise { return { conn, @@ -343,17 +371,21 @@ export async function handleRawWebSocket( // this is called synchronously within onOpen. actor.handleRawWebSocket(conn, ws, request); }, - // Raw websocket messages are handled directly by the actor's event - // listeners on the WebSocket object, not through this callback - onMessage: (_evt: any, _ws: any) => {}, + onMessage: (_evt: any, _wsContext: any) => {}, onClose: (evt: any, ws: any) => { // Resolve the close promise closePromiseResolvers.resolve(); // Clean up the connection - conn.disconnect(evt?.reason); + void conn.disconnect(evt?.reason).catch((error) => { + actor.rLog.error({ + msg: "raw websocket disconnect failed", + error: String(error), + reason: evt?.reason, + }); + }); }, - onError: (error: any, ws: any) => {}, + onError: (_error: any, _ws: any) => {}, }; } diff --git a/rivetkit-typescript/packages/rivetkit/src/agent-os/actor/process.ts b/rivetkit-typescript/packages/rivetkit/src/agent-os/actor/process.ts index 0d2cb09484..1ba2e2e1ca 100644 --- a/rivetkit-typescript/packages/rivetkit/src/agent-os/actor/process.ts +++ b/rivetkit-typescript/packages/rivetkit/src/agent-os/actor/process.ts @@ -3,6 +3,7 @@ import type { ProcessTreeNode, SpawnedProcessInfo, } from "@rivet-dev/agent-os-core"; +import { ActorStopping } from "@/actor/errors"; import type { AgentOsActorConfig } from "../config"; import type { AgentOsActionContext } from "../types"; import { ensureVm, syncPreventSleep } from "./index"; @@ -16,6 +17,21 @@ type SpawnOptions = Parameters< import("@rivet-dev/agent-os-core").AgentOs["spawn"] >[2]; +function broadcastProcessEvent( + c: AgentOsActionContext, + name: "processOutput" | "processExit", + payload: unknown, +) { + try { + c.broadcast(name, payload); + } catch (error) { + if (error instanceof ActorStopping) { + return; + } + throw error; + } +} + // Build process execution actions for the actor factory. export function buildProcessActions( config: AgentOsActorConfig, @@ -40,11 +56,19 @@ export function buildProcessActions( const { pid } = agentOs.spawn(command, args, { ...options, onStdout: (data: Uint8Array) => { - c.broadcast("processOutput", { pid, stream: "stdout" as const, data }); + broadcastProcessEvent(c, "processOutput", { + pid, + stream: "stdout" as const, + data, + }); options?.onStdout?.(data); }, onStderr: (data: Uint8Array) => { - c.broadcast("processOutput", { pid, stream: "stderr" as const, data }); + broadcastProcessEvent(c, "processOutput", { + pid, + stream: "stderr" as const, + data, + }); options?.onStderr?.(data); }, }); @@ -59,7 +83,7 @@ export function buildProcessActions( agentOs.waitProcess(pid) .then((exitCode) => { - c.broadcast("processExit", { pid, exitCode }); + broadcastProcessEvent(c, "processExit", { pid, exitCode }); c.log.info({ msg: "agent-os process exited", pid, diff --git a/rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts b/rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts index 9f6160d8ff..70b52d4a2e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts @@ -5,6 +5,7 @@ import { type SqliteRemoteDatabase, } from "drizzle-orm/sqlite-proxy"; import type { DatabaseProvider, RawAccess } from "../config"; +import { openActorDatabase } from "../open-database"; import { AsyncMutex, createActorKvStore, @@ -216,19 +217,7 @@ export function db< return { createClient: async (ctx) => { - // Construct KV-backed client using actor driver's KV operations - if (!ctx.sqliteVfs) { - throw new Error( - "SqliteVfs instance not provided in context. The driver must provide a sqliteVfs instance.", - ); - } - - const kvStore = createActorKvStore( - ctx.kv, - ctx.metrics, - ctx.preloadedEntries, - ); - const waDb = await ctx.sqliteVfs.open(ctx.actorId, kvStore); + const { database: waDb, kvStore } = await openActorDatabase(ctx); // Per-client mutex so actors of the same type do not serialize // against each other. Each actor has its own database handle and // its own closed flag, so there is no shared state to guard. @@ -344,7 +333,9 @@ export function db< } satisfies RawAccess); clientToRawDb.set(result, waDb); - clientToKvStore.set(result, kvStore); + if (kvStore) { + clientToKvStore.set(result, kvStore); + } return result; }, onMigrate: async (client) => { diff --git a/rivetkit-typescript/packages/rivetkit/src/db/mod.ts b/rivetkit-typescript/packages/rivetkit/src/db/mod.ts index 9fbe132299..207ee78d12 100644 --- a/rivetkit-typescript/packages/rivetkit/src/db/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/db/mod.ts @@ -1,8 +1,5 @@ import type { DatabaseProvider, RawAccess } from "./config"; -import { - nativeSqliteAvailable, - createNativeRawAccess, -} from "./native-sqlite"; +import { openActorDatabase } from "./open-database"; import { AsyncMutex, createActorKvStore, @@ -12,9 +9,6 @@ import { export type { RawAccess } from "./config"; -// Log the native SQLite fallback warning at most once per process. -let nativeFallbackWarned = false; - interface DatabaseFactoryConfig { onMigrate?: (db: RawAccess) => Promise | void; } @@ -54,41 +48,13 @@ export function db({ } satisfies RawAccess; } - // Use native SQLite when the addon is available. The native path - // routes KV operations over a WebSocket KV channel, bypassing - // the WASM VFS entirely. - if (nativeSqliteAvailable()) { - return await createNativeRawAccess( - ctx.actorId, - ctx.nativeSqliteConfig, - ); - } - - // Native addon not available. Fall back to WASM SQLite. - if (!nativeFallbackWarned) { - nativeFallbackWarned = true; - console.warn( - "native SQLite not available, falling back to WebAssembly. run npm rebuild to install native bindings.", - ); - } - - // Construct KV-backed client using actor driver's KV operations - if (!ctx.sqliteVfs) { - throw new Error( - "SqliteVfs instance not provided in context. The driver must provide a sqliteVfs instance.", - ); - } - + const { database: db, kvStore } = await openActorDatabase(ctx); let lastVfsError: unknown = null; - const kvStore = createActorKvStore( - ctx.kv, - ctx.metrics, - ctx.preloadedEntries, - ); - kvStore.onError = (error: unknown) => { - lastVfsError = error; - }; - const db = await ctx.sqliteVfs.open(ctx.actorId, kvStore); + if (kvStore) { + kvStore.onError = (error: unknown) => { + lastVfsError = error; + }; + } let closed = false; const mutex = new AsyncMutex(); const ensureOpen = () => { @@ -219,7 +185,7 @@ export function db({ // page read/write will fail immediately with a // descriptive error instead of hanging or producing a // cryptic "disk I/O error". - kvStore.poison(); + kvStore?.poison(); const shouldClose = await mutex.run(async () => { if (closed) return false; @@ -231,7 +197,9 @@ export function db({ } }, } satisfies RawAccess; - clientToKvStore.set(client, kvStore); + if (kvStore) { + clientToKvStore.set(client, kvStore); + } return client; }, onMigrate: async (client) => { diff --git a/rivetkit-typescript/packages/rivetkit/src/db/native-adapter.ts b/rivetkit-typescript/packages/rivetkit/src/db/native-adapter.ts new file mode 100644 index 0000000000..fee30e710f --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/db/native-adapter.ts @@ -0,0 +1,243 @@ +import type { IDatabase } from "@rivetkit/sqlite-vfs"; +import type { NativeSqliteConfig } from "./config"; +import { + getNativeModule, + disconnectKvChannelIfCurrent, + getOrCreateKvChannel, + toNativeBindings, + type NativeKvChannel, + type NativeDatabase, +} from "./native-sqlite"; +import { AsyncMutex } from "./shared"; + +function isPlainObject(value: unknown): value is Record { + return ( + !!value && + typeof value === "object" && + !Array.isArray(value) && + Object.getPrototypeOf(value) === Object.prototype + ); +} + +function extractNamedSqliteParameters(sql: string): string[] { + const names: string[] = []; + const pattern = /([:@$][A-Za-z_][A-Za-z0-9_]*)/g; + for (const match of sql.matchAll(pattern)) { + names.push(match[1]); + } + return names; +} + +function resolveNamedSqliteBinding( + params: Record, + name: string, +): unknown { + if (name in params) { + return params[name]; + } + + const bareName = name.slice(1); + if (bareName in params) { + return params[bareName]; + } + + for (const prefix of [":", "@", "$"] as const) { + const candidate = `${prefix}${bareName}`; + if (candidate in params) { + return params[candidate]; + } + } + + return undefined; +} + +function normalizeNativeBindings(sql: string, params?: unknown): unknown[] { + if (params === undefined || params === null) { + return []; + } + + if (Array.isArray(params)) { + return toNativeBindings(params); + } + + if (isPlainObject(params)) { + const names = extractNamedSqliteParameters(sql); + if (names.length === 0) { + throw new Error( + "native SQLite adapter only supports named parameter objects when the SQL statement uses named placeholders", + ); + } + + return toNativeBindings( + names.map((name) => { + const value = resolveNamedSqliteBinding(params, name); + if (value === undefined) { + throw new Error(`missing bind parameter: ${name}`); + } + return value; + }), + ); + } + + throw new Error( + "native SQLite adapter only supports positional parameter arrays or named parameter objects", + ); +} + +function isStaleKvChannelError(error: unknown): boolean { + const message = + error instanceof Error ? error.message : String(error); + return /kv channel (?:connection closed|shut down)/i.test(message); +} + +async function clearKvChannelForConfig( + channel: NativeKvChannel, + config?: NativeSqliteConfig, +): Promise { + try { + await disconnectKvChannelIfCurrent(channel, config); + } catch { + // Ignore disconnect errors. The cache entry is about to be replaced. + } +} + +async function openNativeDatabaseHandle( + actorId: string, + config?: NativeSqliteConfig, +): Promise<{ + nativeDb: NativeDatabase; + channel: NativeKvChannel; +}> { + const mod = getNativeModule(); + const maxAttempts = 3; + + for (let attempt = 0; attempt < maxAttempts; attempt++) { + const channel = getOrCreateKvChannel(config); + try { + return { + nativeDb: await mod.openDatabase(channel, actorId), + channel, + }; + } catch (error) { + if ( + !isStaleKvChannelError(error) || + attempt === maxAttempts - 1 + ) { + throw error; + } + + await clearKvChannelForConfig(channel, config); + } + } + + throw new Error("unreachable: native database open exhausted retries"); +} + +class NativeSqliteDatabase implements IDatabase { + #module = getNativeModule(); + #nativeDb: NativeDatabase; + #channel: NativeKvChannel; + #config?: NativeSqliteConfig; + #recoveryMutex = new AsyncMutex(); + readonly fileName: string; + + constructor( + nativeDb: NativeDatabase, + channel: NativeKvChannel, + fileName: string, + config?: NativeSqliteConfig, + ) { + this.#nativeDb = nativeDb; + this.#channel = channel; + this.fileName = fileName; + this.#config = config; + } + + async #recoverFromStaleKvChannel(error: unknown): Promise { + if (!isStaleKvChannelError(error)) { + return false; + } + + await this.#recoveryMutex.run(async () => { + await clearKvChannelForConfig(this.#channel, this.#config); + const reopened = await openNativeDatabaseHandle( + this.fileName, + this.#config, + ); + this.#nativeDb = reopened.nativeDb; + this.#channel = reopened.channel; + }); + + return true; + } + + async #runWithReconnect( + operation: (nativeDb: NativeDatabase) => Promise, + ): Promise { + try { + return await operation(this.#nativeDb); + } catch (error) { + const recovered = await this.#recoverFromStaleKvChannel(error); + if (!recovered) { + throw error; + } + + return await operation(this.#nativeDb); + } + } + + async exec( + sql: string, + callback?: (row: unknown[], columns: string[]) => void, + ): Promise { + const result = await this.#runWithReconnect((nativeDb) => { + return this.#module.exec(nativeDb, sql); + }); + if (!callback) { + return; + } + for (const row of result.rows) { + callback(row, result.columns); + } + } + + async run(sql: string, params?: unknown): Promise { + const bindings = normalizeNativeBindings(sql, params); + await this.#runWithReconnect((nativeDb) => { + return this.#module.execute(nativeDb, sql, bindings); + }); + } + + async query( + sql: string, + params?: unknown, + ): Promise<{ rows: unknown[][]; columns: string[] }> { + const bindings = normalizeNativeBindings(sql, params); + return await this.#runWithReconnect((nativeDb) => { + return this.#module.query(nativeDb, sql, bindings); + }); + } + + async close(): Promise { + try { + await this.#module.closeDatabase(this.#nativeDb); + } catch (error) { + await this.#recoverFromStaleKvChannel(error); + if (!isStaleKvChannelError(error)) { + throw error; + } + } + } +} + +export async function openNativeDatabase( + actorId: string, + config?: NativeSqliteConfig, +): Promise { + const { nativeDb, channel } = await openNativeDatabaseHandle( + actorId, + config, + ); + + return new NativeSqliteDatabase(nativeDb, channel, actorId, config); +} diff --git a/rivetkit-typescript/packages/rivetkit/src/db/native-sqlite.ts b/rivetkit-typescript/packages/rivetkit/src/db/native-sqlite.ts index b0da15ed54..aae01922f1 100644 --- a/rivetkit-typescript/packages/rivetkit/src/db/native-sqlite.ts +++ b/rivetkit-typescript/packages/rivetkit/src/db/native-sqlite.ts @@ -24,7 +24,7 @@ import { AsyncMutex } from "./shared"; // which may not be installed or compiled. /** Typed bind parameter matching the Rust BindParam napi struct. */ -interface NativeBindParam { +export interface NativeBindParam { kind: "null" | "int" | "float" | "text" | "blob"; intValue?: number; floatValue?: number; @@ -32,7 +32,7 @@ interface NativeBindParam { blobValue?: Buffer; } -interface NativeSqliteModule { +export interface NativeSqliteModule { connect(config: { url: string; token?: string; @@ -81,8 +81,8 @@ export interface KvChannelMetricsSnapshot { } // Opaque handles from the native addon. -type NativeKvChannel = object; -type NativeDatabase = object; +export type NativeKvChannel = object; +export type NativeDatabase = object; // Cached detection result. let nativeModule: NativeSqliteModule | null = null; @@ -153,7 +153,7 @@ export function nativeSqliteAvailable(): boolean { * Returns the loaded native module. Only valid after nativeSqliteAvailable() * returns true. */ -function getNativeModule(): NativeSqliteModule { +export function getNativeModule(): NativeSqliteModule { if (!nativeModule) { throw new Error("native SQLite module not loaded"); } @@ -226,7 +226,9 @@ function getKvChannelConfig(config?: NativeSqliteConfig) { }; } -function getOrCreateKvChannel(config?: NativeSqliteConfig): NativeKvChannel { +export function getOrCreateKvChannel( + config?: NativeSqliteConfig, +): NativeKvChannel { const mod = getNativeModule(); const channelConfig = getKvChannelConfig(config); const existing = kvChannels.get(channelConfig.key); @@ -244,36 +246,96 @@ function getOrCreateKvChannel(config?: NativeSqliteConfig): NativeKvChannel { return channel; } +function toNativeBinding(arg: unknown): NativeBindParam { + if (arg === null || arg === undefined) { + return { kind: "null" }; + } + if (typeof arg === "bigint") { + return { kind: "int", intValue: Number(arg) }; + } + if (typeof arg === "number") { + if (Number.isInteger(arg)) { + return { kind: "int", intValue: arg }; + } + return { kind: "float", floatValue: arg }; + } + if (typeof arg === "string") { + return { kind: "text", textValue: arg }; + } + if (typeof arg === "boolean") { + return { kind: "int", intValue: arg ? 1 : 0 }; + } + if (arg instanceof Uint8Array) { + return { kind: "blob", blobValue: Buffer.from(arg) }; + } + throw new Error(`unsupported bind parameter type: ${typeof arg}`); +} + /** * Convert binding values to typed BindParam objects for the native addon. * Uses Buffer for blobs instead of JSON arrays to avoid 20x serialization * overhead. See docs-internal/engine/NATIVE_SQLITE_REVIEW_FIXES.md M7. */ -function toNativeBindings(args: unknown[]): NativeBindParam[] { +export function toNativeBindings(args: unknown[]): NativeBindParam[] { return args.map((arg): NativeBindParam => { - if (arg === null || arg === undefined) { - return { kind: "null" }; - } - if (typeof arg === "bigint") { - return { kind: "int", intValue: Number(arg) }; - } - if (typeof arg === "number") { - if (Number.isInteger(arg)) { - return { kind: "int", intValue: arg }; - } - return { kind: "float", floatValue: arg }; - } - if (typeof arg === "string") { - return { kind: "text", textValue: arg }; + return toNativeBinding(arg); + }); +} + +function toNativeNamedBindings( + sql: string, + bindings: Record, +): NativeBindParam[] { + const orderedNames = extractNamedSqliteParameters(sql); + if (orderedNames.length === 0) { + return toNativeBindings(Object.values(bindings)); + } + + return orderedNames.map((name) => { + const value = getNamedSqliteBinding(bindings, name); + if (value === undefined) { + throw new Error(`missing bind parameter: ${name}`); } - if (typeof arg === "boolean") { - return { kind: "int", intValue: arg ? 1 : 0 }; + return toNativeBinding(value); + }); +} + +function extractNamedSqliteParameters(sql: string): string[] { + const orderedNames: string[] = []; + const seen = new Set(); + const pattern = /([:@$][A-Za-z_][A-Za-z0-9_]*)/g; + for (const match of sql.matchAll(pattern)) { + const name = match[1]; + if (seen.has(name)) { + continue; } - if (arg instanceof Uint8Array) { - return { kind: "blob", blobValue: Buffer.from(arg) }; + seen.add(name); + orderedNames.push(name); + } + return orderedNames; +} + +function getNamedSqliteBinding( + bindings: Record, + name: string, +): unknown { + if (name in bindings) { + return bindings[name]; + } + + const bareName = name.slice(1); + if (bareName in bindings) { + return bindings[bareName]; + } + + for (const prefix of [":", "@", "$"] as const) { + const candidate = `${prefix}${bareName}`; + if (candidate in bindings) { + return bindings[candidate]; } - throw new Error(`unsupported bind parameter type: ${typeof arg}`); - }); + } + + return undefined; } /** @@ -294,19 +356,40 @@ export function getKvChannelMetrics(): KvChannelMetricsSnapshot | undefined { */ export async function disconnectKvChannelForCurrentConfig( config?: NativeSqliteConfig, -): Promise { +): Promise { if (!nativeModule) { - return; + return 0; } const { key } = getKvChannelConfig(config); const channel = kvChannels.get(key); if (!channel) { - return; + return 0; } kvChannels.delete(key); await nativeModule.disconnect(channel); + return 1; +} + +/** + * Disconnect a specific KV channel instance and only clear the cached entry + * if it still points at that same channel. + */ +export async function disconnectKvChannelIfCurrent( + channel: NativeKvChannel, + config?: NativeSqliteConfig, +): Promise { + if (!nativeModule) { + return; + } + + const { key } = getKvChannelConfig(config); + if (kvChannels.get(key) === channel) { + kvChannels.delete(key); + } + + await nativeModule.disconnect(channel); } /** @@ -346,7 +429,17 @@ export async function createNativeRawAccess( // The native addon validates binding types in Rust // (bind_params). Convert bigint/Uint8Array to // JSON-compatible representations. - const bindings = toNativeBindings(args); + const bindings = + args.length === 1 && + args[0] !== null && + typeof args[0] === "object" && + !Array.isArray(args[0]) && + !(args[0] instanceof Uint8Array) + ? toNativeNamedBindings( + query, + args[0] as Record, + ) + : toNativeBindings(args); const token = query .trimStart() .slice(0, 16) diff --git a/rivetkit-typescript/packages/rivetkit/src/db/open-database.ts b/rivetkit-typescript/packages/rivetkit/src/db/open-database.ts new file mode 100644 index 0000000000..fc9f40af2e --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/db/open-database.ts @@ -0,0 +1,50 @@ +import type { IDatabase } from "@rivetkit/sqlite-vfs"; +import type { DatabaseProviderContext } from "./config"; +import { openNativeDatabase } from "./native-adapter"; +import { nativeSqliteAvailable } from "./native-sqlite"; +import { createActorKvStore } from "./shared"; + +type OpenedKvStore = ReturnType; + +export interface OpenedActorDatabase { + database: IDatabase; + kvStore?: OpenedKvStore; +} + +let nativeFallbackWarned = false; + +export async function openActorDatabase( + ctx: DatabaseProviderContext, +): Promise { + if (ctx.nativeSqliteConfig && nativeSqliteAvailable()) { + return { + database: await openNativeDatabase( + ctx.actorId, + ctx.nativeSqliteConfig, + ), + }; + } + + if (!nativeFallbackWarned) { + nativeFallbackWarned = true; + console.warn( + "native SQLite not available, falling back to WebAssembly. run npm rebuild to install native bindings.", + ); + } + + if (!ctx.sqliteVfs) { + throw new Error( + "SqliteVfs instance not provided in context. The driver must provide a sqliteVfs instance.", + ); + } + + const kvStore = createActorKvStore( + ctx.kv, + ctx.metrics, + ctx.preloadedEntries, + ); + return { + database: await ctx.sqliteVfs.open(ctx.actorId, kvStore), + kvStore, + }; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index fffadc7e69..91904d3572 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -57,6 +57,7 @@ import { logger } from "./log"; const ENVOY_SSE_PING_INTERVAL = 1000; const ENVOY_STOP_WAIT_MS = 15_000; +const INITIAL_SLEEP_TIMEOUT_MS = 250; // Message ack deadline is 30s on the gateway, but we will ack more frequently // in order to minimize the message buffer size on the gateway and to give @@ -293,6 +294,13 @@ export class EngineActorDriver implements ActorDriver { // No database overrides - will use KV-backed implementation from rivetkit/db + getInitialSleepTimeoutMs( + _actor: AnyActorInstance, + defaultTimeoutMs: number, + ): number { + return Math.max(defaultTimeoutMs, INITIAL_SLEEP_TIMEOUT_MS); + } + getNativeSqliteConfig() { return { endpoint: getEndpoint(this.#config), @@ -424,6 +432,8 @@ export class EngineActorDriver implements ActorDriver { }); await this.#discardCrashedActorState(actorId); + this.#actorStopIntent.set(actorId, "crash"); + this.#envoy.stopActor(actorId, undefined, "simulated hard crash"); } async shutdown(immediate: boolean): Promise { @@ -507,6 +517,10 @@ export class EngineActorDriver implements ActorDriver { } } + async waitForReady(): Promise { + await this.#envoy.started(); + } + async serverlessHandleStart(c: HonoContext): Promise { let payload = await c.req.arrayBuffer(); @@ -925,6 +939,7 @@ export class EngineActorDriver implements ActorDriver { requestIdBuf, isHibernatable, isRestoringHibernatable, + true, ); } catch (err) { logger().error({ msg: "building websocket handlers errored", err }); @@ -964,11 +979,10 @@ export class EngineActorDriver implements ActorDriver { wsHandler.onRestore?.(wsContext); } - websocket.addEventListener("open", (event) => { - wsHandler.onOpen(event, wsContext); - }); - - websocket.addEventListener("message", (event: RivetMessageEvent) => { + const isRawWebSocketPath = + requestPath === PATH_WEBSOCKET_BASE || + requestPath.startsWith(PATH_WEBSOCKET_PREFIX); + const handleMessageEvent = (event: RivetMessageEvent) => { logger().debug({ msg: "websocket message event listener triggered", connId: conn?.id, @@ -982,9 +996,6 @@ export class EngineActorDriver implements ActorDriver { eventTargetWsId: (event.target as any)?.__rivet_ws_id, }); - // Check if actor is stopping - if so, don't process new messages. - // These messages will be reprocessed when the actor wakes up from hibernation. - // TODO: This will never retransmit the socket and the socket will close if (actor?.isStopping) { logger().debug({ msg: "ignoring ws message, actor is stopping", @@ -995,87 +1006,108 @@ export class EngineActorDriver implements ActorDriver { return; } - // Process message - logger().debug({ - msg: "calling wsHandler.onMessage", - connId: conn?.id, - messageIndex: event.rivetMessageIndex, - }); - wsHandler.onMessage(event, wsContext); + const run = async () => { + logger().debug({ + msg: "calling wsHandler.onMessage", + connId: conn?.id, + messageIndex: event.rivetMessageIndex, + }); + wsHandler.onMessage(event, wsContext); - // Persist message index for hibernatable connections - const hibernate = connStateManager?.hibernatableData; + const hibernate = connStateManager?.hibernatableData; - if (hibernate && conn && actor) { - invariant( - typeof event.rivetMessageIndex === "number", - "missing event.rivetMessageIndex", - ); + if (hibernate && conn && actor) { + invariant( + typeof event.rivetMessageIndex === "number", + "missing event.rivetMessageIndex", + ); - // Persist message index - const previousMsgIndex = hibernate.serverMessageIndex; - hibernate.serverMessageIndex = event.rivetMessageIndex; - logger().info({ - msg: "persisting message index", - connId: conn.id, - previousMsgIndex, - newMsgIndex: event.rivetMessageIndex, - }); + const previousMsgIndex = hibernate.serverMessageIndex; + hibernate.serverMessageIndex = event.rivetMessageIndex; + logger().info({ + msg: "persisting message index", + connId: conn.id, + previousMsgIndex, + newMsgIndex: event.rivetMessageIndex, + }); - // Calculate message size and track cumulative size - const entry = this.#hwsMessageIndex.get(conn.id); - if (entry) { - // Track message length - const messageLength = getValueLength(event.data); - entry.bufferedMessageSize += messageLength; - - if ( - entry.bufferedMessageSize >= - CONN_BUFFERED_MESSAGE_SIZE_THRESHOLD - ) { - // Reset buffered message size immediately (instead - // of waiting for onAfterPersistConn) since we may - // receive more messages before onAfterPersistConn - // is called, which would called saveState - // immediate multiple times - entry.bufferedMessageSize = 0; - entry.pendingAckFromBufferSize = true; - - // Save state immediately if approaching buffer threshold - actor.stateManager.saveState({ - immediate: true, - }); + const entry = this.#hwsMessageIndex.get(conn.id); + if (entry) { + const messageLength = getValueLength(event.data); + entry.bufferedMessageSize += messageLength; + + if ( + entry.bufferedMessageSize >= + CONN_BUFFERED_MESSAGE_SIZE_THRESHOLD + ) { + entry.bufferedMessageSize = 0; + entry.pendingAckFromBufferSize = true; + + actor.stateManager.saveState({ + immediate: true, + }); + } else { + actor.stateManager.saveState({ + maxWait: CONN_MESSAGE_ACK_DEADLINE, + }); + } } else { - // Save message index. The maxWait is set to the ack deadline - // since we ack the message immediately after persisting the index. - // If cumulative size exceeds threshold, force immediate persist. - // - // This will call EngineActorDriver.onAfterPersistConn after - // persist to send the ack to the gateway. actor.stateManager.saveState({ maxWait: CONN_MESSAGE_ACK_DEADLINE, }); } - } else { - // Fallback if entry missing - actor.stateManager.saveState({ - maxWait: CONN_MESSAGE_ACK_DEADLINE, - }); } + }; + + if (isRawWebSocketPath && actor) { + void actor.internalKeepAwake(run); + } else { + void run(); } - }); + }; + const attachMessageListener = () => { + websocket.addEventListener("message", handleMessageEvent); + }; + let postOpenListenersAttached = false; + const attachPostOpenListeners = () => { + if (postOpenListenersAttached) { + return; + } + postOpenListenersAttached = true; - websocket.addEventListener("close", (event) => { - wsHandler.onClose(event, wsContext); + if (!isRawWebSocketPath) { + attachMessageListener(); + } - // NOTE: Persisted connection is removed when `conn.disconnect` - // is called by the WebSocket route - }); + websocket.addEventListener("close", (event) => { + if (isRawWebSocketPath && actor) { + void actor.internalKeepAwake(async () => { + await Promise.resolve(); + wsHandler.onClose(event, wsContext); + }); + } else { + wsHandler.onClose(event, wsContext); + } + }); + + websocket.addEventListener("error", (event) => { + wsHandler.onError(event, wsContext); + }); + }; - websocket.addEventListener("error", (event) => { - wsHandler.onError(event, wsContext); + websocket.addEventListener("open", (event) => { + if (isRawWebSocketPath) { + attachMessageListener(); + } + wsHandler.onOpen(event, wsContext); + + attachPostOpenListeners(); }); + if (!isRawWebSocketPath) { + attachPostOpenListeners(); + } + // Log event listener attachment for restored connections if (isRestoringHibernatable) { logger().info({ diff --git a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts index 2234e15c7f..f9f335d0c9 100644 --- a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts +++ b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/actor-websocket-client.ts @@ -156,11 +156,11 @@ export async function openWebSocketToGateway( buildWebSocketProtocols(runConfig, encoding, params), ); - // Set binary type to arraybuffer for proper encoding support + // The WebSocket is returned before the connection is open. This follows + // standard WebSocket behavior where the caller listens for the "open" + // event before sending messages. ws.binaryType = "arraybuffer"; - logger().debug({ msg: "websocket connection opened", gatewayUrl }); - return ws as UniversalWebSocket; } diff --git a/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts b/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts index dd381d40b1..c4ae6c67b0 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts +++ b/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts @@ -8,6 +8,344 @@ import invariant from "invariant"; import { convertRegistryConfigToClientConfig } from "@/client/config"; import { describe, expect, test, vi } from "vitest"; +interface ActorListResponse { + actors?: Array<{ actor_id?: string }>; + pagination?: { cursor?: string | null }; +} + +interface ActorNamesResponse { + names?: Record; + pagination?: { cursor?: string | null }; +} + +interface SharedEngineRuntime { + endpoint: string; + namespace: string; + runnerName: string; + token: string; + driverConfig: ReturnType; + actorDriver: ReturnType["actor"]>; + forceDisconnectKvChannel: () => Promise; +} + +let sharedNamespacePromise: Promise | undefined; +let sharedRunnerConfigPromise: Promise | undefined; +let sharedEngineRuntimePromise: Promise | undefined; + +async function ensureSharedNamespace( + endpoint: string, + token: string, +): Promise { + if (!sharedNamespacePromise) { + sharedNamespacePromise = (async () => { + const namespace = `test-driver-engine-${crypto.randomUUID().slice(0, 8)}`; + const response = await fetch(`${endpoint}/namespaces`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + name: namespace, + display_name: namespace, + }), + }); + if (!response.ok) { + const errorBody = await response.text().catch(() => ""); + throw new Error( + `Create shared namespace failed at ${endpoint}: ${response.status} ${response.statusText} ${errorBody}`, + ); + } + + return namespace; + })(); + } + + return await sharedNamespacePromise; +} + +async function ensureSharedRunnerConfig( + endpoint: string, + namespace: string, + token: string, +): Promise { + if (!sharedRunnerConfigPromise) { + sharedRunnerConfigPromise = (async () => { + const runnerName = `test-runner-${crypto.randomUUID().slice(0, 8)}`; + const response = await fetch( + `${endpoint}/runner-configs/${runnerName}?namespace=${namespace}`, + { + method: "PUT", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + datacenters: { + default: { normal: {} }, + }, + }), + }, + ); + if (!response.ok) { + const errorBody = await response.text().catch(() => ""); + throw new Error( + `Create shared runner config failed: ${response.status} ${response.statusText} ${errorBody}`, + ); + } + + return runnerName; + })(); + } + + return await sharedRunnerConfigPromise; +} + +async function listAllActorNames( + endpoint: string, + namespace: string, + token: string, +): Promise { + const names: string[] = []; + let cursor: string | undefined; + + for (;;) { + const url = new URL("/actors/names", endpoint); + url.searchParams.set("namespace", namespace); + url.searchParams.set("limit", "100"); + if (cursor) { + url.searchParams.set("cursor", cursor); + } + + const response = await fetch(url, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); + if (!response.ok) { + const errorBody = await response.text().catch(() => ""); + throw new Error( + `List actor names failed: ${response.status} ${response.statusText} ${errorBody}`, + ); + } + + const responseJson = (await response.json()) as ActorNamesResponse; + names.push(...Object.keys(responseJson.names ?? {})); + + const nextCursor = responseJson.pagination?.cursor ?? undefined; + if (!nextCursor) { + return names; + } + cursor = nextCursor; + } +} + +async function listActorIdsForName( + endpoint: string, + namespace: string, + name: string, + token: string, +): Promise { + const actorIds: string[] = []; + let cursor: string | undefined; + + for (;;) { + const url = new URL("/actors", endpoint); + url.searchParams.set("namespace", namespace); + url.searchParams.set("name", name); + url.searchParams.set("limit", "100"); + if (cursor) { + url.searchParams.set("cursor", cursor); + } + + const response = await fetch(url, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); + if (!response.ok) { + const errorBody = await response.text().catch(() => ""); + throw new Error( + `List actors failed for ${name}: ${response.status} ${response.statusText} ${errorBody}`, + ); + } + + const responseJson = (await response.json()) as ActorListResponse; + actorIds.push( + ...((responseJson.actors ?? []) + .map((actor) => actor.actor_id) + .filter((actorId): actorId is string => !!actorId)), + ); + + const nextCursor = responseJson.pagination?.cursor ?? undefined; + if (!nextCursor) { + return actorIds; + } + cursor = nextCursor; + } +} + +async function destroyNamespaceActors( + endpoint: string, + namespace: string, + token: string, +): Promise { + const names = await listAllActorNames(endpoint, namespace, token); + for (const name of names) { + const actorIds = await listActorIdsForName( + endpoint, + namespace, + name, + token, + ); + await Promise.all( + actorIds.map(async (actorId) => { + const url = new URL(`/actors/${actorId}`, endpoint); + url.searchParams.set("namespace", namespace); + + const response = await fetch(url, { + method: "DELETE", + headers: { + Authorization: `Bearer ${token}`, + }, + }); + if (response.status === 404) { + return; + } + if (!response.ok) { + const errorBody = await response.text().catch(() => ""); + throw new Error( + `Delete actor ${actorId} failed: ${response.status} ${response.statusText} ${errorBody}`, + ); + } + }), + ); + } +} + +async function waitForEnvoyCount( + endpoint: string, + namespace: string, + runnerName: string, + token: string, + expectAtLeastOne: boolean, +): Promise { + const envoysUrl = new URL(`${endpoint.replace(/\/$/, "")}/envoys`); + envoysUrl.searchParams.set("namespace", namespace); + envoysUrl.searchParams.set("name", runnerName); + + let probeError: unknown; + for (let attempt = 0; attempt < 150; attempt++) { + try { + const envoyResponse = await fetch(envoysUrl, { + method: "GET", + headers: { + Authorization: `Bearer ${token}`, + }, + }); + if (!envoyResponse.ok) { + const errorBody = await envoyResponse.text().catch(() => ""); + probeError = new Error( + `List envoys failed: ${envoyResponse.status} ${envoyResponse.statusText} ${errorBody}`, + ); + } else { + const responseJson = + (await envoyResponse.json()) as { + envoys?: Array<{ pool_name?: string }>; + }; + const count = + responseJson.envoys?.filter( + (envoy) => envoy.pool_name === runnerName, + ).length ?? 0; + if (expectAtLeastOne ? count > 0 : count === 0) { + return; + } + + probeError = new Error( + expectAtLeastOne + ? `Envoy ${runnerName} not registered yet` + : `Envoy ${runnerName} is still connected`, + ); + } + } catch (err) { + probeError = err; + } + + if (attempt < 149) { + await new Promise((resolve) => setTimeout(resolve, 100)); + } + } + + throw probeError; +} + +async function ensureSharedEngineRuntime( + registry: any, + endpoint: string, + namespace: string, + runnerName: string, + token: string, +): Promise { + if (!sharedEngineRuntimePromise) { + sharedEngineRuntimePromise = (async () => { + const driverConfig = createEngineDriver(); + + registry.config.driver = driverConfig; + registry.config.endpoint = endpoint; + registry.config.namespace = namespace; + registry.config.token = token; + registry.config.envoy = { + ...registry.config.envoy, + poolName: runnerName, + }; + + const parsedConfig = registry.parseConfig(); + + const managerDriver = driverConfig.manager?.(parsedConfig); + invariant(managerDriver, "missing manager driver"); + const inlineClient = createClientWithDriver( + managerDriver, + convertRegistryConfigToClientConfig(parsedConfig), + ); + + const actorDriver = driverConfig.actor( + parsedConfig, + managerDriver, + inlineClient, + ); + + await actorDriver.waitForReady?.(); + await waitForEnvoyCount( + endpoint, + namespace, + runnerName, + token, + true, + ); + + return { + endpoint, + namespace, + runnerName, + token, + driverConfig, + actorDriver, + forceDisconnectKvChannel: async () => { + const { disconnectKvChannelForCurrentConfig } = + await import("@/db/native-sqlite"); + return await disconnectKvChannelForCurrentConfig({ + endpoint, + token, + namespace, + }); + }, + }; + })(); + } + + return await sharedEngineRuntimePromise; +} + const driverTestConfig = { // Use real timers for engine-runner tests useRealTimers: true, @@ -26,168 +364,87 @@ const driverTestConfig = { process.env.RIVET_NAMESPACE_ENDPOINT || process.env.RIVET_API_ENDPOINT || endpoint; - const namespace = `test-${crypto.randomUUID().slice(0, 8)}`; - const runnerName = "test-runner"; const token = "dev"; - - // Create namespace. - const response = await fetch( - `${namespaceEndpoint}/namespaces`, - { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: "Bearer dev", - }, - body: JSON.stringify({ - name: namespace, - display_name: namespace, - }), - }, - ); - if (!response.ok) { - const errorBody = await response.text().catch(() => ""); - throw new Error( - `Create namespace failed at ${namespaceEndpoint}: ${response.status} ${response.statusText} ${errorBody}`, - ); - } - - // Create driver config. - const driverConfig = createEngineDriver(); - - // Start the actor driver. - registry.config.driver = driverConfig; - registry.config.endpoint = endpoint; - registry.config.namespace = namespace; - registry.config.token = token; - registry.config.envoy = { - ...registry.config.envoy, - poolName: runnerName, - }; - - // Parse config only after mutating registry.config so the manager - // and actor drivers do not get stale namespace/runner values from - // previous tests. - const parsedConfig = registry.parseConfig(); - - const managerDriver = driverConfig.manager?.(parsedConfig); - invariant(managerDriver, "missing manager driver"); - const inlineClient = createClientWithDriver( - managerDriver, - convertRegistryConfigToClientConfig(parsedConfig), + const namespace = await ensureSharedNamespace( + namespaceEndpoint, + token, ); - - const actorDriver = driverConfig.actor( - parsedConfig, - managerDriver, - inlineClient, + const runnerName = await ensureSharedRunnerConfig( + namespaceEndpoint, + namespace, + token, ); - - // Wait for runner registration so tests do not race actor creation - // against asynchronous runner connect. - const runnersUrl = new URL( - `${endpoint.replace(/\/$/, "")}/runners`, + const runtime = await ensureSharedEngineRuntime( + registry, + endpoint, + namespace, + runnerName, + token, ); - runnersUrl.searchParams.set("namespace", namespace); - runnersUrl.searchParams.set("name", runnerName); - let probeError: unknown; - for (let attempt = 0; attempt < 120; attempt++) { - try { - const runnerResponse = await fetch(runnersUrl, { - method: "GET", - headers: { - Authorization: `Bearer ${token}`, - }, - }); - if (!runnerResponse.ok) { - const errorBody = await runnerResponse - .text() - .catch(() => ""); - probeError = new Error( - `List runners failed: ${runnerResponse.status} ${runnerResponse.statusText} ${errorBody}`, - ); - } else { - const responseJson = - (await runnerResponse.json()) as { - runners?: Array<{ name?: string }>; - }; - const hasRunner = !!responseJson.runners?.some( - (runner) => runner.name === runnerName, - ); - if (hasRunner) { - probeError = undefined; - break; - } - probeError = new Error( - `Runner ${runnerName} not registered yet`, - ); - } - } catch (err) { - probeError = err; - } - if (attempt < 119) { - await new Promise((resolve) => - setTimeout(resolve, 100), - ); - } - } - if (probeError) { - throw probeError; - } return { rivetEngine: { - endpoint, - namespace, - runnerName, - token, - }, - driver: driverConfig, - hardCrashActor: async (actorId: string) => { - await actorDriver.hardCrashActor?.(actorId); - }, - hardCrashPreservesData: true, - cleanup: async () => { - await actorDriver.shutdownRunner?.(true); + endpoint: runtime.endpoint, + namespace: runtime.namespace, + runnerName: runtime.runnerName, + token: runtime.token, }, - }; + driver: runtime.driverConfig, + hardCrashActor: async (actorId: string) => { + await runtime.actorDriver.hardCrashActor?.(actorId); + }, + hardCrashPreservesData: true, + forceDisconnectKvChannel: runtime.forceDisconnectKvChannel, + cleanup: async () => { + await destroyNamespaceActors( + namespaceEndpoint, + namespace, + token, + ); + }, + }; }, ); }, } satisfies Omit; -runDriverTests(driverTestConfig); +describe.sequential("engine driver", { timeout: 30_000 }, () => { + runDriverTests(driverTestConfig); -describe("engine startup kv preload", () => { - test("wakes actors with envoy-provided preloaded kv", async (c) => { - const { client } = await setupDriverTest(c, { - ...driverTestConfig, - clientType: "http", - encoding: "bare", - }); - const handle = client.sleep.getOrCreate(); + describe("engine startup kv preload", () => { + test("wakes actors with envoy-provided preloaded kv", async (c) => { + const { client } = await setupDriverTest(c, { + ...driverTestConfig, + clientType: "http", + encoding: "bare", + }); + const handle = client.sleep.getOrCreate(); - await handle.getCounts(); - await handle.triggerSleep(); + await handle.getCounts(); + await handle.triggerSleep(); - await vi.waitFor( - async () => { - const counts = await handle.getCounts(); - expect(counts.sleepCount).toBeGreaterThanOrEqual(1); - expect(counts.startCount).toBeGreaterThanOrEqual(2); - }, - { timeout: 5_000, interval: 100 }, - ); + await vi.waitFor( + async () => { + const counts = await handle.getCounts(); + expect(counts.sleepCount).toBeGreaterThanOrEqual(1); + expect(counts.startCount).toBeGreaterThanOrEqual(2); + }, + { timeout: 5_000, interval: 100 }, + ); - const gatewayUrl = await handle.getGatewayUrl(); - const response = await fetch(`${gatewayUrl}/inspector/metrics`, { - headers: { Authorization: "Bearer token" }, - }); - expect(response.status).toBe(200); + const actorId = await handle.resolve(); + const gatewayUrl = await client.sleep + .getForId(actorId) + .getGatewayUrl(); + const response = await fetch(`${gatewayUrl}/inspector/metrics`, { + headers: { Authorization: "Bearer token" }, + }); + expect(response.status).toBe(200); - const metrics: any = await response.json(); - expect(metrics.startup_is_new.value).toBe(0); - expect(metrics.startup_internal_preload_kv_entries.value).toBeGreaterThan(0); - expect(metrics.startup_kv_round_trips.value).toBe(0); + const metrics: any = await response.json(); + expect(metrics.startup_is_new.value).toBe(0); + expect(metrics.startup_internal_preload_kv_entries.value).toBeGreaterThan(0); + expect(metrics.startup_kv_round_trips.value).toBe(0); + }); }); }); diff --git a/rivetkit-typescript/packages/sqlite-native/sqlite-native.linux-x64-gnu.node b/rivetkit-typescript/packages/sqlite-native/sqlite-native.linux-x64-gnu.node deleted file mode 100755 index c108848343..0000000000 Binary files a/rivetkit-typescript/packages/sqlite-native/sqlite-native.linux-x64-gnu.node and /dev/null differ diff --git a/rivetkit-typescript/packages/sqlite-native/turbo.json b/rivetkit-typescript/packages/sqlite-native/turbo.json new file mode 100644 index 0000000000..1269f41c8d --- /dev/null +++ b/rivetkit-typescript/packages/sqlite-native/turbo.json @@ -0,0 +1,24 @@ +{ + "$schema": "https://turbo.build/schema.json", + "extends": ["//"], + "tasks": { + "build": { + "inputs": [ + "src/**", + "scripts/**", + "Cargo.toml", + "Cargo.lock", + "build.rs", + "package.json", + "index.js", + "index.d.ts" + ], + "outputs": [ + "sqlite-native.*.node", + "npm/**/*.node", + "target/release/librivetkit_sqlite_native.*", + "target/release/rivetkit_sqlite_native.dll" + ] + } + } +}