diff --git a/Cargo.lock b/Cargo.lock index 3451fc5c60..80b78289c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4699,7 +4699,6 @@ name = "rivet-envoy-protocol" version = "2.2.1" dependencies = [ "anyhow", - "gasoline", "hex", "rand 0.8.5", "rivet-util", diff --git a/engine/artifacts/config-schema.json b/engine/artifacts/config-schema.json index 4d7e896f7d..ef248614ec 100644 --- a/engine/artifacts/config-schema.json +++ b/engine/artifacts/config-schema.json @@ -228,11 +228,21 @@ "description": "Configuration for the cache layer.", "type": "object", "required": [ - "driver" + "enabled" ], "properties": { "driver": { - "$ref": "#/definitions/CacheDriver" + "anyOf": [ + { + "$ref": "#/definitions/CacheDriver" + }, + { + "type": "null" + } + ] + }, + "enabled": { + "type": "boolean" } }, "additionalProperties": false diff --git a/engine/packages/api-public/src/runner_configs/serverless_health_check.rs b/engine/packages/api-public/src/runner_configs/serverless_health_check.rs index 5aeac485c9..56f4999252 100644 --- a/engine/packages/api-public/src/runner_configs/serverless_health_check.rs +++ b/engine/packages/api-public/src/runner_configs/serverless_health_check.rs @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; use utoipa::IntoParams; use utoipa::ToSchema; -use super::utils::{ServerlessMetadataError, fetch_serverless_runner_metadata}; +use super::utils::{ServerlessMetadataError, fetch_serverless_metadata}; use crate::ctx::ApiCtx; #[derive(Debug, Serialize, Deserialize, Clone, IntoParams)] @@ -72,7 +72,7 @@ async fn serverless_health_check_inner( let ServerlessHealthCheckRequest { url, headers } = body; - match fetch_serverless_runner_metadata(&ctx, url, headers).await { + match fetch_serverless_metadata(&ctx, url, headers).await { Ok(metadata) => Ok(ServerlessHealthCheckResponse::Success { version: metadata.version, }), diff --git a/engine/packages/api-public/src/runner_configs/utils.rs b/engine/packages/api-public/src/runner_configs/utils.rs index 04cdd35cd2..6df2dbaa9e 100644 --- a/engine/packages/api-public/src/runner_configs/utils.rs +++ b/engine/packages/api-public/src/runner_configs/utils.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use std::collections::HashMap; use gas::prelude::*; @@ -35,11 +36,11 @@ impl From for ServerlessMetad /// /// Returns metadata including runtime, version, and actor names if available. #[tracing::instrument(skip_all)] -pub async fn fetch_serverless_runner_metadata( +pub async fn fetch_serverless_metadata( ctx: &ApiCtx, url: String, headers: HashMap, -) -> Result { +) -> std::result::Result { ctx.op(pegboard::ops::serverless_metadata::fetch::Input { url, headers }) .await .map_err(|_| ServerlessMetadataError::RequestFailed {})? @@ -54,49 +55,23 @@ pub async fn refresh_runner_config_metadata( runner_name: String, url: String, headers: HashMap, -) -> anyhow::Result<()> { +) -> Result<()> { tracing::debug!( ?namespace_id, ?runner_name, "refreshing runner config metadata" ); - // Fetch metadata using the op - let metadata = ctx - .op(pegboard::ops::serverless_metadata::fetch::Input { url, headers }) - .await? - .map_err(|e| { - pegboard::errors::ServerlessRunnerPool::FailedToFetchMetadata { reason: e }.build() - })?; - - if !metadata.actor_names.is_empty() { - tracing::debug!( - actor_names_count = metadata.actor_names.len(), - "storing actor names metadata" - ); - - // Convert and store actor names - let actor_names: Vec = metadata - .actor_names - .into_iter() - .map( - |a| pegboard::ops::actor_name::upsert_batch::ActorNameEntry { - name: a.name, - metadata: a.metadata, - }, - ) - .collect(); - - ctx.op(pegboard::ops::actor_name::upsert_batch::Input { - namespace_id, - actor_names, - }) - .await?; - - tracing::debug!("successfully stored actor names metadata"); - } else { - tracing::debug!("no actor names to store"); - } + ctx.op(pegboard::ops::runner_config::refresh_metadata::Input { + namespace_id, + runner_name, + url, + headers, + }) + .await? + .map_err(|e| { + pegboard::errors::ServerlessRunnerPool::FailedToFetchMetadata { reason: e }.build() + })?; Ok(()) } diff --git a/engine/packages/cache/src/inner.rs b/engine/packages/cache/src/inner.rs index d34e0e32c6..25afcc5b8c 100644 --- a/engine/packages/cache/src/inner.rs +++ b/engine/packages/cache/src/inner.rs @@ -14,7 +14,7 @@ pub type Cache = Arc; /// Utility type used to hold information relating to caching. pub struct CacheInner { - pub(crate) driver: Driver, + pub(crate) driver: Option, pub(crate) ups: Option, } @@ -33,7 +33,10 @@ impl CacheInner { let ups = pools.ups().ok(); match &config.cache().driver { - rivet_config::config::CacheDriver::InMemory => Ok(Self::new_in_memory(10000, ups)), + Some(rivet_config::config::CacheDriver::InMemory) => { + Ok(Self::new_in_memory(10000, ups)) + } + None => Ok(Self::new_disabled()), } } @@ -41,7 +44,17 @@ impl CacheInner { pub fn new_in_memory(max_capacity: u64, ups: Option) -> Cache { let driver = Driver::InMemory(InMemoryDriver::new(max_capacity)); - Arc::new(CacheInner { driver, ups }) + Arc::new(CacheInner { + driver: Some(driver), + ups, + }) + } + + pub fn new_disabled() -> Cache { + Arc::new(CacheInner { + driver: None, + ups: None, + }) } pub(crate) fn in_flight(&self) -> &scc::HashMap> { diff --git a/engine/packages/cache/src/req_config.rs b/engine/packages/cache/src/req_config.rs index ccacc5f94d..f9d84414d5 100644 --- a/engine/packages/cache/src/req_config.rs +++ b/engine/packages/cache/src/req_config.rs @@ -81,15 +81,27 @@ impl RequestConfig { let mut ctx = GetterCtx::new(keys); + // No driver (cache disabled) + let Some(driver) = &self.cache.driver else { + let keys = ctx.unresolved_keys(); + let ctx = getter(ctx, keys).await.map_err(Error::Getter)?; + + metrics::CACHE_VALUE_EMPTY_TOTAL + .with_label_values(&[&base_key]) + .inc_by(ctx.unresolved_keys().len() as u64); + + return Ok(ctx.into_values()); + }; + // Build driver-specific cache keys let (keys, cache_keys): (Vec<_>, Vec<_>) = ctx .entries() - .map(|(key, _)| (key.clone(), self.cache.driver.process_key(&base_key, key))) + .map(|(key, _)| (key.clone(), driver.process_key(&base_key, key))) .unzip(); let cache_keys_len = cache_keys.len(); // Attempt to fetch value from cache, fall back to getter - match self.cache.driver.get(&base_key, &cache_keys).await { + match driver.get(&base_key, &cache_keys).await { Ok(cached_values) => { debug_assert_eq!( cache_keys_len, @@ -128,7 +140,7 @@ impl RequestConfig { // Determine which keys are currently being fetched and not for key in remaining_keys { - let cache_key = self.cache.driver.process_key(&base_key, &key); + let cache_key = driver.process_key(&base_key, &key); match self.cache.in_flight().entry_async(cache_key).await { scc::hash_map::Entry::Occupied(broadcast) => { waiting_keys.push((key, broadcast.subscribe())); @@ -141,7 +153,6 @@ impl RequestConfig { } let getter2 = getter.clone(); - let cache = self.cache.clone(); let ctx2 = GetterCtx::new(leased_keys.clone()); let base_key2 = base_key.clone(); let leased_keys2 = leased_keys.clone(); @@ -178,8 +189,7 @@ impl RequestConfig { .into_iter() .partition_map(|(key, succeeded)| { if succeeded { - let cache_key = - cache.driver.process_key(&base_key2, &key); + let cache_key = driver.process_key(&base_key2, &key); Either::Left((key, cache_key)) } else { Either::Right(key) @@ -193,7 +203,7 @@ impl RequestConfig { if succeeded_cache_keys.is_empty() { Ok(Vec::new()) } else { - cache.driver.get(&base_key2, &succeeded_cache_keys).await + driver.get(&base_key2, &succeeded_cache_keys).await } }, async { @@ -255,7 +265,7 @@ impl RequestConfig { .into_iter() .filter_map(|(key, value)| { // Process the key with the appropriate driver - let cache_key = self.cache.driver.process_key(&base_key, key); + let cache_key = driver.process_key(&base_key, key); // Try to decode the value using the driver match encoder(value) { Ok(value_bytes) => Some((cache_key, value_bytes, expire_at)), @@ -269,10 +279,9 @@ impl RequestConfig { .collect::>(); if !entries_values.is_empty() { - let cache = self.cache.clone(); let base_key_clone = base_key.clone(); - if let Err(err) = cache.driver.set(&base_key_clone, entries_values).await { + if let Err(err) = driver.set(&base_key_clone, entries_values).await { tracing::error!(?err, "failed to write to cache"); } @@ -281,7 +290,7 @@ impl RequestConfig { // Release leases for key in leased_keys { - let cache_key = self.cache.driver.process_key(&base_key, &key); + let cache_key = driver.process_key(&base_key, &key); self.cache.in_flight().remove_async(&cache_key).await; } } @@ -298,13 +307,17 @@ impl RequestConfig { "failed to read batch keys from cache, falling back to getter" ); + let keys = ctx.unresolved_keys(); + metrics::CACHE_REQUEST_ERRORS .with_label_values(&[&base_key]) .inc(); + metrics::CACHE_VALUE_MISS_TOTAL + .with_label_values(&[base_key.as_str()]) + .inc_by(keys.len() as u64); // Fall back to the getter since we can't fetch the value from // the cache - let keys = ctx.unresolved_keys(); let ctx = getter(ctx, keys).await.map_err(Error::Getter)?; metrics::CACHE_VALUE_EMPTY_TOTAL @@ -325,11 +338,16 @@ impl RequestConfig { where Key: CacheKey + Send + Sync, { + // Cache disabled + let Some(driver) = &self.cache.driver else { + return Ok(()); + }; + // Build keys let base_key = base_key.as_ref(); let cache_keys = keys .into_iter() - .map(|key| self.cache.driver.process_key(base_key, &key)) + .map(|key| driver.process_key(base_key, &key)) .collect::>(); if cache_keys.is_empty() { @@ -375,6 +393,11 @@ impl RequestConfig { base_key: impl AsRef + Debug, keys: Vec, ) -> Result<()> { + // Cache disabled + let Some(driver) = &self.cache.driver else { + return Ok(()); + }; + let base_key = base_key.as_ref(); if keys.is_empty() { @@ -389,7 +412,7 @@ impl RequestConfig { .inc_by(keys.len() as u64); // Delete keys locally - match self.cache.driver.delete(base_key, keys).await { + match driver.delete(base_key, keys).await { Ok(_) => { tracing::trace!("successfully deleted keys"); } diff --git a/engine/packages/config/src/config/cache.rs b/engine/packages/config/src/config/cache.rs index b12117cc97..83351c2ae5 100644 --- a/engine/packages/config/src/config/cache.rs +++ b/engine/packages/config/src/config/cache.rs @@ -5,17 +5,25 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] #[serde(deny_unknown_fields)] pub struct Cache { - pub driver: CacheDriver, + pub enabled: bool, + pub driver: Option, } impl Default for Cache { fn default() -> Cache { Self { - driver: CacheDriver::InMemory, + enabled: true, + driver: None, } } } +impl Cache { + pub fn driver(&self) -> CacheDriver { + self.driver.clone().unwrap_or(CacheDriver::InMemory) + } +} + #[derive(Debug, Serialize, Deserialize, Clone, JsonSchema)] #[serde(rename_all = "snake_case", deny_unknown_fields)] pub enum CacheDriver { diff --git a/engine/packages/engine/tests/common/test_envoy.rs b/engine/packages/engine/tests/common/test_envoy.rs index 1de3fc4f78..6f13429453 100644 --- a/engine/packages/engine/tests/common/test_envoy.rs +++ b/engine/packages/engine/tests/common/test_envoy.rs @@ -8,6 +8,7 @@ use std::collections::HashMap; use std::sync::Arc; // Re-export everything from the standalone package +pub use rivet_envoy_protocol::PROTOCOL_VERSION; pub use rivet_test_envoy::{ ActorConfig, ActorEvent, ActorLifecycleEvent, ActorStartResult, ActorStopResult, CountingCrashActor, CrashNTimesThenSucceedActor, CrashOnStartActor, CustomActor, @@ -15,7 +16,6 @@ pub use rivet_test_envoy::{ NotifyOnStartActor, SleepImmediatelyActor, StopImmediatelyActor, TestActor, TimeoutActor, VerifyInputActor, }; -pub use rivet_envoy_protocol::PROTOCOL_VERSION; // Type alias for backwards compatibility pub type TestEnvoy = Envoy; diff --git a/engine/packages/pegboard/src/ops/actor_name/mod.rs b/engine/packages/pegboard/src/ops/actor_name/mod.rs deleted file mode 100644 index 90013308ef..0000000000 --- a/engine/packages/pegboard/src/ops/actor_name/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod upsert_batch; diff --git a/engine/packages/pegboard/src/ops/actor_name/upsert_batch.rs b/engine/packages/pegboard/src/ops/actor_name/upsert_batch.rs deleted file mode 100644 index 3bcefb4e9e..0000000000 --- a/engine/packages/pegboard/src/ops/actor_name/upsert_batch.rs +++ /dev/null @@ -1,49 +0,0 @@ -use anyhow::Result; -use gas::prelude::*; -use rivet_data::converted::ActorNameKeyData; - -use crate::keys; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Input { - pub namespace_id: Id, - /// Actor names with their metadata. Metadata must be a JSON object. - pub actor_names: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ActorNameEntry { - pub name: String, - pub metadata: serde_json::Map, -} - -#[operation] -pub async fn pegboard_actor_name_upsert_batch(ctx: &OperationCtx, input: &Input) -> Result<()> { - if input.actor_names.is_empty() { - return Ok(()); - } - - ctx.udb()? - .run(|tx| { - let actor_names = input.actor_names.clone(); - let namespace_id = input.namespace_id; - async move { - let tx = tx.with_subspace(keys::subspace()); - - for entry in actor_names { - tx.write( - &keys::ns::ActorNameKey::new(namespace_id, entry.name), - ActorNameKeyData { - metadata: entry.metadata, - }, - )?; - } - - Ok(()) - } - }) - .custom_instrument(tracing::info_span!("actor_name_upsert_batch_tx")) - .await?; - - Ok(()) -} diff --git a/engine/packages/pegboard/src/ops/mod.rs b/engine/packages/pegboard/src/ops/mod.rs index 1e67ee5e44..d4cd7e5755 100644 --- a/engine/packages/pegboard/src/ops/mod.rs +++ b/engine/packages/pegboard/src/ops/mod.rs @@ -1,5 +1,4 @@ pub mod actor; -pub mod actor_name; pub mod envoy; pub mod runner; pub mod runner_config; diff --git a/engine/packages/pegboard/src/ops/runner_config/mod.rs b/engine/packages/pegboard/src/ops/runner_config/mod.rs index 9a0727da0c..693cf98fa9 100644 --- a/engine/packages/pegboard/src/ops/runner_config/mod.rs +++ b/engine/packages/pegboard/src/ops/runner_config/mod.rs @@ -3,4 +3,5 @@ pub mod ensure_normal_if_missing; pub mod get; pub mod get_error; pub mod list; +pub mod refresh_metadata; pub mod upsert; diff --git a/engine/packages/pegboard/src/ops/runner_config/refresh_metadata.rs b/engine/packages/pegboard/src/ops/runner_config/refresh_metadata.rs new file mode 100644 index 0000000000..8cd127a6b0 --- /dev/null +++ b/engine/packages/pegboard/src/ops/runner_config/refresh_metadata.rs @@ -0,0 +1,124 @@ +use anyhow::Result; +use gas::prelude::*; +use rivet_data::converted::ActorNameKeyData; +use rivet_types::actor::RunnerPoolError; +use std::collections::HashMap; +use universaldb::prelude::*; + +use crate::{ + keys, + ops::serverless_metadata::fetch::{Output, ServerlessMetadataError}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Input { + pub namespace_id: Id, + pub runner_name: String, + pub url: String, + pub headers: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ActorNameEntry { + pub name: String, + pub metadata: serde_json::Map, +} + +#[operation] +pub async fn pegboard_runner_config_refresh_metadata( + ctx: &OperationCtx, + input: &Input, +) -> Result> { + let metadata = ctx + .op(crate::ops::serverless_metadata::fetch::Input { + url: input.url.clone(), + headers: input.headers.clone(), + }) + .await?; + + let metadata = match metadata { + Ok(x) => x, + Err(err) => return Ok(Err(err)), + }; + + // Save protocol to udb + let downgraded = ctx + .udb()? + .run(|tx| async move { + let tx = tx.with_subspace(namespace::keys::subspace()); + + let protocol_version_key = keys::runner_config::ProtocolVersionKey::new( + input.namespace_id, + input.runner_name.clone(), + ); + + if let Some(protocol_version) = metadata.envoy_protocol_version { + tx.write(&protocol_version_key, protocol_version)?; + + Ok(false) + } else if tx.exists(&protocol_version_key, Serializable).await? { + Ok(true) + } else { + Ok(false) + } + }) + .await?; + + if downgraded { + report_error( + ctx, + input.namespace_id, + &input.runner_name, + RunnerPoolError::Downgrade, + ) + .await; + } + + // Update actor names in DB if present + if !metadata.actor_names.is_empty() { + ctx.udb()? + .run(|tx| { + let metadata = &metadata; + let namespace_id = input.namespace_id; + async move { + let tx = tx.with_subspace(keys::subspace()); + + for entry in &metadata.actor_names { + tx.write( + &keys::ns::ActorNameKey::new(namespace_id, entry.name.clone()), + ActorNameKeyData { + metadata: entry.metadata.clone(), + }, + )?; + } + + Ok(()) + } + }) + .custom_instrument(tracing::info_span!("actor_name_upsert_batch_tx")) + .await?; + } + + Ok(Ok(metadata)) +} + +/// Report an error to the error tracker workflow. +async fn report_error( + ctx: &OperationCtx, + namespace_id: Id, + pool_name: &str, + error: RunnerPoolError, +) { + if let Err(err) = ctx + .signal(crate::workflows::runner_pool_error_tracker::ReportError { error }) + .bypass_signal_from_workflow_I_KNOW_WHAT_IM_DOING() + .to_workflow::() + .tag("namespace_id", namespace_id) + .tag("runner_name", pool_name) + .graceful_not_found() + .send() + .await + { + tracing::warn!(?err, "failed to report serverless error"); + } +} diff --git a/engine/packages/pegboard/src/workflows/actor/mod.rs b/engine/packages/pegboard/src/workflows/actor/mod.rs index 90da2f2833..283cc94da3 100644 --- a/engine/packages/pegboard/src/workflows/actor/mod.rs +++ b/engine/packages/pegboard/src/workflows/actor/mod.rs @@ -238,6 +238,12 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .tag("actor_id", input.actor_id) .dispatch() .await?; + + ctx.msg(MigratedToV2 {}) + .topic(("actor_id", input.actor_id)) + .send() + .await?; + return Ok(()); } }; diff --git a/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs b/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs index 32b73c4a3b..bcb600f784 100644 --- a/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs +++ b/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs @@ -2,10 +2,7 @@ use std::time::Duration; use futures_util::FutureExt; use gas::prelude::*; -use rivet_types::{actor::RunnerPoolError, runner_configs::RunnerConfigKind}; -use universaldb::prelude::*; - -use crate::{keys, ops::actor_name::upsert_batch::ActorNameEntry}; +use rivet_types::runner_configs::RunnerConfigKind; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Input { @@ -176,9 +173,13 @@ async fn poll_metadata(ctx: &ActivityCtx, input: &PollMetadataInput) -> Result

Result

= metadata - .actor_names - .iter() - .map(|a| ActorNameEntry { - name: a.name.clone(), - metadata: a.metadata.clone(), - }) - .collect(); - - ctx.op(crate::ops::actor_name::upsert_batch::Input { - namespace_id: input.namespace_id, - actor_names, - }) - .await?; - } - // Drain older runners if runner_version is set let older_runner_workflow_ids = if let Some(version) = metadata.runner_version { ctx.op(crate::ops::runner::drain::Input { @@ -282,27 +232,6 @@ async fn poll_metadata(ctx: &ActivityCtx, input: &PollMetadataInput) -> Result

() - .tag("namespace_id", namespace_id) - .tag("runner_name", pool_name) - .graceful_not_found() - .send() - .await - { - tracing::warn!(?err, "failed to report serverless error"); - } -} - #[signal("pegboard_runner_pool_metadata_poller_endpoint_config_changed")] #[derive(Debug)] pub struct EndpointConfigChanged {} diff --git a/engine/packages/runner-protocol/build.rs b/engine/packages/runner-protocol/build.rs index f39d661e13..e7df077bbe 100644 --- a/engine/packages/runner-protocol/build.rs +++ b/engine/packages/runner-protocol/build.rs @@ -62,17 +62,15 @@ mod typescript { let output_path = src_dir.join("index.ts"); - let output = Command::new( - repo_root.join("node_modules/@bare-ts/tools/dist/bin/cli.js"), - ) - .arg("compile") - .arg("--generator") - .arg("ts") - .arg(highest_version_path) - .arg("-o") - .arg(&output_path) - .output() - .expect("Failed to execute bare compiler for TypeScript"); + let output = Command::new(repo_root.join("node_modules/@bare-ts/tools/dist/bin/cli.js")) + .arg("compile") + .arg("--generator") + .arg("ts") + .arg(highest_version_path) + .arg("-o") + .arg(&output_path) + .output() + .expect("Failed to execute bare compiler for TypeScript"); if !output.status.success() { panic!( diff --git a/engine/sdks/rust/envoy-client/src/actor.rs b/engine/sdks/rust/envoy-client/src/actor.rs index 9b9839fbfd..01538a52f0 100644 --- a/engine/sdks/rust/envoy-client/src/actor.rs +++ b/engine/sdks/rust/envoy-client/src/actor.rs @@ -11,7 +11,7 @@ use crate::connection::ws_send; use crate::context::SharedContext; use crate::handle::EnvoyHandle; use crate::stringify::stringify_to_rivet_tunnel_message_kind; -use crate::utils::{id_to_str, wrapping_add_u16, wrapping_lte_u16, wrapping_sub_u16, BufferMap}; +use crate::utils::{BufferMap, id_to_str, wrapping_add_u16, wrapping_lte_u16, wrapping_sub_u16}; pub enum ToActor { Intent { @@ -135,7 +135,13 @@ async fn actor_inner( let start_result = shared .config .callbacks - .on_actor_start(handle.clone(), actor_id.clone(), generation, config, preloaded_kv) + .on_actor_start( + handle.clone(), + actor_id.clone(), + generation, + config, + preloaded_kv, + ) .await; if let Err(error) = start_result { @@ -165,9 +171,7 @@ async fn actor_inner( ToActor::Intent { intent, error } => { send_event( &mut ctx, - protocol::Event::EventActorIntent(protocol::EventActorIntent { - intent, - }), + protocol::Event::EventActorIntent(protocol::EventActorIntent { intent }), ); if error.is_some() { ctx.error = error; @@ -192,9 +196,7 @@ async fn actor_inner( ToActor::SetAlarm { alarm_ts } => { send_event( &mut ctx, - protocol::Event::EventActorSetAlarm(protocol::EventActorSetAlarm { - alarm_ts, - }), + protocol::Event::EventActorSetAlarm(protocol::EventActorSetAlarm { alarm_ts }), ); } ToActor::ReqStart { message_id, req } => { @@ -237,9 +239,12 @@ async fn actor_inner( fn send_event(ctx: &mut ActorContext, inner: protocol::Event) { let checkpoint = increment_checkpoint(ctx); - let _ = ctx.shared.envoy_tx.send(crate::envoy::ToEnvoyMessage::SendEvents { - events: vec![protocol::EventWrapper { checkpoint, inner }], - }); + let _ = ctx + .shared + .envoy_tx + .send(crate::envoy::ToEnvoyMessage::SendEvents { + events: vec![protocol::EventWrapper { checkpoint, inner }], + }); } async fn handle_stop( @@ -258,12 +263,7 @@ async fn handle_stop( .shared .config .callbacks - .on_actor_stop( - handle.clone(), - ctx.actor_id.clone(), - ctx.generation, - reason, - ) + .on_actor_stop(handle.clone(), ctx.actor_id.clone(), ctx.generation, reason) .await; if let Err(error) = stop_result { @@ -298,7 +298,11 @@ fn handle_req_start( ctx.pending_requests .insert(&[&message_id.gateway_id, &message_id.request_id], pending); - let headers: HashMap = req.headers.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + let headers: HashMap = req + .headers + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); let body_stream = if req.stream { let (body_tx, body_rx) = mpsc::unbounded_channel::>(); @@ -440,7 +444,6 @@ async fn handle_ws_open( match ws_result { Ok(ws_handler) => { - ctx.ws_entries.insert( &[&message_id.gateway_id, &message_id.request_id], WsEntry { @@ -457,15 +460,10 @@ async fn handle_ws_open( let shared = ctx.shared.clone(); let gateway_id = message_id.gateway_id; let request_id = message_id.request_id; - let ws_msg_counter = std::sync::Arc::new(std::sync::atomic::AtomicU16::new(0)); - // Store counter ref on pending request so send_actor_message can coordinate - if let Some(req) = ctx.pending_requests.get_mut(&[&gateway_id, &request_id]) { - // The pending request's envoy_message_index will be managed separately; - // the outgoing task uses its own counter space starting from a high offset. - } tokio::spawn(async move { let mut idx: u16 = 0; while let Some(msg) = outgoing_rx.recv().await { + idx += 1; match msg { crate::config::WsOutgoing::Message { data, binary } => { ws_send( @@ -523,13 +521,17 @@ async fn handle_ws_open( .await; // Call on_open if provided - if let Some(ws_entry) = ctx + if let Some(ws) = ctx .ws_entries .get_mut(&[&message_id.gateway_id, &message_id.request_id]) { - if let Some(handler) = &mut ws_entry.ws_handler { + if let Some(handler) = &mut ws.ws_handler { if let Some(on_open) = handler.on_open.take() { - on_open().await; + let sender = crate::config::WebSocketSender { + tx: ws.outgoing_tx.clone(), + }; + + on_open(sender).await; } } } @@ -696,7 +698,9 @@ async fn handle_hws_restore( }; let (hws_outgoing_tx, _hws_outgoing_rx) = mpsc::unbounded_channel(); - let hws_sender = crate::config::WebSocketSender { tx: hws_outgoing_tx.clone() }; + let hws_sender = crate::config::WebSocketSender { + tx: hws_outgoing_tx.clone(), + }; let ws_result = ctx .shared @@ -782,9 +786,9 @@ async fn handle_hws_restore( // Process loaded but not connected (stale) for meta in &meta_entries { - let is_connected = hibernating_requests.iter().any(|req| { - req.gateway_id == meta.gateway_id && req.request_id == meta.request_id - }); + let is_connected = hibernating_requests + .iter() + .any(|req| req.gateway_id == meta.gateway_id && req.request_id == meta.request_id); if !is_connected { tracing::warn!( @@ -873,9 +877,7 @@ async fn send_actor_message( request_id: protocol::RequestId, message_kind: protocol::ToRivetTunnelMessageKind, ) { - let req = ctx - .pending_requests - .get_mut(&[&gateway_id, &request_id]); + let req = ctx.pending_requests.get_mut(&[&gateway_id, &request_id]); let envoy_message_index = if let Some(req) = req { let idx = req.envoy_message_index; req.envoy_message_index += 1; @@ -899,11 +901,7 @@ async fn send_actor_message( }; let buffer_msg = msg.clone(); - let failed = ws_send( - &ctx.shared, - protocol::ToRivet::ToRivetTunnelMessage(msg), - ) - .await; + let failed = ws_send(&ctx.shared, protocol::ToRivet::ToRivetTunnelMessage(msg)).await; if failed { if tracing::enabled!(tracing::Level::DEBUG) { diff --git a/engine/sdks/rust/envoy-client/src/config.rs b/engine/sdks/rust/envoy-client/src/config.rs index eef7ae4804..8b0a7a4131 100644 --- a/engine/sdks/rust/envoy-client/src/config.rs +++ b/engine/sdks/rust/envoy-client/src/config.rs @@ -47,9 +47,6 @@ pub struct EnvoyConfig { /// Optional envoy key. If not provided, a UUID will be generated. pub envoy_key: Option, - /// Whether to automatically restart actors that crash. - pub auto_restart: bool, - /// Debug option to inject artificial latency (in ms) into WebSocket communication. pub debug_latency_ms: Option, @@ -117,7 +114,7 @@ pub trait EnvoyCallbacks: Send + Sync + 'static { pub struct WebSocketHandler { pub on_message: Box BoxFuture<()> + Send + Sync>, pub on_close: Box BoxFuture<()> + Send + Sync>, - pub on_open: Option BoxFuture<()> + Send>>, + pub on_open: Option BoxFuture<()> + Send>>, } pub struct WebSocketMessage { @@ -137,8 +134,14 @@ pub struct WebSocketSender { } pub(crate) enum WsOutgoing { - Message { data: Vec, binary: bool }, - Close { code: Option, reason: Option }, + Message { + data: Vec, + binary: bool, + }, + Close { + code: Option, + reason: Option, + }, } impl WebSocketSender { diff --git a/engine/sdks/rust/envoy-client/src/connection.rs b/engine/sdks/rust/envoy-client/src/connection.rs index be365d2c68..1e00b3cca4 100644 --- a/engine/sdks/rust/envoy-client/src/connection.rs +++ b/engine/sdks/rust/envoy-client/src/connection.rs @@ -10,7 +10,7 @@ use vbare::OwnedVersionedData; use crate::context::{SharedContext, WsTxMessage}; use crate::envoy::ToEnvoyMessage; use crate::stringify::{stringify_to_envoy, stringify_to_rivet}; -use crate::utils::{calculate_backoff, parse_ws_close_reason, BackoffOptions}; +use crate::utils::{BackoffOptions, calculate_backoff, parse_ws_close_reason}; const STABLE_CONNECTION_MS: u64 = 60_000; @@ -29,15 +29,21 @@ async fn connection_loop(shared: Arc) { if let Some(reason) = &close_reason { if reason.group == "ws" && reason.error == "eviction" { tracing::debug!("connection evicted"); - let _ = shared.envoy_tx.send(ToEnvoyMessage::ConnClose { evict: true }); + let _ = shared + .envoy_tx + .send(ToEnvoyMessage::ConnClose { evict: true }); return; } } - let _ = shared.envoy_tx.send(ToEnvoyMessage::ConnClose { evict: false }); + let _ = shared + .envoy_tx + .send(ToEnvoyMessage::ConnClose { evict: false }); } Err(error) => { tracing::error!(?error, "connection failed"); - let _ = shared.envoy_tx.send(ToEnvoyMessage::ConnClose { evict: false }); + let _ = shared + .envoy_tx + .send(ToEnvoyMessage::ConnClose { evict: false }); } } @@ -107,17 +113,22 @@ async fn single_connection( } // Serialize metadata HashMap to JSON string for the protocol - let metadata_json = shared.config.metadata.as_ref().map(|m| { - serde_json::to_string(m).unwrap_or_else(|_| "{}".to_string()) - }); + let metadata_json = shared + .config + .metadata + .as_ref() + .map(|m| serde_json::to_string(m).unwrap_or_else(|_| "{}".to_string())); // Send init - ws_send(shared, protocol::ToRivet::ToRivetInit(protocol::ToRivetInit { - envoy_key: shared.envoy_key.clone(), - version: shared.config.version, - prepopulate_actor_names: Some(prepopulate_map), - metadata: metadata_json, - })) + ws_send( + shared, + protocol::ToRivet::ToRivetInit(protocol::ToRivetInit { + envoy_key: shared.envoy_key.clone(), + version: shared.config.version, + prepopulate_actor_names: Some(prepopulate_map), + metadata: metadata_json, + }), + ) .await; // Spawn write task @@ -154,8 +165,10 @@ async fn single_connection( Ok(tungstenite::Message::Binary(data)) => { crate::utils::inject_latency(debug_latency_ms).await; - let decoded = - crate::protocol::versioned::ToEnvoy::deserialize(&data, protocol::PROTOCOL_VERSION)?; + let decoded = crate::protocol::versioned::ToEnvoy::deserialize( + &data, + protocol::PROTOCOL_VERSION, + )?; if tracing::enabled!(tracing::Level::DEBUG) { tracing::debug!(data = stringify_to_envoy(&decoded), "received message"); @@ -223,10 +236,9 @@ pub async fn ws_send(shared: &SharedContext, message: protocol::ToRivet) -> bool return true; }; - let encoded = - crate::protocol::versioned::ToRivet::wrap_latest(message) - .serialize(protocol::PROTOCOL_VERSION) - .expect("failed to encode message"); + let encoded = crate::protocol::versioned::ToRivet::wrap_latest(message) + .serialize(protocol::PROTOCOL_VERSION) + .expect("failed to encode message"); let _ = tx.send(WsTxMessage::Send(encoded)); false } diff --git a/engine/sdks/rust/envoy-client/src/context.rs b/engine/sdks/rust/envoy-client/src/context.rs index f8a47a562f..f9d07c7dcf 100644 --- a/engine/sdks/rust/envoy-client/src/context.rs +++ b/engine/sdks/rust/envoy-client/src/context.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use std::sync::atomic::AtomicBool; use rivet_envoy_protocol as protocol; -use tokio::sync::mpsc; use tokio::sync::Mutex; +use tokio::sync::mpsc; use crate::config::EnvoyConfig; use crate::envoy::ToEnvoyMessage; diff --git a/engine/sdks/rust/envoy-client/src/envoy.rs b/engine/sdks/rust/envoy-client/src/envoy.rs index 2518fe38a1..1c15d414c9 100644 --- a/engine/sdks/rust/envoy-client/src/envoy.rs +++ b/engine/sdks/rust/envoy-client/src/envoy.rs @@ -6,19 +6,19 @@ use tokio::sync::mpsc; use tokio::sync::oneshot; use crate::actor::ToActor; -use crate::commands::{handle_commands, send_command_ack, ACK_COMMANDS_INTERVAL_MS}; +use crate::commands::{ACK_COMMANDS_INTERVAL_MS, handle_commands, send_command_ack}; use crate::config::EnvoyConfig; use crate::connection::{start_connection, ws_send}; use crate::context::{SharedContext, WsTxMessage}; use crate::events::{handle_ack_events, handle_send_events, resend_unacknowledged_events}; use crate::handle::EnvoyHandle; use crate::kv::{ - cleanup_old_kv_requests, handle_kv_request, handle_kv_response, process_unsent_kv_requests, - KvRequestEntry, KV_CLEANUP_INTERVAL_MS, + KV_CLEANUP_INTERVAL_MS, KvRequestEntry, cleanup_old_kv_requests, handle_kv_request, + handle_kv_response, process_unsent_kv_requests, }; use crate::tunnel::{ - handle_tunnel_message, resend_buffered_tunnel_messages, send_hibernatable_ws_message_ack, - HibernatingWebSocketMetadata, + HibernatingWebSocketMetadata, handle_tunnel_message, resend_buffered_tunnel_messages, + send_hibernatable_ws_message_ack, }; use crate::utils::{BufferMap, EnvoyShutdownError}; @@ -94,11 +94,7 @@ pub struct ActorInfo { } impl EnvoyContext { - pub fn get_actor( - &self, - actor_id: &str, - generation: Option, - ) -> Option<&ActorEntry> { + pub fn get_actor(&self, actor_id: &str, generation: Option) -> Option<&ActorEntry> { let gens = self.actors.get(actor_id)?; if gens.is_empty() { return None; @@ -303,7 +299,9 @@ async fn envoy_loop( } for (_id, request) in ctx.kv_requests.drain() { - let _ = request.response_tx.send(Err(anyhow::anyhow!("envoy shutting down"))); + let _ = request + .response_tx + .send(Err(anyhow::anyhow!("envoy shutting down"))); } ctx.actors.clear(); @@ -372,7 +370,9 @@ fn handle_conn_close( tracing::debug!(ms = lost_threshold, "starting envoy lost timeout"); - Some(Box::pin(tokio::time::sleep(std::time::Duration::from_millis(lost_threshold)))) + Some(Box::pin(tokio::time::sleep( + std::time::Duration::from_millis(lost_threshold), + ))) } async fn handle_shutdown(ctx: &mut EnvoyContext) { @@ -383,16 +383,13 @@ async fn handle_shutdown(ctx: &mut EnvoyContext) { tracing::debug!("envoy received shutdown"); - ws_send( - &ctx.shared, - protocol::ToRivet::ToRivetStopping, - ) - .await; + ws_send(&ctx.shared, protocol::ToRivet::ToRivetStopping).await; // Check if any actors are still active - let has_actors = ctx.actors.values().any(|gens| { - gens.values().any(|entry| !entry.handle.is_closed()) - }); + let has_actors = ctx + .actors + .values() + .any(|gens| gens.values().any(|entry| !entry.handle.is_closed())); if !has_actors { let _ = ctx.shared.envoy_tx.send(ToEnvoyMessage::Stop); diff --git a/engine/sdks/rust/envoy-client/src/events.rs b/engine/sdks/rust/envoy-client/src/events.rs index 2d0a718bfe..c19cc649b3 100644 --- a/engine/sdks/rust/envoy-client/src/events.rs +++ b/engine/sdks/rust/envoy-client/src/events.rs @@ -6,10 +6,8 @@ use crate::envoy::EnvoyContext; pub async fn handle_send_events(ctx: &mut EnvoyContext, events: Vec) { // Record in history per actor for event in &events { - let entry = ctx.get_actor_entry_mut( - &event.checkpoint.actor_id, - event.checkpoint.generation, - ); + let entry = + ctx.get_actor_entry_mut(&event.checkpoint.actor_id, event.checkpoint.generation); if let Some(entry) = entry { entry.event_history.push(event.clone()); @@ -26,19 +24,16 @@ pub async fn handle_send_events(ctx: &mut EnvoyContext, events: Vec checkpoint.index); + entry + .event_history + .retain(|event| event.checkpoint.index > checkpoint.index); // Clean up fully acked stopped actors if entry.event_history.is_empty() && entry.handle.is_closed() { @@ -82,9 +77,5 @@ pub async fn resend_unacknowledged_events(ctx: &EnvoyContext) { tracing::info!(count = events.len(), "resending unacknowledged events"); - ws_send( - &ctx.shared, - protocol::ToRivet::ToRivetEvents(events), - ) - .await; + ws_send(&ctx.shared, protocol::ToRivet::ToRivetEvents(events)).await; } diff --git a/engine/sdks/rust/envoy-client/src/handle.rs b/engine/sdks/rust/envoy-client/src/handle.rs index e45a4569f0..867094de5d 100644 --- a/engine/sdks/rust/envoy-client/src/handle.rs +++ b/engine/sdks/rust/envoy-client/src/handle.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use rivet_envoy_protocol as protocol; -use crate::envoy::{ActorInfo, ToEnvoyMessage}; use crate::context::SharedContext; +use crate::envoy::{ActorInfo, ToEnvoyMessage}; use crate::tunnel::HibernatingWebSocketMetadata; /// Handle for interacting with the envoy from callbacks. @@ -66,11 +66,7 @@ impl EnvoyHandle { }); } - pub async fn get_actor( - &self, - actor_id: &str, - generation: Option, - ) -> Option { + pub async fn get_actor(&self, actor_id: &str, generation: Option) -> Option { let (tx, rx) = tokio::sync::oneshot::channel(); self.shared .envoy_tx @@ -253,10 +249,7 @@ impl EnvoyHandle { pub async fn kv_drop(&self, actor_id: String) -> anyhow::Result<()> { let response = self - .send_kv_request( - actor_id, - protocol::KvRequestData::KvDropRequest, - ) + .send_kv_request(actor_id, protocol::KvRequestData::KvDropRequest) .await?; match response { protocol::KvResponseData::KvDropResponse => Ok(()), @@ -306,8 +299,7 @@ impl EnvoyHandle { ); } - let message = - crate::protocol::versioned::ToEnvoy::deserialize(&payload[2..], version)?; + let message = crate::protocol::versioned::ToEnvoy::deserialize(&payload[2..], version)?; let protocol::ToEnvoy::ToEnvoyCommands(ref commands) = message else { anyhow::bail!("invalid serverless payload: expected ToEnvoyCommands"); @@ -348,7 +340,8 @@ impl EnvoyHandle { response_tx: tx, }) .map_err(|_| anyhow::anyhow!("envoy channel closed"))?; - rx.await.map_err(|_| anyhow::anyhow!("kv response channel closed"))? + rx.await + .map_err(|_| anyhow::anyhow!("kv response channel closed"))? } } diff --git a/engine/sdks/rust/envoy-client/src/stringify.rs b/engine/sdks/rust/envoy-client/src/stringify.rs index 2be50c8451..b20821682e 100644 --- a/engine/sdks/rust/envoy-client/src/stringify.rs +++ b/engine/sdks/rust/envoy-client/src/stringify.rs @@ -8,7 +8,10 @@ fn stringify_bytes(data: &[u8]) -> String { } fn stringify_map(map: &HashableMap) -> String { - let entries: Vec = map.iter().map(|(k, v)| format!("\"{k}\": \"{v}\"")).collect(); + let entries: Vec = map + .iter() + .map(|(k, v)| format!("\"{k}\": \"{v}\"")) + .collect(); format!("Map({}){{{}}}", map.len(), entries.join(", ")) } @@ -47,7 +50,10 @@ pub fn stringify_to_rivet_tunnel_message_kind(kind: &protocol::ToRivetTunnelMess "ToRivetResponseAbort".to_string() } protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(val) => { - format!("ToRivetWebSocketOpen{{canHibernate: {}}}", val.can_hibernate) + format!( + "ToRivetWebSocketOpen{{canHibernate: {}}}", + val.can_hibernate + ) } protocol::ToRivetTunnelMessageKind::ToRivetWebSocketMessage(val) => { format!( @@ -76,9 +82,7 @@ pub fn stringify_to_rivet_tunnel_message_kind(kind: &protocol::ToRivetTunnelMess } } -pub fn stringify_to_envoy_tunnel_message_kind( - kind: &protocol::ToEnvoyTunnelMessageKind, -) -> String { +pub fn stringify_to_envoy_tunnel_message_kind(kind: &protocol::ToEnvoyTunnelMessageKind) -> String { match kind { protocol::ToEnvoyTunnelMessageKind::ToEnvoyRequestStart(val) => { let body_str = match &val.body { @@ -87,7 +91,12 @@ pub fn stringify_to_envoy_tunnel_message_kind( }; format!( "ToEnvoyRequestStart{{actorId: \"{}\", method: \"{}\", path: \"{}\", headers: {}, body: {}, stream: {}}}", - val.actor_id, val.method, val.path, stringify_map(&val.headers), body_str, val.stream + val.actor_id, + val.method, + val.path, + stringify_map(&val.headers), + body_str, + val.stream ) } protocol::ToEnvoyTunnelMessageKind::ToEnvoyRequestChunk(val) => { @@ -194,7 +203,10 @@ pub fn stringify_event(event: &protocol::Event) -> String { Some(m) => format!("\"{m}\""), None => "null".to_string(), }; - format!("Stopped{{code: {:?}, message: {message_str}}}", stopped.code) + format!( + "Stopped{{code: {:?}, message: {message_str}}}", + stopped.code + ) } }; format!("EventActorStateUpdate{{state: {state_str}}}") @@ -228,8 +240,7 @@ pub fn stringify_to_rivet(message: &protocol::ToRivet) -> String { ) } protocol::ToRivet::ToRivetEvents(events) => { - let event_strs: Vec = - events.iter().map(stringify_event_wrapper).collect(); + let event_strs: Vec = events.iter().map(stringify_event_wrapper).collect(); format!( "ToRivetEvents{{count: {}, events: [{}]}}", events.len(), @@ -240,14 +251,12 @@ pub fn stringify_to_rivet(message: &protocol::ToRivet) -> String { let checkpoints: Vec = val .last_command_checkpoints .iter() - .map(|cp| { - format!( - "{{actorId: \"{}\", index: {}}}", - cp.actor_id, cp.index - ) - }) + .map(|cp| format!("{{actorId: \"{}\", index: {}}}", cp.actor_id, cp.index)) .collect(); - format!("ToRivetAckCommands{{lastCommandCheckpoints: [{}]}}", checkpoints.join(", ")) + format!( + "ToRivetAckCommands{{lastCommandCheckpoints: [{}]}}", + checkpoints.join(", ") + ) } protocol::ToRivet::ToRivetStopping => "ToRivetStopping".to_string(), protocol::ToRivet::ToRivetPong(val) => { @@ -278,10 +287,7 @@ pub fn stringify_to_envoy(message: &protocol::ToEnvoy) -> String { ) } protocol::ToEnvoy::ToEnvoyCommands(commands) => { - let cmd_strs: Vec = commands - .iter() - .map(stringify_command_wrapper) - .collect(); + let cmd_strs: Vec = commands.iter().map(stringify_command_wrapper).collect(); format!( "ToEnvoyCommands{{count: {}, commands: [{}]}}", commands.len(), @@ -292,12 +298,7 @@ pub fn stringify_to_envoy(message: &protocol::ToEnvoy) -> String { let checkpoints: Vec = val .last_event_checkpoints .iter() - .map(|cp| { - format!( - "{{actorId: \"{}\", index: {}}}", - cp.actor_id, cp.index - ) - }) + .map(|cp| format!("{{actorId: \"{}\", index: {}}}", cp.actor_id, cp.index)) .collect(); format!( "ToEnvoyAckEvents{{lastEventCheckpoints: [{}]}}", diff --git a/engine/sdks/rust/envoy-client/src/tunnel.rs b/engine/sdks/rust/envoy-client/src/tunnel.rs index 633d6112a0..1405bc343f 100644 --- a/engine/sdks/rust/envoy-client/src/tunnel.rs +++ b/engine/sdks/rust/envoy-client/src/tunnel.rs @@ -56,7 +56,9 @@ async fn handle_request_start( ); let actor = ctx.get_actor(&actor_id, None).unwrap(); - let _ = actor.handle.send(crate::actor::ToActor::ReqStart { message_id, req }); + let _ = actor + .handle + .send(crate::actor::ToActor::ReqStart { message_id, req }); } fn handle_request_chunk( @@ -137,7 +139,11 @@ async fn handle_ws_open( ); // Convert HashableMap headers to BTreeMap for the actor message - let headers = open.headers.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + let headers = open + .headers + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); let actor = ctx.get_actor(&actor_id, None).unwrap(); let _ = actor.handle.send(crate::actor::ToActor::WsOpen { @@ -158,10 +164,9 @@ fn handle_ws_message( .cloned(); if let Some(actor_id) = &actor_id { if let Some(actor) = ctx.get_actor(actor_id, None) { - let _ = actor.handle.send(crate::actor::ToActor::WsMsg { - message_id, - msg, - }); + let _ = actor + .handle + .send(crate::actor::ToActor::WsMsg { message_id, msg }); } } } @@ -214,15 +219,14 @@ pub async fn resend_buffered_tunnel_messages(ctx: &mut EnvoyContext) { return; } - tracing::info!(count = ctx.buffered_messages.len(), "resending buffered tunnel messages"); + tracing::info!( + count = ctx.buffered_messages.len(), + "resending buffered tunnel messages" + ); let messages = std::mem::take(&mut ctx.buffered_messages); for msg in messages { - ws_send( - &ctx.shared, - protocol::ToRivet::ToRivetTunnelMessage(msg), - ) - .await; + ws_send(&ctx.shared, protocol::ToRivet::ToRivetTunnelMessage(msg)).await; } } @@ -233,7 +237,10 @@ async fn send_error_response( ) { let body = b"Actor not found".to_vec(); let mut headers = rivet_util::serde::HashableMap::new(); - headers.insert("x-rivet-error".to_string(), "envoy.actor_not_found".to_string()); + headers.insert( + "x-rivet-error".to_string(), + "envoy.actor_not_found".to_string(), + ); headers.insert("content-length".to_string(), body.len().to_string()); ws_send( diff --git a/engine/sdks/rust/envoy-protocol/Cargo.toml b/engine/sdks/rust/envoy-protocol/Cargo.toml index b20e73d5c2..f2137b94c8 100644 --- a/engine/sdks/rust/envoy-protocol/Cargo.toml +++ b/engine/sdks/rust/envoy-protocol/Cargo.toml @@ -7,7 +7,6 @@ edition.workspace = true [dependencies] anyhow.workspace = true -gas.workspace = true hex.workspace = true rand.workspace = true rivet-util.workspace = true diff --git a/engine/sdks/rust/test-envoy/src/server.rs b/engine/sdks/rust/test-envoy/src/server.rs index f37d135495..b8cd36fa03 100644 --- a/engine/sdks/rust/test-envoy/src/server.rs +++ b/engine/sdks/rust/test-envoy/src/server.rs @@ -6,7 +6,10 @@ use axum::{ Router, body::Bytes, extract::State, - response::{IntoResponse, Json, Sse, sse::{Event, KeepAlive}}, + response::{ + IntoResponse, Json, Sse, + sse::{Event, KeepAlive}, + }, routing::{get, post}, }; use rivet_envoy_protocol as protocol; @@ -110,7 +113,10 @@ async fn run_http_server(state: AppState) -> Result<()> { .await .with_context(|| format!("failed to bind {addr}"))?; - tracing::info!(port = state.settings.internal_server_port, "internal http server listening"); + tracing::info!( + port = state.settings.internal_server_port, + "internal http server listening" + ); axum::serve(listener, app) .await @@ -136,10 +142,7 @@ async fn metadata() -> Json { })) } -async fn start_serverless( - State(state): State, - body: Bytes, -) -> impl IntoResponse { +async fn start_serverless(State(state): State, body: Bytes) -> impl IntoResponse { tracing::info!("received serverless start request"); let handle = match create_envoy(&state.settings).await { @@ -184,7 +187,6 @@ async fn create_envoy(settings: &Settings) -> Result { prepopulate_actor_names: std::collections::HashMap::new(), metadata: None, envoy_key: None, - auto_restart: false, debug_latency_ms: None, callbacks: Arc::new(DefaultTestCallbacks), }; diff --git a/rivetkit-typescript/packages/rivetkit-native/index.d.ts b/rivetkit-typescript/packages/rivetkit-native/index.d.ts index 21ae47f8e6..4800224f74 100644 --- a/rivetkit-typescript/packages/rivetkit-native/index.d.ts +++ b/rivetkit-typescript/packages/rivetkit-native/index.d.ts @@ -13,6 +13,11 @@ export interface JsEnvoyConfig { poolName: string version: number metadata?: any + /** + * Log level for the Rust tracing subscriber (e.g. "trace", "debug", "info", "warn", "error"). + * Falls back to RIVET_LOG_LEVEL, then LOG_LEVEL, then RUST_LOG env vars. Defaults to "warn". + */ + logLevel?: string } /** Options for KV list operations. */ export interface JsKvListOptions { @@ -60,5 +65,9 @@ export declare class JsEnvoyHandle { restoreHibernatingRequests(actorId: string, requests: Array): void sendHibernatableWebSocketMessageAck(gatewayId: Buffer, requestId: Buffer, clientMessageIndex: number): void startServerless(payload: Buffer): Promise + /** Send a message on an open WebSocket connection. */ + sendWsMessage(gatewayId: Buffer, requestId: Buffer, data: Buffer, binary: boolean): void + /** Close an open WebSocket connection. */ + closeWebsocket(gatewayId: Buffer, requestId: Buffer, code?: number | undefined | null, reason?: string | undefined | null): void respondCallback(responseId: string, data: any): Promise } diff --git a/rivetkit-typescript/packages/rivetkit-native/src/bridge_actor.rs b/rivetkit-typescript/packages/rivetkit-native/src/bridge_actor.rs index 9894a6583d..f96e0b1a65 100644 --- a/rivetkit-typescript/packages/rivetkit-native/src/bridge_actor.rs +++ b/rivetkit-typescript/packages/rivetkit-native/src/bridge_actor.rs @@ -13,28 +13,41 @@ use tokio::sync::{Mutex, oneshot}; use crate::types; /// Type alias for the threadsafe event callback function. -pub type EventCallback = - napi::threadsafe_function::ThreadsafeFunction; +pub type EventCallback = napi::threadsafe_function::ThreadsafeFunction< + serde_json::Value, + napi::threadsafe_function::ErrorStrategy::Fatal, +>; /// Map of pending callback response channels, keyed by response ID. pub type ResponseMap = Arc>>>; -/// Map of WebSocket senders, keyed by hex-encoded messageId. -pub type WsSenderMap = Arc>>; +/// Map of open WebSocket senders, keyed by concatenated gateway_id + request_id (8 bytes). +pub type WsSenderMap = Arc>>; + +fn make_ws_key(gateway_id: &protocol::GatewayId, request_id: &protocol::RequestId) -> [u8; 8] { + let mut key = [0u8; 8]; + key[..4].copy_from_slice(gateway_id); + key[4..].copy_from_slice(request_id); + key +} /// Callbacks implementation that bridges envoy events to JavaScript via N-API. pub struct BridgeCallbacks { event_cb: EventCallback, response_map: ResponseMap, - pub ws_senders: WsSenderMap, + ws_sender_map: WsSenderMap, } impl BridgeCallbacks { - pub fn new(event_cb: EventCallback, response_map: ResponseMap) -> Self { + pub fn new( + event_cb: EventCallback, + response_map: ResponseMap, + ws_sender_map: WsSenderMap, + ) -> Self { Self { event_cb, response_map, - ws_senders: Arc::new(Mutex::new(HashMap::new())), + ws_sender_map, } } @@ -206,10 +219,10 @@ impl EnvoyCallbacks for BridgeCallbacks { headers: HashMap, _is_hibernatable: bool, _is_restoring_hibernatable: bool, - sender: WebSocketSender, + _sender: WebSocketSender, ) -> BoxFuture> { let event_cb = self.event_cb.clone(); - let ws_senders = self.ws_senders.clone(); + let ws_sender_map = self.ws_sender_map.clone(); Box::pin(async move { let msg_id = protocol::MessageId { @@ -217,30 +230,19 @@ impl EnvoyCallbacks for BridgeCallbacks { request_id, message_index: 0, }; - let msg_id_hex = hex::encode(types::encode_message_id(&msg_id)); - - // Store the sender so JS can call ws.send() via the native handle - { - let mut senders = ws_senders.lock().await; - senders.insert(msg_id_hex.clone(), sender); - } + let msg_id_bytes = types::encode_message_id(&msg_id); - let envelope = serde_json::json!({ - "kind": "websocket_open", - "actorId": actor_id, - "messageId": types::encode_message_id(&msg_id), - "messageIdHex": msg_id_hex, - "path": path, - "headers": headers, - }); - event_cb.call(envelope, ThreadsafeFunctionCallMode::NonBlocking); + let ws_key = make_ws_key(&gateway_id, &request_id); + let event_cb_open = event_cb.clone(); let event_cb_msg = event_cb.clone(); let event_cb_close = event_cb.clone(); + let actor_id_open = actor_id.clone(); let actor_id_msg = actor_id.clone(); let actor_id_close = actor_id; - let ws_senders_close = ws_senders.clone(); - let msg_id_hex_close = msg_id_hex; + let msg_id_bytes_close = msg_id_bytes.clone(); + let ws_sender_map_open = ws_sender_map.clone(); + let ws_sender_map_close = ws_sender_map.clone(); Ok(WebSocketHandler { on_message: Box::new(move |msg: WebSocketMessage| { @@ -260,21 +262,38 @@ impl EnvoyCallbacks for BridgeCallbacks { Box::pin(async {}) }), on_close: Box::new(move |code, reason| { - let ws_senders = ws_senders_close.clone(); - let msg_id_hex = msg_id_hex_close.clone(); let envelope = serde_json::json!({ "kind": "websocket_close", "actorId": actor_id_close, + "messageId": msg_id_bytes_close, "code": code, "reason": reason, }); event_cb_close.call(envelope, ThreadsafeFunctionCallMode::NonBlocking); + + let ws_sender_map_close = ws_sender_map_close.clone(); Box::pin(async move { - let mut senders = ws_senders.lock().await; - senders.remove(&msg_id_hex); + let mut senders = ws_sender_map_close.lock().await; + senders.remove(&ws_key); }) }), - on_open: None, + // on_open fires the websocket_open event only after the sender is stored, + // guaranteeing that ws.send() works as soon as JS receives the event. + on_open: Some(Box::new(move |sender: WebSocketSender| { + let envelope = serde_json::json!({ + "kind": "websocket_open", + "actorId": actor_id_open, + "messageId": msg_id_bytes, + "path": path, + "headers": headers, + }); + event_cb_open.call(envelope, ThreadsafeFunctionCallMode::NonBlocking); + + Box::pin(async move { + let mut senders = ws_sender_map_open.lock().await; + senders.insert(ws_key, sender); + }) + })), }) }) } diff --git a/rivetkit-typescript/packages/rivetkit-native/src/database.rs b/rivetkit-typescript/packages/rivetkit-native/src/database.rs index e1cb8549ae..d1941ee620 100644 --- a/rivetkit-typescript/packages/rivetkit-native/src/database.rs +++ b/rivetkit-typescript/packages/rivetkit-native/src/database.rs @@ -70,11 +70,7 @@ impl SqliteKv for EnvoyKv { .map_err(|e| SqliteKvError::new(e.to_string())) } - async fn batch_delete( - &self, - _actor_id: &str, - keys: Vec>, - ) -> Result<(), SqliteKvError> { + async fn batch_delete(&self, _actor_id: &str, keys: Vec>) -> Result<(), SqliteKvError> { self.handle .kv_delete(self.actor_id.clone(), keys) .await @@ -92,7 +88,6 @@ impl SqliteKv for EnvoyKv { .await .map_err(|e| SqliteKvError::new(e.to_string())) } - } /// Native SQLite database handle exposed to JavaScript. diff --git a/rivetkit-typescript/packages/rivetkit-native/src/envoy_handle.rs b/rivetkit-typescript/packages/rivetkit-native/src/envoy_handle.rs index b22c20eb1b..3469901efe 100644 --- a/rivetkit-typescript/packages/rivetkit-native/src/envoy_handle.rs +++ b/rivetkit-typescript/packages/rivetkit-native/src/envoy_handle.rs @@ -9,13 +9,24 @@ use tokio::runtime::Runtime; use crate::bridge_actor::{ResponseMap, WsSenderMap}; use crate::types::{self, JsKvEntry, JsKvListOptions}; +fn make_ws_key(gateway_id: &[u8], request_id: &[u8]) -> [u8; 8] { + let mut key = [0u8; 8]; + if gateway_id.len() >= 4 { + key[..4].copy_from_slice(&gateway_id[..4]); + } + if request_id.len() >= 4 { + key[4..].copy_from_slice(&request_id[..4]); + } + key +} + /// Native envoy handle exposed to JavaScript via N-API. #[napi] pub struct JsEnvoyHandle { pub(crate) runtime: Arc, pub(crate) handle: EnvoyHandle, pub(crate) response_map: ResponseMap, - pub(crate) ws_senders: WsSenderMap, + pub(crate) ws_sender_map: WsSenderMap, } impl JsEnvoyHandle { @@ -23,13 +34,13 @@ impl JsEnvoyHandle { runtime: Arc, handle: EnvoyHandle, response_map: ResponseMap, - ws_senders: WsSenderMap, + ws_sender_map: WsSenderMap, ) -> Self { Self { runtime, handle, response_map, - ws_senders, + ws_sender_map, } } } @@ -214,7 +225,11 @@ impl JsEnvoyHandle { let limit = options.as_ref().and_then(|o| o.limit).map(|l| l as u64); let result = self .runtime - .spawn(async move { handle.kv_list_prefix(actor_id, prefix_vec, reverse, limit).await }) + .spawn(async move { + handle + .kv_list_prefix(actor_id, prefix_vec, reverse, limit) + .await + }) .await .map_err(|e| napi::Error::from_reason(e.to_string()))? .map_err(|e| napi::Error::from_reason(e.to_string()))?; @@ -300,22 +315,40 @@ impl JsEnvoyHandle { #[napi] pub async fn send_ws_message( &self, - message_id_hex: String, + gateway_id: Buffer, + request_id: Buffer, data: Buffer, binary: bool, ) -> napi::Result<()> { - let senders = self.ws_senders.lock().await; - if let Some(sender) = senders.get(&message_id_hex) { + let key = make_ws_key(&gateway_id, &request_id); + let map = self.ws_sender_map.lock().await; + if let Some(sender) = map.get(&key) { sender.send(data.to_vec(), binary); Ok(()) } else { Err(napi::Error::from_reason(format!( - "no WebSocket sender for {}", - message_id_hex + "no WebSocket sender for {:?}", + key ))) } } + /// Close an open WebSocket connection. + #[napi] + pub async fn close_websocket( + &self, + gateway_id: Buffer, + request_id: Buffer, + code: Option, + reason: Option, + ) { + let key = make_ws_key(&gateway_id, &request_id); + let mut map = self.ws_sender_map.lock().await; + if let Some(sender) = map.remove(&key) { + sender.close(code.map(|c| c as u16), reason); + } + } + // -- Serverless -- #[napi] diff --git a/rivetkit-typescript/packages/rivetkit-native/src/lib.rs b/rivetkit-typescript/packages/rivetkit-native/src/lib.rs index e8f7c57742..4457b7ebba 100644 --- a/rivetkit-typescript/packages/rivetkit-native/src/lib.rs +++ b/rivetkit-typescript/packages/rivetkit-native/src/lib.rs @@ -32,7 +32,7 @@ fn init_tracing(log_level: Option<&str>) { }); } -use crate::bridge_actor::{BridgeCallbacks, ResponseMap}; +use crate::bridge_actor::{BridgeCallbacks, ResponseMap, WsSenderMap}; use crate::envoy_handle::JsEnvoyHandle; use crate::types::JsEnvoyConfig; @@ -52,17 +52,23 @@ pub fn start_envoy_sync_js( let runtime = Arc::new(runtime); let response_map: ResponseMap = Arc::new(tokio::sync::Mutex::new(HashMap::new())); + let ws_sender_map: WsSenderMap = Arc::new(tokio::sync::Mutex::new(HashMap::new())); // Create threadsafe callback for bridging events to JS - let tsfn: bridge_actor::EventCallback = event_callback - .create_threadsafe_function(0, |ctx: napi::threadsafe_function::ThreadSafeCallContext| { + let tsfn: bridge_actor::EventCallback = event_callback.create_threadsafe_function( + 0, + |ctx: napi::threadsafe_function::ThreadSafeCallContext| { let env = ctx.env; let value = env.to_js_value(&ctx.value)?; Ok(vec![value]) - })?; + }, + )?; - let callbacks = Arc::new(BridgeCallbacks::new(tsfn.clone(), response_map.clone())); - let ws_senders = callbacks.ws_senders.clone(); + let callbacks = Arc::new(BridgeCallbacks::new( + tsfn.clone(), + response_map.clone(), + ws_sender_map.clone(), + )); let metadata: Option> = config.metadata.and_then(|v| { if let serde_json::Value::Object(map) = v { @@ -81,7 +87,6 @@ pub fn start_envoy_sync_js( prepopulate_actor_names: HashMap::new(), metadata, envoy_key: None, - auto_restart: false, debug_latency_ms: None, callbacks, }; @@ -89,7 +94,12 @@ pub fn start_envoy_sync_js( let _guard = runtime.enter(); let handle = start_envoy_sync(envoy_config); - Ok(JsEnvoyHandle::new(runtime, handle, response_map, ws_senders)) + Ok(JsEnvoyHandle::new( + runtime, + handle, + response_map, + ws_sender_map, + )) } /// Start the native envoy client asynchronously. diff --git a/rivetkit-typescript/packages/rivetkit-native/wrapper.js b/rivetkit-typescript/packages/rivetkit-native/wrapper.js index 446aaa09ce..69a7499829 100644 --- a/rivetkit-typescript/packages/rivetkit-native/wrapper.js +++ b/rivetkit-typescript/packages/rivetkit-native/wrapper.js @@ -270,8 +270,7 @@ function handleEvent(event, config, wrappedHandle) { const messageId = Buffer.from(event.messageId); const gatewayId = messageId.subarray(0, 4); const requestId = messageId.subarray(4, 8); - // Use the hex key from Rust (matches the ws_senders map key) - const messageIdHex = event.messageIdHex || messageId.toString("hex"); + const wsIdHex = gatewayId.toString("hex") + requestId.toString("hex"); const headers = new Headers(event.headers || {}); headers.set("Upgrade", "websocket"); @@ -292,18 +291,31 @@ function handleEvent(event, config, wrappedHandle) { readyState: { value: OPEN, writable: true }, OPEN: { value: OPEN }, CLOSED: { value: CLOSED }, - send: { value: (data) => { - if (handle._raw && messageIdHex) { - const binary = data instanceof Buffer || data instanceof Uint8Array || data instanceof ArrayBuffer; - const buf = Buffer.from(data); - handle._raw.sendWsMessage(messageIdHex, buf, binary).catch((e) => { - console.error("ws.send error:", e.message); - }); + send: { + value: (data) => { + if (handle._raw) { + const isBinary = + data instanceof ArrayBuffer || ArrayBuffer.isView(data); + const bytes = isBinary + ? Buffer.from(data instanceof ArrayBuffer ? data : data.buffer, data instanceof ArrayBuffer ? 0 : data.byteOffset, data instanceof ArrayBuffer ? data.byteLength : data.byteLength) + : Buffer.from(String(data)); + handle._raw.sendWsMessage(gatewayId, requestId, bytes, isBinary); + } + } + }, + close: { + value: (code, reason) => { + ws.readyState = CLOSED; + if (handle._raw) { + handle._raw.closeWebsocket( + gatewayId, + requestId, + code != null ? code : undefined, + reason != null ? String(reason) : undefined, + ); + } } - }}, - close: { value: (code, reason) => { - ws.readyState = CLOSED; - }}, + }, addEventListener: { value: target.addEventListener.bind(target) }, removeEventListener: { value: target.removeEventListener.bind(target) }, dispatchEvent: { value: target.dispatchEvent.bind(target) }, @@ -311,18 +323,18 @@ function handleEvent(event, config, wrappedHandle) { // Store the ws object so websocket_message/close events can dispatch to it if (!handle._wsMap) handle._wsMap = new Map(); - handle._wsMap.set(messageIdHex, ws); + handle._wsMap.set(wsIdHex, ws); const canHibernate = config.hibernatableWebSocket ? config.hibernatableWebSocket.canHibernate( - event.actorId, - gatewayId, - requestId, - request, - ) + event.actorId, + gatewayId, + requestId, + request, + ) : false; - console.log("[wrapper] websocket_open actorId:", event.actorId?.slice(0,12), "path:", event.path); + console.log("[wrapper] websocket_open actorId:", event.actorId?.slice(0, 12), "path:", event.path); Promise.resolve( config.websocket( handle, @@ -349,8 +361,12 @@ function handleEvent(event, config, wrappedHandle) { case "websocket_message": { if (handle._wsMap && event.messageId) { const messageId = Buffer.from(event.messageId); - const messageIdHex = messageId.toString("hex"); - const ws = handle._wsMap.get(messageIdHex); + const gatewayId = messageId.subarray(0, 4); + const requestId = messageId.subarray(4, 8); + const wsIdHex = gatewayId.toString("hex") + requestId.toString("hex"); + + const ws = handle._wsMap.get(wsIdHex); + if (ws) { const data = event.data ? (event.binary @@ -369,15 +385,18 @@ function handleEvent(event, config, wrappedHandle) { case "websocket_close": { if (handle._wsMap && event.messageId) { const messageId = Buffer.from(event.messageId); - const messageIdHex = messageId.toString("hex"); - const ws = handle._wsMap.get(messageIdHex); + const gatewayId = messageId.subarray(0, 4); + const requestId = messageId.subarray(4, 8); + const wsIdHex = gatewayId.toString("hex") + requestId.toString("hex"); + + const ws = handle._wsMap.get(wsIdHex); if (ws) { ws.readyState = 3; ws.dispatchEvent(new CloseEvent("close", { code: event.code || 1000, reason: event.reason || "", })); - handle._wsMap.delete(messageIdHex); + handle._wsMap.delete(wsIdHex); } } break; diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-loader.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-loader.ts index 97cb765681..0ea84c156d 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-loader.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-loader.ts @@ -129,6 +129,7 @@ export async function loadStaticActors(): Promise< Record > { const actors: Record = {}; + console.log(listActorFixtureFiles(), "WHATHATHTHATHTH"); for (const actorFixturePath of listActorFixtureFiles()) { actors[actorNameFromFilePath(actorFixturePath)] = await importActorDefinition(actorFixturePath); diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts index 4c852b12ad..b3a8b10b91 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts @@ -3,7 +3,6 @@ import type { registry as DriverTestRegistryType } from "./registry"; import { loadStaticActors } from "./registry-loader"; const use = await loadStaticActors(); - export const registry = setup({ use, }) as typeof DriverTestRegistryType; diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts index 6958d42368..5bc2d3b5f5 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts @@ -320,9 +320,9 @@ export const registry = setup({ stateZodCoercionActor, ...(agentOsTestActor ? { - // From agent-os.ts - agentOsTestActor, - } + // From agent-os.ts + agentOsTestActor, + } : {}), }, }); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index e9506a3cc9..a4dce1ca06 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -288,18 +288,18 @@ export function isStaticActorInstance( export type ExtractActorState = A extends ActorInstance - ? State - : never; + ? State + : never; export type ExtractActorConnParams = A extends ActorInstance - ? ConnParams - : never; + ? ConnParams + : never; export type ExtractActorConnState = A extends ActorInstance - ? ConnState - : never; + ? ConnState + : never; // MARK: - Main ActorInstance Class export class ActorInstance< @@ -311,8 +311,7 @@ export class ActorInstance< DB extends AnyDatabaseProvider, E extends EventSchemaConfig = Record, Q extends QueueSchemaConfig = Record, -> implements BaseActorInstance -{ +> implements BaseActorInstance { // MARK: - Core Properties actorContext: ActorContext; #config: ActorConfig; @@ -971,15 +970,15 @@ export class ActorInstance< // is intentional and safe. try { this.#abortController.abort(); - } catch {} + } catch { } // Wait for run handler to complete await this.#waitForRunHandler( this.overrides.runStopTimeout !== undefined ? Math.min( - this.#config.options.runStopTimeout, - this.overrides.runStopTimeout, - ) + this.#config.options.runStopTimeout, + this.overrides.runStopTimeout, + ) : this.#config.options.runStopTimeout, ); @@ -1044,7 +1043,7 @@ export class ActorInstance< try { this.#abortController.abort(); - } catch {} + } catch { } } finally { this.#shutdownComplete = true; await this.#cleanupDatabase(); @@ -1103,7 +1102,7 @@ export class ActorInstance< // modes. try { this.#abortController.abort(); - } catch {} + } catch { } const destroy = this.driver.startDestroy.bind( this.driver, @@ -1140,14 +1139,14 @@ export class ActorInstance< async processMessage( message: { body: - | { - tag: "ActionRequest"; - val: { id: bigint; name: string; args: unknown }; - } - | { - tag: "SubscriptionRequest"; - val: { eventName: string; subscribe: boolean }; - }; + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; }, conn: Conn, ) { @@ -1518,9 +1517,9 @@ export class ActorInstance< if (this.overrides.sleepGracePeriod !== undefined) { return this.#config.options.sleepGracePeriod !== undefined ? Math.min( - this.#config.options.sleepGracePeriod, - this.overrides.sleepGracePeriod, - ) + this.#config.options.sleepGracePeriod, + this.overrides.sleepGracePeriod, + ) : this.overrides.sleepGracePeriod; } @@ -1531,16 +1530,16 @@ export class ActorInstance< const effectiveOnSleepTimeout = this.overrides.onSleepTimeout !== undefined ? Math.min( - this.#config.options.onSleepTimeout, - this.overrides.onSleepTimeout, - ) + this.#config.options.onSleepTimeout, + this.overrides.onSleepTimeout, + ) : this.#config.options.onSleepTimeout; const effectiveWaitUntilTimeout = this.overrides.waitUntilTimeout !== undefined ? Math.min( - this.#config.options.waitUntilTimeout, - this.overrides.waitUntilTimeout, - ) + this.#config.options.waitUntilTimeout, + this.overrides.waitUntilTimeout, + ) : this.#config.options.waitUntilTimeout; const usesDefaultLegacyTimeouts = @@ -2020,10 +2019,10 @@ export class ActorInstance< result, this.overrides.onDestroyTimeout !== undefined ? Math.min( - this.#config.options - .onDestroyTimeout, - this.overrides.onDestroyTimeout, - ) + this.#config.options + .onDestroyTimeout, + this.overrides.onDestroyTimeout, + ) : this.#config.options.onDestroyTimeout, ); } @@ -2163,16 +2162,16 @@ export class ActorInstance< overrideRawDatabaseClient: this.driver .overrideRawDatabaseClient ? () => - this.driver.overrideRawDatabaseClient!( - this.#actorId, - ) + this.driver.overrideRawDatabaseClient!( + this.#actorId, + ) : undefined, overrideDrizzleDatabaseClient: this.driver .overrideDrizzleDatabaseClient ? () => - this.driver.overrideDrizzleDatabaseClient!( - this.#actorId, - ) + this.driver.overrideDrizzleDatabaseClient!( + this.#actorId, + ) : undefined, kv: { batchPut: (entries: [Uint8Array, Uint8Array][]) => diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts index 52615d3328..890c16eb00 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts @@ -180,6 +180,7 @@ export async function routeWebSocket( ); createdConn = conn; + // Create handler // // This must call actor.connectionManager.connectConn in onOpen. @@ -215,8 +216,8 @@ export async function routeWebSocket( onMessage: (_evt: { data: any }, ws: WSContext) => { ws.close(1011, "actor.not_loaded"); }, - onClose: (_event: any, _ws: WSContext) => {}, - onError: (_error: unknown) => {}, + onClose: (_event: any, _ws: WSContext) => { }, + onError: (_error: unknown) => { }, }; } } @@ -408,7 +409,7 @@ export async function handleRawWebSocket( // this is called synchronously within onOpen. actor.handleRawWebSocket(conn, ws, request); }, - onMessage: (_evt: any, _wsContext: any) => {}, + onMessage: (_evt: any, _wsContext: any) => { }, onClose: (evt: any, ws: any) => { // Resolve the close promise closePromiseResolvers.resolve(); @@ -422,7 +423,7 @@ export async function handleRawWebSocket( }); }); }, - onError: (_error: any, _ws: any) => {}, + onError: (_error: any, _ws: any) => { }, }; } diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts index 1e1586678d..2406c9dfc9 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts @@ -161,9 +161,9 @@ export class ActorHandleRaw { actorId ? { msg: "using direct actor gateway target", actorId } : { - msg: "using query gateway target for action", - query: this.#actorResolutionState, - }, + msg: "using query gateway target for action", + query: this.#actorResolutionState, + }, ); logger().debug({ @@ -185,10 +185,10 @@ export class ActorHandleRaw { [HEADER_ENCODING]: this.#encoding, ...(this.#params !== undefined ? { - [HEADER_CONN_PARAMS]: JSON.stringify( - this.#params, - ), - } + [HEADER_CONN_PARAMS]: JSON.stringify( + this.#params, + ), + } : {}), }, body: opts.args, diff --git a/rivetkit-typescript/packages/rivetkit/src/client/raw-utils.ts b/rivetkit-typescript/packages/rivetkit/src/client/raw-utils.ts index 3d2bbd0d25..dbcebfa4d8 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/raw-utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/raw-utils.ts @@ -67,9 +67,9 @@ export async function rawHttpFetch( "directId" in target ? { msg: "sending raw http request to actor", actorId: target.directId } : { - msg: "sending raw http request with actor query", - query: target, - }, + msg: "sending raw http request with actor query", + query: target, + }, ); // Build the URL with normalized path diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-websocket.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-websocket.ts index e6e462e3f3..87df9f5353 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-websocket.ts @@ -770,7 +770,7 @@ export function runRawWebSocketTests(driverTestConfig: DriverTestConfig) { type: "indexedAckProbe", payload: "x".repeat( HIBERNATABLE_WEBSOCKET_BUFFERED_MESSAGE_SIZE_THRESHOLD + - 8_000, + 8_000, ), }), ); diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 62c12f75fa..490150bb06 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -1452,6 +1452,7 @@ export class EngineActorDriver implements ActorDriver { if (isRawWebSocketPath) { attachMessageListener(); } + wsHandler.onOpen(event, wsContext); attachPostOpenListeners(); @@ -1667,10 +1668,10 @@ export class EngineActorDriver implements ActorDriver { // envoy can decide it before the actor has fully started. const actorName = "config" in actorInstance && - actorInstance.config && - typeof actorInstance.config === "object" && - "name" in actorInstance.config && - typeof actorInstance.config.name === "string" + actorInstance.config && + typeof actorInstance.config === "object" && + "name" in actorInstance.config && + typeof actorInstance.config.name === "string" ? actorInstance.config.name : this.#actors.get(actorId)?.actorName; invariant( diff --git a/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts b/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts index 498a6a74bc..33f353e722 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts +++ b/rivetkit-typescript/packages/rivetkit/tests/driver-engine.test.ts @@ -122,20 +122,6 @@ for (const registryVariant of getDriverRegistryVariants(__dirname)) { // Wait for envoy to connect await actorDriver.waitForReady(); - // Refresh metadata so the engine stores envoyProtocolVersion - // which enables v2 POST dispatch for serverless actors. - await fetch( - `${endpoint}/runner-configs/${poolName}/refresh-metadata?namespace=${namespace}`, - { - method: "POST", - headers: { "Content-Type": "application/json", Authorization: `Bearer ${token}` }, - body: JSON.stringify({}), - }, - ); - - // TODO(US-XXX): Remove this delay once the engine processes metadata synchronously - await new Promise((resolve) => setTimeout(resolve, 5000)); - return { rivetEngine: { endpoint, diff --git a/rivetkit-typescript/packages/rivetkit/tests/standalone-native-test.mts b/rivetkit-typescript/packages/rivetkit/tests/standalone-native-test.mts index 7161a23006..56648f5c18 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/standalone-native-test.mts +++ b/rivetkit-typescript/packages/rivetkit/tests/standalone-native-test.mts @@ -62,11 +62,15 @@ const port = (server.address() as any).port; // Point runner config at our serverless server await updateRunnerConfig(clientConfig, poolName, { - datacenters: { default: { serverless: { - url: `http://127.0.0.1:${port}`, - request_lifespan: 300, max_concurrent_actors: 10000, - slots_per_runner: 1, min_runners: 0, max_runners: 10000, - }}}, + datacenters: { + default: { + serverless: { + url: `http://127.0.0.1:${port}`, + request_lifespan: 300, max_concurrent_actors: 10000, + slots_per_runner: 1, min_runners: 0, max_runners: 10000, + } + } + }, }); await actorDriver.waitForReady(); @@ -116,12 +120,6 @@ try { // Connect const conn = handle.connect(); - await new Promise((resolve, reject) => { - const timeout = setTimeout(() => reject(new Error("connect timeout")), 10000); - conn.addEventListener("open", () => { clearTimeout(timeout); resolve(); }); - conn.addEventListener("error", (e: any) => { clearTimeout(timeout); reject(new Error(`error: ${e?.message}`)); }); - }); - ok("connected"); // Action through existing connection const val = await handle.increment(42); diff --git a/rivetkit-typescript/packages/sqlite-native/src/kv.rs b/rivetkit-typescript/packages/sqlite-native/src/kv.rs index 6fff75fc09..f2d3347647 100644 --- a/rivetkit-typescript/packages/sqlite-native/src/kv.rs +++ b/rivetkit-typescript/packages/sqlite-native/src/kv.rs @@ -38,7 +38,7 @@ pub const FILE_TAG_SHM: u8 = 0x03; /// /// Format: `[SQLITE_PREFIX, SCHEMA_VERSION, META_PREFIX, file_tag]` pub fn get_meta_key(file_tag: u8) -> [u8; 4] { - [SQLITE_PREFIX, SQLITE_SCHEMA_VERSION, META_PREFIX, file_tag] + [SQLITE_PREFIX, SQLITE_SCHEMA_VERSION, META_PREFIX, file_tag] } /// Returns the 8-byte chunk key for the given file tag and chunk index. @@ -47,17 +47,17 @@ pub fn get_meta_key(file_tag: u8) -> [u8; 4] { /// /// The chunk index is derived from byte offset as `offset / CHUNK_SIZE`. pub fn get_chunk_key(file_tag: u8, chunk_index: u32) -> [u8; 8] { - let ci = chunk_index.to_be_bytes(); - [ - SQLITE_PREFIX, - SQLITE_SCHEMA_VERSION, - CHUNK_PREFIX, - file_tag, - ci[0], - ci[1], - ci[2], - ci[3], - ] + let ci = chunk_index.to_be_bytes(); + [ + SQLITE_PREFIX, + SQLITE_SCHEMA_VERSION, + CHUNK_PREFIX, + file_tag, + ci[0], + ci[1], + ci[2], + ci[3], + ] } /// Maximum file size in bytes before chunk index overflow. @@ -75,126 +75,131 @@ pub const MAX_FILE_SIZE: u64 = (u32::MAX as u64 + 1) * CHUNK_SIZE as u64; /// This is shorter than a chunk key but lexicographically greater than any /// 8-byte chunk key with the same file_tag prefix. pub fn get_chunk_key_range_end(file_tag: u8) -> [u8; 4] { - [SQLITE_PREFIX, SQLITE_SCHEMA_VERSION, CHUNK_PREFIX, file_tag + 1] + [ + SQLITE_PREFIX, + SQLITE_SCHEMA_VERSION, + CHUNK_PREFIX, + file_tag + 1, + ] } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn constants_match_typescript() { - assert_eq!(CHUNK_SIZE, 4096); - assert_eq!(SQLITE_PREFIX, 8); - assert_eq!(SQLITE_SCHEMA_VERSION, 1); - assert_eq!(META_PREFIX, 0); - assert_eq!(CHUNK_PREFIX, 1); - assert_eq!(FILE_TAG_MAIN, 0); - assert_eq!(FILE_TAG_JOURNAL, 1); - assert_eq!(FILE_TAG_WAL, 2); - assert_eq!(FILE_TAG_SHM, 3); - } - - #[test] - fn meta_key_main() { - // TypeScript: getMetaKey(FILE_TAG_MAIN) => [8, 1, 0, 0] - assert_eq!(get_meta_key(FILE_TAG_MAIN), [0x08, 0x01, 0x00, 0x00]); - } - - #[test] - fn meta_key_journal() { - // TypeScript: getMetaKey(FILE_TAG_JOURNAL) => [8, 1, 0, 1] - assert_eq!(get_meta_key(FILE_TAG_JOURNAL), [0x08, 0x01, 0x00, 0x01]); - } - - #[test] - fn meta_key_wal() { - assert_eq!(get_meta_key(FILE_TAG_WAL), [0x08, 0x01, 0x00, 0x02]); - } - - #[test] - fn meta_key_shm() { - assert_eq!(get_meta_key(FILE_TAG_SHM), [0x08, 0x01, 0x00, 0x03]); - } - - #[test] - fn chunk_key_zero_index() { - // TypeScript: getChunkKey(FILE_TAG_MAIN, 0) => [8, 1, 1, 0, 0, 0, 0, 0] - assert_eq!( - get_chunk_key(FILE_TAG_MAIN, 0), - [0x08, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00] - ); - } - - #[test] - fn chunk_key_index_one() { - // TypeScript: getChunkKey(FILE_TAG_MAIN, 1) => [8, 1, 1, 0, 0, 0, 0, 1] - assert_eq!( - get_chunk_key(FILE_TAG_MAIN, 1), - [0x08, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01] - ); - } - - #[test] - fn chunk_key_large_index() { - // TypeScript: getChunkKey(FILE_TAG_MAIN, 256) => [8, 1, 1, 0, 0, 0, 1, 0] - assert_eq!( - get_chunk_key(FILE_TAG_MAIN, 256), - [0x08, 0x01, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00] - ); - } - - #[test] - fn chunk_key_max_index() { - // TypeScript: getChunkKey(FILE_TAG_MAIN, 0xFFFFFFFF) => [8, 1, 1, 0, 255, 255, 255, 255] - assert_eq!( - get_chunk_key(FILE_TAG_MAIN, u32::MAX), - [0x08, 0x01, 0x01, 0x00, 0xFF, 0xFF, 0xFF, 0xFF] - ); - } - - #[test] - fn chunk_key_journal_tag() { - assert_eq!( - get_chunk_key(FILE_TAG_JOURNAL, 42), - [0x08, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 42] - ); - } - - #[test] - fn chunk_key_big_endian_encoding() { - // 0x01020304 => bytes [1, 2, 3, 4] - assert_eq!( - get_chunk_key(FILE_TAG_MAIN, 0x01020304), - [0x08, 0x01, 0x01, 0x00, 0x01, 0x02, 0x03, 0x04] - ); - } - - #[test] - fn chunk_key_range_end_main() { - // TypeScript: getChunkKeyRangeEnd(FILE_TAG_MAIN) => [8, 1, 1, 1] - assert_eq!( - get_chunk_key_range_end(FILE_TAG_MAIN), - [0x08, 0x01, 0x01, 0x01] - ); - } - - #[test] - fn chunk_key_range_end_journal() { - // TypeScript: getChunkKeyRangeEnd(FILE_TAG_JOURNAL) => [8, 1, 1, 2] - assert_eq!( - get_chunk_key_range_end(FILE_TAG_JOURNAL), - [0x08, 0x01, 0x01, 0x02] - ); - } - - #[test] - fn range_end_is_past_all_chunk_keys() { - // The range end key must be lexicographically greater than any chunk key for the same tag. - let max_chunk = get_chunk_key(FILE_TAG_MAIN, u32::MAX); - let range_end = get_chunk_key_range_end(FILE_TAG_MAIN); - // Compare as slices. The range end [8,1,1,1] > [8,1,1,0,FF,FF,FF,FF] - // because at byte index 3, 1 > 0. - assert!(range_end.as_slice() > max_chunk.as_slice()); - } + use super::*; + + #[test] + fn constants_match_typescript() { + assert_eq!(CHUNK_SIZE, 4096); + assert_eq!(SQLITE_PREFIX, 8); + assert_eq!(SQLITE_SCHEMA_VERSION, 1); + assert_eq!(META_PREFIX, 0); + assert_eq!(CHUNK_PREFIX, 1); + assert_eq!(FILE_TAG_MAIN, 0); + assert_eq!(FILE_TAG_JOURNAL, 1); + assert_eq!(FILE_TAG_WAL, 2); + assert_eq!(FILE_TAG_SHM, 3); + } + + #[test] + fn meta_key_main() { + // TypeScript: getMetaKey(FILE_TAG_MAIN) => [8, 1, 0, 0] + assert_eq!(get_meta_key(FILE_TAG_MAIN), [0x08, 0x01, 0x00, 0x00]); + } + + #[test] + fn meta_key_journal() { + // TypeScript: getMetaKey(FILE_TAG_JOURNAL) => [8, 1, 0, 1] + assert_eq!(get_meta_key(FILE_TAG_JOURNAL), [0x08, 0x01, 0x00, 0x01]); + } + + #[test] + fn meta_key_wal() { + assert_eq!(get_meta_key(FILE_TAG_WAL), [0x08, 0x01, 0x00, 0x02]); + } + + #[test] + fn meta_key_shm() { + assert_eq!(get_meta_key(FILE_TAG_SHM), [0x08, 0x01, 0x00, 0x03]); + } + + #[test] + fn chunk_key_zero_index() { + // TypeScript: getChunkKey(FILE_TAG_MAIN, 0) => [8, 1, 1, 0, 0, 0, 0, 0] + assert_eq!( + get_chunk_key(FILE_TAG_MAIN, 0), + [0x08, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00] + ); + } + + #[test] + fn chunk_key_index_one() { + // TypeScript: getChunkKey(FILE_TAG_MAIN, 1) => [8, 1, 1, 0, 0, 0, 0, 1] + assert_eq!( + get_chunk_key(FILE_TAG_MAIN, 1), + [0x08, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01] + ); + } + + #[test] + fn chunk_key_large_index() { + // TypeScript: getChunkKey(FILE_TAG_MAIN, 256) => [8, 1, 1, 0, 0, 0, 1, 0] + assert_eq!( + get_chunk_key(FILE_TAG_MAIN, 256), + [0x08, 0x01, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00] + ); + } + + #[test] + fn chunk_key_max_index() { + // TypeScript: getChunkKey(FILE_TAG_MAIN, 0xFFFFFFFF) => [8, 1, 1, 0, 255, 255, 255, 255] + assert_eq!( + get_chunk_key(FILE_TAG_MAIN, u32::MAX), + [0x08, 0x01, 0x01, 0x00, 0xFF, 0xFF, 0xFF, 0xFF] + ); + } + + #[test] + fn chunk_key_journal_tag() { + assert_eq!( + get_chunk_key(FILE_TAG_JOURNAL, 42), + [0x08, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 42] + ); + } + + #[test] + fn chunk_key_big_endian_encoding() { + // 0x01020304 => bytes [1, 2, 3, 4] + assert_eq!( + get_chunk_key(FILE_TAG_MAIN, 0x01020304), + [0x08, 0x01, 0x01, 0x00, 0x01, 0x02, 0x03, 0x04] + ); + } + + #[test] + fn chunk_key_range_end_main() { + // TypeScript: getChunkKeyRangeEnd(FILE_TAG_MAIN) => [8, 1, 1, 1] + assert_eq!( + get_chunk_key_range_end(FILE_TAG_MAIN), + [0x08, 0x01, 0x01, 0x01] + ); + } + + #[test] + fn chunk_key_range_end_journal() { + // TypeScript: getChunkKeyRangeEnd(FILE_TAG_JOURNAL) => [8, 1, 1, 2] + assert_eq!( + get_chunk_key_range_end(FILE_TAG_JOURNAL), + [0x08, 0x01, 0x01, 0x02] + ); + } + + #[test] + fn range_end_is_past_all_chunk_keys() { + // The range end key must be lexicographically greater than any chunk key for the same tag. + let max_chunk = get_chunk_key(FILE_TAG_MAIN, u32::MAX); + let range_end = get_chunk_key_range_end(FILE_TAG_MAIN); + // Compare as slices. The range end [8,1,1,1] > [8,1,1,0,FF,FF,FF,FF] + // because at byte index 3, 1 > 0. + assert!(range_end.as_slice() > max_chunk.as_slice()); + } } diff --git a/rivetkit-typescript/packages/sqlite-native/src/sqlite_kv.rs b/rivetkit-typescript/packages/sqlite-native/src/sqlite_kv.rs index 974f6e9a24..4d906b0552 100644 --- a/rivetkit-typescript/packages/sqlite-native/src/sqlite_kv.rs +++ b/rivetkit-typescript/packages/sqlite-native/src/sqlite_kv.rs @@ -98,11 +98,7 @@ pub trait SqliteKv: Send + Sync { ) -> Result<(), SqliteKvError>; /// Delete multiple keys in one batch. - async fn batch_delete( - &self, - actor_id: &str, - keys: Vec>, - ) -> Result<(), SqliteKvError>; + async fn batch_delete(&self, actor_id: &str, keys: Vec>) -> Result<(), SqliteKvError>; /// Delete all keys in the half-open range `[start, end)`. async fn delete_range( diff --git a/rivetkit-typescript/packages/sqlite-native/src/vfs.rs b/rivetkit-typescript/packages/sqlite-native/src/vfs.rs index 750e981cc0..7d00a36974 100644 --- a/rivetkit-typescript/packages/sqlite-native/src/vfs.rs +++ b/rivetkit-typescript/packages/sqlite-native/src/vfs.rs @@ -33,10 +33,7 @@ macro_rules! vfs_catch_unwind { match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| $body)) { Ok(result) => result, Err(panic) => { - tracing::error!( - message = panic_message(&panic), - "vfs callback panicked" - ); + tracing::error!(message = panic_message(&panic), "vfs callback panicked"); $err_val } } @@ -62,11 +59,10 @@ const KV_MAX_BATCH_KEYS: usize = 128; /// This must match `HEADER_PREFIX` in /// `rivetkit-typescript/packages/sqlite-vfs/src/generated/empty-db-page.ts`. const EMPTY_DB_PAGE_HEADER_PREFIX: [u8; 108] = [ - 83, 81, 76, 105, 116, 101, 32, 102, 111, 114, 109, 97, 116, 32, 51, 0, 16, 0, - 1, 1, 0, 64, 32, 32, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 3, 0, 46, 138, 17, 13, 0, 0, 0, 0, 16, 0, 0, + 83, 81, 76, 105, 116, 101, 32, 102, 111, 114, 109, 97, 116, 32, 51, 0, 16, 0, 1, 1, 0, 64, 32, + 32, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 46, 138, 17, 13, 0, 0, 0, 0, 16, 0, 0, ]; fn empty_db_page() -> Vec { @@ -175,7 +171,11 @@ impl VfsContext { .map_err(|e| e.to_string()); let elapsed = start.elapsed(); if std::env::var("RIVET_TRACE_SQL").is_ok() { - eprintln!("[sql-trace] kv_roundtrip op={} duration={}us", op_name, elapsed.as_micros()); + eprintln!( + "[sql-trace] kv_roundtrip op={} duration={}us", + op_name, + elapsed.as_micros() + ); } tracing::debug!( op = %op_name, @@ -194,7 +194,11 @@ impl VfsContext { .map_err(|e| e.to_string()); let elapsed = start.elapsed(); if std::env::var("RIVET_TRACE_SQL").is_ok() { - eprintln!("[sql-trace] kv_roundtrip op={} duration={}us", op_name, elapsed.as_micros()); + eprintln!( + "[sql-trace] kv_roundtrip op={} duration={}us", + op_name, + elapsed.as_micros() + ); } tracing::debug!( op = %op_name, @@ -213,7 +217,11 @@ impl VfsContext { .map_err(|e| e.to_string()); let elapsed = start.elapsed(); if std::env::var("RIVET_TRACE_SQL").is_ok() { - eprintln!("[sql-trace] kv_roundtrip op={} duration={}us", op_name, elapsed.as_micros()); + eprintln!( + "[sql-trace] kv_roundtrip op={} duration={}us", + op_name, + elapsed.as_micros() + ); } tracing::debug!( op = %op_name, @@ -232,7 +240,11 @@ impl VfsContext { .map_err(|e| e.to_string()); let elapsed = start_time.elapsed(); if std::env::var("RIVET_TRACE_SQL").is_ok() { - eprintln!("[sql-trace] kv_roundtrip op={} duration={}us", op_name, elapsed.as_micros()); + eprintln!( + "[sql-trace] kv_roundtrip op={} duration={}us", + op_name, + elapsed.as_micros() + ); } tracing::debug!( op = %op_name, @@ -437,24 +449,23 @@ unsafe extern "C" fn kv_io_read( let value_map = build_value_map(&resp); for chunk_idx in start_chunk..=end_chunk { - let chunk_data: Option<&[u8]> = buffered_chunks.get(&chunk_idx).map(|v| v.as_slice()).or_else(|| { - let key = kv::get_chunk_key(file.file_tag, chunk_idx as u32); - value_map.get(key.as_slice()).copied() - }); + let chunk_data: Option<&[u8]> = buffered_chunks + .get(&chunk_idx) + .map(|v| v.as_slice()) + .or_else(|| { + let key = kv::get_chunk_key(file.file_tag, chunk_idx as u32); + value_map.get(key.as_slice()).copied() + }); let chunk_offset = chunk_idx * kv::CHUNK_SIZE; let read_start = offset.saturating_sub(chunk_offset); - let read_end = std::cmp::min( - kv::CHUNK_SIZE, - offset + requested_length - chunk_offset, - ); + let read_end = std::cmp::min(kv::CHUNK_SIZE, offset + requested_length - chunk_offset); let dest_start = chunk_offset + read_start - offset; if let Some(chunk_data) = chunk_data { let source_end = std::cmp::min(read_end, chunk_data.len()); if source_end > read_start { let dest_end = dest_start + (source_end - read_start); - buf[dest_start..dest_end] - .copy_from_slice(&chunk_data[read_start..source_end]); + buf[dest_start..dest_end].copy_from_slice(&chunk_data[read_start..source_end]); } if source_end < read_end { let zero_start = dest_start + (source_end - read_start); @@ -470,11 +481,15 @@ unsafe extern "C" fn kv_io_read( let actual_bytes = std::cmp::min(requested_length, file_size - offset); if actual_bytes < requested_length { buf[actual_bytes..].fill(0); - ctx.vfs_metrics.xread_us.fetch_add(read_start.elapsed().as_micros() as u64, Ordering::Relaxed); + ctx.vfs_metrics + .xread_us + .fetch_add(read_start.elapsed().as_micros() as u64, Ordering::Relaxed); return SQLITE_IOERR_SHORT_READ; } - ctx.vfs_metrics.xread_us.fetch_add(read_start.elapsed().as_micros() as u64, Ordering::Relaxed); + ctx.vfs_metrics + .xread_us + .fetch_add(read_start.elapsed().as_micros() as u64, Ordering::Relaxed); SQLITE_OK }) } @@ -518,12 +533,10 @@ unsafe extern "C" fn kv_io_write( if state.batch_mode { for chunk_idx in start_chunk..=end_chunk { let chunk_offset = chunk_idx * kv::CHUNK_SIZE; - let source_start = std::cmp::max(0isize, chunk_offset as isize - offset as isize) - as usize; - let source_end = std::cmp::min( - write_length, - chunk_offset + kv::CHUNK_SIZE - offset, - ); + let source_start = + std::cmp::max(0isize, chunk_offset as isize - offset as isize) as usize; + let source_end = + std::cmp::min(write_length, chunk_offset + kv::CHUNK_SIZE - offset); state .dirty_buffer .insert(chunk_idx as u32, data[source_start..source_end].to_vec()); @@ -535,8 +548,12 @@ unsafe extern "C" fn kv_io_write( file.meta_dirty = true; } - ctx.vfs_metrics.xwrite_buffered_count.fetch_add(1, Ordering::Relaxed); - ctx.vfs_metrics.xwrite_us.fetch_add(write_start.elapsed().as_micros() as u64, Ordering::Relaxed); + ctx.vfs_metrics + .xwrite_buffered_count + .fetch_add(1, Ordering::Relaxed); + ctx.vfs_metrics + .xwrite_us + .fetch_add(write_start.elapsed().as_micros() as u64, Ordering::Relaxed); return SQLITE_OK; } } @@ -554,10 +571,7 @@ unsafe extern "C" fn kv_io_write( for chunk_idx in start_chunk..=end_chunk { let chunk_offset = chunk_idx * kv::CHUNK_SIZE; let write_start = offset.saturating_sub(chunk_offset); - let write_end = std::cmp::min( - kv::CHUNK_SIZE, - offset + write_length - chunk_offset, - ); + let write_end = std::cmp::min(kv::CHUNK_SIZE, offset + write_length - chunk_offset); let existing_bytes_in_chunk = if file.size as usize > chunk_offset { std::cmp::min(kv::CHUNK_SIZE, file.size as usize - chunk_offset) } else { @@ -650,15 +664,14 @@ unsafe extern "C" fn kv_io_write( } file.meta_dirty = false; - ctx.vfs_metrics.xwrite_us.fetch_add(write_start.elapsed().as_micros() as u64, Ordering::Relaxed); + ctx.vfs_metrics + .xwrite_us + .fetch_add(write_start.elapsed().as_micros() as u64, Ordering::Relaxed); SQLITE_OK }) } -unsafe extern "C" fn kv_io_truncate( - p_file: *mut sqlite3_file, - size: sqlite3_int64, -) -> c_int { +unsafe extern "C" fn kv_io_truncate(p_file: *mut sqlite3_file, size: sqlite3_int64) -> c_int { vfs_catch_unwind!(SQLITE_IOERR_TRUNCATE, { let file = get_file(p_file); let ctx = &*file.ctx; @@ -692,7 +705,11 @@ unsafe extern "C" fn kv_io_truncate( // Invalidate read cache entries for truncated chunks. { let state = get_file_state(file.state); - let truncate_from_chunk = if size == 0 { 0u32 } else { (size as u32 / kv::CHUNK_SIZE as u32) + 1 }; + let truncate_from_chunk = if size == 0 { + 0u32 + } else { + (size as u32 / kv::CHUNK_SIZE as u32) + 1 + }; state.read_cache.retain(|key, _| { // Chunk keys are 8 bytes: [prefix, version, CHUNK_PREFIX, file_tag, idx_be32] if key.len() == 8 && key[3] == file.file_tag { @@ -772,10 +789,7 @@ unsafe extern "C" fn kv_io_truncate( }) } -unsafe extern "C" fn kv_io_sync( - p_file: *mut sqlite3_file, - _flags: c_int, -) -> c_int { +unsafe extern "C" fn kv_io_sync(p_file: *mut sqlite3_file, _flags: c_int) -> c_int { vfs_catch_unwind!(SQLITE_IOERR_FSYNC, { let file = get_file(p_file); if !file.meta_dirty { @@ -887,16 +901,24 @@ unsafe extern "C" fn kv_io_file_control( // Move dirty buffer entries into the read cache so subsequent // reads can serve them without a KV round-trip. - let flushed: Vec<_> = std::mem::take(&mut state.dirty_buffer).into_iter().collect(); + let flushed: Vec<_> = std::mem::take(&mut state.dirty_buffer) + .into_iter() + .collect(); for (chunk_index, data) in flushed { let key = kv::get_chunk_key(file.file_tag, chunk_index); state.read_cache.insert(key.to_vec(), data); } file.meta_dirty = false; state.batch_mode = false; - ctx.vfs_metrics.commit_atomic_count.fetch_add(1, Ordering::Relaxed); - ctx.vfs_metrics.commit_atomic_pages.fetch_add(dirty_page_count, Ordering::Relaxed); - ctx.vfs_metrics.commit_atomic_us.fetch_add(commit_start.elapsed().as_micros() as u64, Ordering::Relaxed); + ctx.vfs_metrics + .commit_atomic_count + .fetch_add(1, Ordering::Relaxed); + ctx.vfs_metrics + .commit_atomic_pages + .fetch_add(dirty_page_count, Ordering::Relaxed); + ctx.vfs_metrics + .commit_atomic_us + .fetch_add(commit_start.elapsed().as_micros() as u64, Ordering::Relaxed); SQLITE_OK } SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE => { @@ -977,10 +999,7 @@ unsafe extern "C" fn kv_vfs_open( } else { let size = 0i64; if ctx - .kv_put( - vec![meta_key.to_vec()], - vec![encode_file_meta(size)], - ) + .kv_put(vec![meta_key.to_vec()], vec![encode_file_meta(size)]) .is_err() { return SQLITE_CANTOPEN; @@ -1077,7 +1096,11 @@ unsafe extern "C" fn kv_vfs_access( Err(_) => return SQLITE_IOERR_ACCESS, }; let value_map = build_value_map(&resp); - *p_res_out = if value_map.contains_key(meta_key.as_slice()) { 1 } else { 0 }; + *p_res_out = if value_map.contains_key(meta_key.as_slice()) { + 1 + } else { + 0 + }; SQLITE_OK }) @@ -1119,20 +1142,14 @@ unsafe extern "C" fn kv_vfs_randomness( }) } -unsafe extern "C" fn kv_vfs_sleep( - _p_vfs: *mut sqlite3_vfs, - microseconds: c_int, -) -> c_int { +unsafe extern "C" fn kv_vfs_sleep(_p_vfs: *mut sqlite3_vfs, microseconds: c_int) -> c_int { vfs_catch_unwind!(0, { std::thread::sleep(std::time::Duration::from_micros(microseconds as u64)); microseconds }) } -unsafe extern "C" fn kv_vfs_current_time( - _p_vfs: *mut sqlite3_vfs, - p_time_out: *mut f64, -) -> c_int { +unsafe extern "C" fn kv_vfs_current_time(_p_vfs: *mut sqlite3_vfs, p_time_out: *mut f64) -> c_int { vfs_catch_unwind!(SQLITE_IOERR, { let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -1299,15 +1316,8 @@ pub fn open_database(vfs: KvVfs, file_name: &str) -> Result