From 3cb32e3987cb01ad7c07b446cabc7f31afda1c27 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sun, 26 Apr 2026 18:07:05 -0700 Subject: [PATCH] test(rivetkit-core): move private tests to crate root --- .../packages/rivetkit-core/Cargo.toml | 1 + .../rivetkit-core/src/actor/config.rs | 3 +- .../rivetkit-core/src/actor/connection.rs | 334 +------------ .../rivetkit-core/src/actor/context.rs | 3 +- .../packages/rivetkit-core/src/actor/kv.rs | 3 +- .../rivetkit-core/src/actor/messages.rs | 3 +- .../rivetkit-core/src/actor/metrics.rs | 38 +- .../packages/rivetkit-core/src/actor/queue.rs | 161 +------ .../rivetkit-core/src/actor/schedule.rs | 200 +------- .../packages/rivetkit-core/src/actor/sleep.rs | 439 +----------------- .../packages/rivetkit-core/src/actor/state.rs | 11 +- .../packages/rivetkit-core/src/actor/task.rs | 3 +- .../rivetkit-core/src/actor/work_registry.rs | 38 +- .../rivetkit-core/src/inspector/auth.rs | 8 + .../rivetkit-core/src/inspector/mod.rs | 3 +- .../src/registry/envoy_callbacks.rs | 59 +-- .../rivetkit-core/src/registry/http.rs | 12 +- .../packages/rivetkit-core/src/serverless.rs | 145 +----- .../packages/rivetkit-core/src/websocket.rs | 3 +- .../tests/{modules => }/config.rs | 0 .../rivetkit-core/tests/connection.rs | 324 +++++++++++++ .../tests/{modules => }/context.rs | 0 .../rivetkit-core/tests/envoy_callbacks.rs | 59 +++ .../tests/{modules => }/inspector.rs | 10 +- .../rivetkit-core/tests/{modules => }/kv.rs | 0 .../tests/{modules => }/messages.rs | 0 .../packages/rivetkit-core/tests/metrics.rs | 37 ++ .../packages/rivetkit-core/tests/queue.rs | 160 +++++++ .../rivetkit-core/tests/registry_http.rs | 237 ++++++++++ .../packages/rivetkit-core/tests/schedule.rs | 197 ++++++++ .../rivetkit-core/tests/serverless.rs | 144 ++++++ .../packages/rivetkit-core/tests/sleep.rs | 428 +++++++++++++++++ .../tests/{modules => }/state.rs | 0 .../rivetkit-core/tests/{modules => }/task.rs | 68 +-- .../tests/{modules => }/websocket.rs | 0 .../rivetkit-core/tests/work_registry.rs | 37 ++ 36 files changed, 1732 insertions(+), 1436 deletions(-) rename rivetkit-rust/packages/rivetkit-core/tests/{modules => }/config.rs (100%) create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/connection.rs rename rivetkit-rust/packages/rivetkit-core/tests/{modules => }/context.rs (100%) create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/envoy_callbacks.rs rename rivetkit-rust/packages/rivetkit-core/tests/{modules => }/inspector.rs (96%) rename rivetkit-rust/packages/rivetkit-core/tests/{modules => }/kv.rs (100%) rename rivetkit-rust/packages/rivetkit-core/tests/{modules => }/messages.rs (100%) create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/metrics.rs create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/queue.rs create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/registry_http.rs create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/schedule.rs create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/serverless.rs create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/sleep.rs rename rivetkit-rust/packages/rivetkit-core/tests/{modules => }/state.rs (100%) rename rivetkit-rust/packages/rivetkit-core/tests/{modules => }/task.rs (98%) rename rivetkit-rust/packages/rivetkit-core/tests/{modules => }/websocket.rs (100%) create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/work_registry.rs diff --git a/rivetkit-rust/packages/rivetkit-core/Cargo.toml b/rivetkit-rust/packages/rivetkit-core/Cargo.toml index d544d92950..a92c2e0959 100644 --- a/rivetkit-rust/packages/rivetkit-core/Cargo.toml +++ b/rivetkit-rust/packages/rivetkit-core/Cargo.toml @@ -5,6 +5,7 @@ authors.workspace = true license.workspace = true edition.workspace = true workspace = "../../../" +autotests = false [features] default = [] diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/config.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/config.rs index ab5a3e3666..c8b3d79d44 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/config.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/config.rs @@ -282,6 +282,7 @@ fn duration_ms(value: u32) -> Duration { Duration::from_millis(u64::from(value)) } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -#[path = "../../tests/modules/config.rs"] +#[path = "../../tests/config.rs"] mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/connection.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/connection.rs index 0b0297ff68..63d3ea6235 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/connection.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/connection.rs @@ -982,335 +982,7 @@ pub(crate) fn make_connection_key(conn_id: &str) -> Vec { key } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -mod tests { - use std::collections::BTreeSet; - use std::sync::Arc; - use std::sync::atomic::{AtomicUsize, Ordering}; - - use parking_lot::Mutex; - use tokio::sync::{Barrier, mpsc}; - use tokio::task::yield_now; - - use super::{ - HibernatableConnectionMetadata, PersistedConnection, decode_persisted_connection, - encode_persisted_connection, hibernatable_id_from_slice, make_connection_key, - }; - use crate::actor::context::ActorContext; - use crate::actor::messages::ActorEvent; - use crate::actor::preload::PreloadedKv; - use crate::actor::task::LifecycleEvent; - use crate::kv::Kv; - - fn next_non_activity_lifecycle_event( - rx: &mut mpsc::Receiver, - ) -> Option { - rx.try_recv().ok() - } - - #[tokio::test] - async fn restore_persisted_uses_preloaded_connection_prefix_when_present() { - let ctx = ActorContext::new_with_kv( - "actor-preload", - "actor", - Vec::new(), - "local", - Kv::new_in_memory(), - ); - let persisted = PersistedConnection { - id: "conn-preloaded".to_owned(), - parameters: vec![1], - state: vec![2], - gateway_id: [1, 2, 3, 4], - request_id: [5, 6, 7, 8], - request_path: "/socket".to_owned(), - ..PersistedConnection::default() - }; - let preloaded = PreloadedKv::new_with_requested_get_keys( - [( - make_connection_key(&persisted.id), - encode_persisted_connection(&persisted) - .expect("persisted connection should encode"), - )], - Vec::new(), - vec![vec![2]], - ); - - let restored = ctx - .restore_persisted(Some(&preloaded)) - .await - .expect("restore should use preloaded entries instead of unconfigured kv"); - - assert_eq!(restored.len(), 1); - assert_eq!(restored[0].id(), "conn-preloaded"); - assert_eq!(restored[0].state(), vec![2]); - assert!(ctx.connection("conn-preloaded").is_some()); - } - - #[test] - fn persisted_connection_uses_ts_v4_fixed_id_wire_format() { - let persisted = PersistedConnection { - id: "c".to_owned(), - parameters: vec![1, 2], - state: vec![3], - gateway_id: [10, 11, 12, 13], - request_id: [20, 21, 22, 23], - server_message_index: 9, - client_message_index: 10, - request_path: "/".to_owned(), - ..PersistedConnection::default() - }; - - let encoded = - encode_persisted_connection(&persisted).expect("persisted connection should encode"); - - assert_eq!( - encoded, - vec![ - 4, 0, // embedded version - 1, b'c', // id - 2, 1, 2, // parameters - 1, 3, // state - 0, // subscriptions - 10, 11, 12, 13, // gatewayId fixed data[4] - 20, 21, 22, 23, // requestId fixed data[4] - 9, 0, // serverMessageIndex - 10, 0, // clientMessageIndex - 1, b'/', // requestPath - 0, // requestHeaders - ] - ); - - let decoded = - decode_persisted_connection(&encoded).expect("persisted connection should decode"); - assert_eq!(decoded.gateway_id, [10, 11, 12, 13]); - assert_eq!(decoded.request_id, [20, 21, 22, 23]); - } - - #[test] - fn hibernatable_id_validation_returns_rivet_error() { - let error = hibernatable_id_from_slice("gateway_id", &[1, 2, 3]) - .expect_err("invalid id should fail"); - let error = rivet_error::RivetError::extract(&error); - - assert_eq!(error.group(), "actor"); - assert_eq!(error.code(), "invalid_request"); - } - - #[tokio::test(start_paused = true)] - async fn concurrent_disconnects_only_emit_one_close_and_one_hibernation_removal() { - let ctx = ActorContext::new_with_kv( - "actor-race", - "actor", - Vec::new(), - "local", - Kv::new_in_memory(), - ); - ctx.configure_connection_runtime(crate::actor::config::ActorConfig::default()); - let (events_tx, mut events_rx) = mpsc::unbounded_channel(); - ctx.configure_actor_events(Some(events_tx)); - let closed = Arc::new(AtomicUsize::new(0)); - let observed_conn_id = Arc::new(Mutex::new(None::)); - - let recv = tokio::spawn({ - let closed = closed.clone(); - let observed_conn_id = observed_conn_id.clone(); - async move { - while let Some(event) = events_rx.recv().await { - match event { - ActorEvent::ConnectionOpen { reply, .. } => reply.send(Ok(())), - ActorEvent::ConnectionClosed { conn } => { - *observed_conn_id.lock() = Some(conn.id().to_owned()); - closed.fetch_add(1, Ordering::SeqCst); - break; - } - other => panic!("unexpected event: {other:?}"), - } - } - } - }); - - let conn = ctx - .connect_with_state( - vec![1], - true, - Some(HibernatableConnectionMetadata { - gateway_id: [1, 2, 3, 4], - request_id: [5, 6, 7, 8], - ..HibernatableConnectionMetadata::default() - }), - None, - async { Ok(vec![9]) }, - ) - .await - .expect("connection should open"); - let conn_id = conn.id().to_owned(); - ctx.record_connections_updated(); - ctx.reset_sleep_timer(); - - let barrier = Arc::new(Barrier::new(2)); - conn.configure_transport_disconnect_handler(Some(Arc::new({ - let barrier = barrier.clone(); - move |_reason| { - let barrier = barrier.clone(); - Box::pin(async move { - barrier.wait().await; - Ok(()) - }) - } - }))); - - let first = tokio::spawn({ - let conn = conn.clone(); - async move { conn.disconnect(Some("first")).await } - }); - let second = tokio::spawn({ - let conn = conn.clone(); - async move { conn.disconnect(Some("second")).await } - }); - - yield_now().await; - first - .await - .expect("first disconnect task should join") - .expect("first disconnect should succeed"); - second - .await - .expect("second disconnect task should join") - .expect("second disconnect should succeed"); - recv.await.expect("event receiver should join"); - - assert_eq!(closed.load(Ordering::SeqCst), 1); - assert_eq!(observed_conn_id.lock().as_deref(), Some(conn_id.as_str())); - assert!(ctx.connection(&conn_id).is_none()); - - let pending = ctx.take_pending_hibernation_changes_inner(); - assert!(pending.updated.is_empty()); - assert_eq!(pending.removed, BTreeSet::from([conn_id])); - } - - #[tokio::test] - async fn hibernatable_set_state_queues_save_and_non_hibernatable_stays_memory_only() { - let ctx = ActorContext::new_with_kv( - "actor-state-dirty", - "actor", - Vec::new(), - "local", - Kv::new_in_memory(), - ); - let (actor_events_tx, mut actor_events_rx) = mpsc::unbounded_channel(); - let (lifecycle_events_tx, mut lifecycle_events_rx) = mpsc::channel(4); - ctx.configure_actor_events(Some(actor_events_tx)); - ctx.configure_lifecycle_events(Some(lifecycle_events_tx)); - - let open_replies = tokio::spawn(async move { - for _ in 0..2 { - match actor_events_rx - .recv() - .await - .expect("open event should arrive") - { - ActorEvent::ConnectionOpen { reply, .. } => reply.send(Ok(())), - other => panic!("unexpected actor event: {other:?}"), - } - } - }); - - let non_hibernatable = ctx - .connect_with_state(vec![1], false, None, None, async { Ok(vec![2]) }) - .await - .expect("non-hibernatable connection should open"); - non_hibernatable.set_state(vec![3]); - assert_eq!(non_hibernatable.state(), vec![3]); - assert!( - ctx.dirty_hibernatable_conns_inner().is_empty(), - "non-hibernatable state changes should not queue persistence" - ); - assert!( - next_non_activity_lifecycle_event(&mut lifecycle_events_rx).is_none(), - "non-hibernatable state changes should not request actor save" - ); - - let hibernatable = ctx - .connect_with_state( - vec![4], - true, - Some(HibernatableConnectionMetadata { - gateway_id: [1, 2, 3, 4], - request_id: [5, 6, 7, 8], - ..HibernatableConnectionMetadata::default() - }), - None, - async { Ok(vec![5]) }, - ) - .await - .expect("hibernatable connection should open"); - hibernatable.set_state(vec![6]); - - assert_eq!( - ctx.dirty_hibernatable_conns_inner() - .into_iter() - .map(|conn| conn.id().to_owned()) - .collect::>(), - vec![hibernatable.id().to_owned()] - ); - assert_eq!( - next_non_activity_lifecycle_event(&mut lifecycle_events_rx) - .expect("hibernatable state change should request save"), - LifecycleEvent::SaveRequested { immediate: false } - ); - - open_replies - .await - .expect("open reply task should join cleanly"); - } - - #[tokio::test(start_paused = true)] - async fn remove_existing_for_disconnect_has_exactly_one_winner() { - let ctx = ActorContext::new_with_kv( - "actor-race", - "actor", - Vec::new(), - "local", - Kv::new_in_memory(), - ); - let conn = super::ConnHandle::new("conn-race", vec![1], vec![2], true); - conn.configure_hibernation(Some(HibernatableConnectionMetadata { - gateway_id: [1, 2, 3, 4], - request_id: [5, 6, 7, 8], - ..HibernatableConnectionMetadata::default() - })); - ctx.insert_existing(conn); - - let barrier = Arc::new(Barrier::new(2)); - let first = tokio::spawn({ - let ctx = ctx.clone(); - let barrier = barrier.clone(); - async move { - barrier.wait().await; - ctx.remove_existing_for_disconnect("conn-race") - .map(|conn| conn.id().to_owned()) - } - }); - let second = tokio::spawn({ - let ctx = ctx.clone(); - let barrier = barrier.clone(); - async move { - barrier.wait().await; - ctx.remove_existing_for_disconnect("conn-race") - .map(|conn| conn.id().to_owned()) - } - }); - - let first = first.await.expect("first task should join"); - let second = second.await.expect("second task should join"); - let winners = [first, second].into_iter().flatten().collect::>(); - - assert_eq!(winners, vec!["conn-race".to_owned()]); - assert!(ctx.connection("conn-race").is_none()); - - let pending = ctx.take_pending_hibernation_changes_inner(); - assert!(pending.updated.is_empty()); - assert_eq!(pending.removed, BTreeSet::from(["conn-race".to_owned()])); - } -} +#[path = "../../tests/connection.rs"] +mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/context.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/context.rs index c60a45278f..d6cb69dc54 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/context.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/context.rs @@ -1547,6 +1547,7 @@ impl std::fmt::Debug for ActorContext { } } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -#[path = "../../tests/modules/context.rs"] +#[path = "../../tests/context.rs"] pub(crate) mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/kv.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/kv.rs index 1ecfa52f86..5c7d39f479 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/kv.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/kv.rs @@ -450,6 +450,7 @@ fn apply_list_opts(entries: &mut Vec<(Vec, Vec)>, opts: ListOpts) { } } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -#[path = "../../tests/modules/kv.rs"] +#[path = "../../tests/kv.rs"] pub(crate) mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs index ac7b2eca74..970c517423 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/messages.rs @@ -365,6 +365,7 @@ impl ActorEvent { } } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -#[path = "../../tests/modules/messages.rs"] +#[path = "../../tests/messages.rs"] mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/metrics.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/metrics.rs index 7566870fae..fe556c45b1 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/metrics.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/metrics.rs @@ -466,39 +466,7 @@ where } } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -mod tests { - use std::panic::{AssertUnwindSafe, catch_unwind}; - - use super::*; - - #[test] - fn duplicate_metric_registration_uses_noop_fallback() { - let registry = Registry::new(); - let first = IntGauge::with_opts(Opts::new( - "duplicate_actor_metric", - "first duplicate metric", - )) - .expect("first gauge should be valid"); - let second = IntGauge::with_opts(Opts::new( - "duplicate_actor_metric", - "second duplicate metric", - )) - .expect("second gauge should be valid"); - - register_metric(®istry, first.clone()); - let result = catch_unwind(AssertUnwindSafe(|| { - register_metric(®istry, second.clone()); - })); - - assert!(result.is_ok()); - assert_eq!( - 1, - registry - .gather() - .iter() - .filter(|family| family.name() == "duplicate_actor_metric") - .count() - ); - } -} +#[path = "../../tests/metrics.rs"] +mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/queue.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/queue.rs index cd4dbba51c..d5d6fdbf1d 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/queue.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/queue.rs @@ -1178,162 +1178,7 @@ fn duration_ms(duration: Duration) -> u64 { duration.as_millis().try_into().unwrap_or(u64::MAX) } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -mod tests { - use super::{ - PersistedQueueMessage, QUEUE_MESSAGES_PREFIX, QUEUE_METADATA_KEY, QueueMetadata, - QueueNextOpts, QueueWaitOpts, encode_queue_message, encode_queue_metadata, - make_queue_message_key, - }; - - use crate::actor::context::ActorContext; - use crate::actor::preload::PreloadedKv; - use crate::kv::Kv; - use std::time::Duration; - use tokio::task::yield_now; - use tokio_util::sync::CancellationToken; - - fn test_queue() -> ActorContext { - ActorContext::new_with_kv( - "actor-queue", - "queue-test", - Vec::new(), - "local", - Kv::new_in_memory(), - ) - } - - fn assert_actor_aborted(error: anyhow::Error) { - let error = rivet_error::RivetError::extract(&error); - assert_eq!(error.group(), "actor"); - assert_eq!(error.code(), "aborted"); - } - - #[tokio::test] - async fn inspect_messages_uses_preloaded_queue_entries_when_present() { - let queue = ActorContext::new_with_kv( - "actor-queue", - "queue-preload", - Vec::new(), - "local", - Kv::default(), - ); - let metadata = QueueMetadata { - next_id: 8, - size: 1, - }; - let persisted = PersistedQueueMessage { - name: "preloaded".to_owned(), - body: b"body".to_vec(), - created_at: 42, - failure_count: None, - available_at: None, - in_flight: None, - in_flight_at: None, - }; - queue.configure_preload(Some(PreloadedKv::new_with_requested_get_keys( - [ - ( - QUEUE_METADATA_KEY.to_vec(), - encode_queue_metadata(&metadata).expect("metadata should encode"), - ), - ( - make_queue_message_key(7), - encode_queue_message(&persisted).expect("message should encode"), - ), - ], - vec![QUEUE_METADATA_KEY.to_vec()], - vec![QUEUE_MESSAGES_PREFIX.to_vec()], - ))); - - let messages = queue - .inspect_messages() - .await - .expect("queue should initialize from preload without touching kv"); - - assert_eq!(messages.len(), 1); - assert_eq!(messages[0].id, 7); - assert_eq!(messages[0].name, "preloaded"); - assert_eq!(messages[0].body, b"body"); - assert_eq!(*queue.0.queue_metadata.lock().await, metadata); - } - - #[tokio::test] - async fn wait_for_names_returns_aborted_when_signal_is_already_cancelled() { - let queue = test_queue(); - let signal = CancellationToken::new(); - signal.cancel(); - - let error = queue - .wait_for_names( - vec!["missing".to_owned()], - QueueWaitOpts { - signal: Some(signal), - ..Default::default() - }, - ) - .await - .expect_err("already-cancelled waits should abort immediately"); - - assert_actor_aborted(error); - } - - #[tokio::test(start_paused = true)] - async fn wait_for_names_returns_aborted_when_signal_cancels_during_wait() { - let queue = test_queue(); - let signal = CancellationToken::new(); - let wait_signal = signal.clone(); - let wait_queue = queue.clone(); - - let wait = tokio::spawn(async move { - wait_queue - .wait_for_names( - vec!["missing".to_owned()], - QueueWaitOpts { - timeout: Some(Duration::from_secs(60)), - signal: Some(wait_signal), - ..Default::default() - }, - ) - .await - }); - - yield_now().await; - signal.cancel(); - - let error = wait - .await - .expect("wait task should join") - .expect_err("cancelled waits should abort"); - - assert_actor_aborted(error); - } - - #[tokio::test(start_paused = true)] - async fn next_returns_aborted_when_actor_signal_cancels_during_wait() { - let queue = test_queue(); - - let wait = tokio::spawn({ - let queue = queue.clone(); - async move { - queue - .next(QueueNextOpts { - names: Some(vec!["missing".to_owned()]), - timeout: Some(Duration::from_secs(60)), - ..Default::default() - }) - .await - } - }); - - yield_now().await; - queue.mark_destroy_requested(); - - let error = wait - .await - .expect("wait task should join") - .expect_err("cancelled actor waits should abort"); - - assert_actor_aborted(error); - } -} +#[path = "../../tests/queue.rs"] +mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/schedule.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/schedule.rs index f795a765f1..ceee0937ef 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/schedule.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/schedule.rs @@ -418,201 +418,7 @@ fn now_timestamp_ms() -> i64 { i64::try_from(duration.as_millis()).unwrap_or(i64::MAX) } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -mod tests { - use std::collections::HashMap; - use std::sync::Mutex as EnvoySharedMutex; - use std::sync::atomic::AtomicBool; - - use rivet_envoy_client::config::{ - BoxFuture, EnvoyCallbacks, EnvoyConfig, HttpRequest, HttpResponse, WebSocketHandler, - WebSocketSender, - }; - use rivet_envoy_client::context::{SharedContext, WsTxMessage}; - use rivet_envoy_client::envoy::ToEnvoyMessage; - use rivet_envoy_client::protocol; - use tokio::sync::mpsc; - - use super::*; - - struct IdleEnvoyCallbacks; - - impl EnvoyCallbacks for IdleEnvoyCallbacks { - fn on_actor_start( - &self, - _handle: EnvoyHandle, - _actor_id: String, - _generation: u32, - _config: protocol::ActorConfig, - _preloaded_kv: Option, - _sqlite_startup_data: Option, - ) -> BoxFuture> { - Box::pin(async { Ok(()) }) - } - - fn on_shutdown(&self) {} - - fn fetch( - &self, - _handle: EnvoyHandle, - _actor_id: String, - _gateway_id: protocol::GatewayId, - _request_id: protocol::RequestId, - _request: HttpRequest, - ) -> BoxFuture> { - Box::pin(async { anyhow::bail!("fetch should not run in schedule tests") }) - } - - fn websocket( - &self, - _handle: EnvoyHandle, - _actor_id: String, - _gateway_id: protocol::GatewayId, - _request_id: protocol::RequestId, - _request: HttpRequest, - _path: String, - _headers: HashMap, - _is_hibernatable: bool, - _is_restoring_hibernatable: bool, - _sender: WebSocketSender, - ) -> BoxFuture> { - Box::pin(async { anyhow::bail!("websocket should not run in schedule tests") }) - } - - fn can_hibernate( - &self, - _actor_id: &str, - _gateway_id: &protocol::GatewayId, - _request_id: &protocol::RequestId, - _request: &HttpRequest, - ) -> BoxFuture> { - Box::pin(async { Ok(false) }) - } - } - - fn test_envoy_handle() -> (EnvoyHandle, mpsc::UnboundedReceiver) { - let (envoy_tx, envoy_rx) = mpsc::unbounded_channel(); - let shared = Arc::new(SharedContext { - config: EnvoyConfig { - version: 1, - endpoint: "http://127.0.0.1:1".to_string(), - token: None, - namespace: "test".to_string(), - pool_name: "test".to_string(), - prepopulate_actor_names: HashMap::new(), - metadata: None, - not_global: true, - debug_latency_ms: None, - callbacks: Arc::new(IdleEnvoyCallbacks), - }, - envoy_key: "test-envoy".to_string(), - envoy_tx, - // Forced-std-sync: envoy-client's test SharedContext owns these - // fields as std mutexes, so construction must match that API. - actors: Arc::new(EnvoySharedMutex::new(HashMap::new())), - live_tunnel_requests: Arc::new(EnvoySharedMutex::new(HashMap::new())), - pending_hibernation_restores: Arc::new(EnvoySharedMutex::new(HashMap::new())), - ws_tx: Arc::new(tokio::sync::Mutex::new( - None::>, - )), - protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), - shutting_down: AtomicBool::new(false), - stopped_tx: tokio::sync::watch::channel(true).0, - }); - - (EnvoyHandle::from_shared(shared), envoy_rx) - } - - fn recv_alarm_now( - rx: &mut mpsc::UnboundedReceiver, - expected_actor_id: &str, - expected_generation: Option, - ) -> Option { - match rx.try_recv() { - Ok(ToEnvoyMessage::SetAlarm { - actor_id, - generation, - alarm_ts, - ack_tx, - }) => { - assert_eq!(actor_id, expected_actor_id); - assert_eq!(generation, expected_generation); - if let Some(ack_tx) = ack_tx { - let _ = ack_tx.send(()); - } - alarm_ts - } - Ok(_) => panic!("expected set_alarm envoy message"), - Err(error) => panic!("expected set_alarm envoy message, got {error:?}"), - } - } - - fn assert_no_alarm(rx: &mut mpsc::UnboundedReceiver) { - assert!(matches!( - rx.try_recv(), - Err(mpsc::error::TryRecvError::Empty) - )); - } - - #[test] - fn sync_alarm_skips_driver_push_until_schedule_changes() { - let schedule = ActorContext::new_for_schedule_tests("actor-schedule-dirty"); - let (handle, mut rx) = test_envoy_handle(); - schedule.configure_schedule_envoy(handle, Some(7)); - - schedule.sync_alarm_logged(); - assert_eq!( - recv_alarm_now(&mut rx, "actor-schedule-dirty", Some(7)), - None - ); - - schedule.sync_alarm_logged(); - assert_no_alarm(&mut rx); - - schedule.at(123, "tick", b"args"); - assert_eq!( - recv_alarm_now(&mut rx, "actor-schedule-dirty", Some(7)), - Some(123) - ); - - schedule.sync_alarm_logged(); - assert_no_alarm(&mut rx); - - let event_id = schedule - .next_event() - .expect("scheduled event should exist") - .event_id; - assert!(schedule.cancel_scheduled_event(&event_id)); - assert_eq!( - recv_alarm_now(&mut rx, "actor-schedule-dirty", Some(7)), - None - ); - - schedule.sync_alarm_logged(); - assert_no_alarm(&mut rx); - } - - #[test] - fn sync_future_alarm_uses_dirty_since_push_gate() { - let schedule = ActorContext::new_for_schedule_tests("actor-future-alarm-dirty"); - let (handle, mut rx) = test_envoy_handle(); - schedule.configure_schedule_envoy(handle, Some(8)); - - let future_ts = now_timestamp_ms() + 60_000; - schedule.set_scheduled_events(vec![PersistedScheduleEvent { - event_id: "event-1".to_owned(), - timestamp_ms: future_ts, - action: "tick".to_owned(), - args: vec![1, 2, 3], - }]); - - schedule.sync_future_alarm_logged(); - assert_eq!( - recv_alarm_now(&mut rx, "actor-future-alarm-dirty", Some(8)), - Some(future_ts) - ); - - schedule.sync_future_alarm_logged(); - assert_no_alarm(&mut rx); - } -} +#[path = "../../tests/schedule.rs"] +mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs index aa23acad61..0988b7d43a 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs @@ -542,440 +542,7 @@ impl ActorContext { } } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -mod tests { - use std::sync::Arc; - use std::sync::atomic::{AtomicUsize, Ordering}; - - use super::CanSleep; - use crate::actor::context::ActorContext; - use parking_lot::Mutex as DropMutex; - use rivet_util::async_counter::AsyncCounter; - use tokio::sync::oneshot; - use tokio::task::yield_now; - use tokio::time::{Duration, Instant, advance}; - use tracing::field::{Field, Visit}; - use tracing::{Event, Subscriber}; - use tracing_subscriber::layer::{Context as LayerContext, Layer}; - use tracing_subscriber::prelude::*; - use tracing_subscriber::registry::Registry; - - #[derive(Default)] - struct MessageVisitor { - message: Option, - } - - impl Visit for MessageVisitor { - fn record_str(&mut self, field: &Field, value: &str) { - if field.name() == "message" { - self.message = Some(value.to_owned()); - } - } - - fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) { - if field.name() == "message" { - self.message = Some(format!("{value:?}").trim_matches('"').to_owned()); - } - } - } - - #[derive(Clone)] - struct ShutdownTaskRefusedLayer { - count: Arc, - } - - impl Layer for ShutdownTaskRefusedLayer - where - S: Subscriber, - { - fn on_event(&self, event: &Event<'_>, _ctx: LayerContext<'_, S>) { - if *event.metadata().level() != tracing::Level::WARN { - return; - } - - let mut visitor = MessageVisitor::default(); - event.record(&mut visitor); - if visitor.message.as_deref() - == Some("shutdown task spawned after teardown; aborting immediately") - { - self.count.fetch_add(1, Ordering::SeqCst); - } - } - } - - struct NotifyOnDrop(DropMutex>>); - - impl NotifyOnDrop { - fn new(sender: oneshot::Sender<()>) -> Self { - Self(DropMutex::new(Some(sender))) - } - } - - impl Drop for NotifyOnDrop { - fn drop(&mut self) { - if let Some(sender) = self.0.lock().take() { - let _ = sender.send(()); - } - } - } - - #[tokio::test(start_paused = true)] - async fn shutdown_task_counter_reaches_zero_after_completion() { - let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-complete"); - let (done_tx, done_rx) = oneshot::channel(); - - ctx.track_shutdown_task(async move { - let _ = done_tx.send(()); - }); - - done_rx.await.expect("shutdown task should complete"); - yield_now().await; - - assert_eq!(ctx.shutdown_task_count(), 0); - assert!( - ctx.0 - .sleep - .work - .shutdown_counter - .wait_zero(Instant::now() + Duration::from_millis(1)) - .await - ); - } - - #[tokio::test(start_paused = true)] - async fn shutdown_task_counter_reaches_zero_after_panic() { - let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-panic"); - - ctx.track_shutdown_task(async move { - panic!("boom"); - }); - - yield_now().await; - yield_now().await; - - assert_eq!(ctx.shutdown_task_count(), 0); - assert!( - ctx.0 - .sleep - .work - .shutdown_counter - .wait_zero(Instant::now() + Duration::from_millis(1)) - .await - ); - } - - #[tokio::test(start_paused = true)] - async fn teardown_aborts_tracked_shutdown_tasks() { - let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-teardown"); - let (drop_tx, drop_rx) = oneshot::channel(); - let (_never_tx, never_rx) = oneshot::channel::<()>(); - let notify = NotifyOnDrop::new(drop_tx); - - ctx.track_shutdown_task(async move { - let _notify = notify; - let _ = never_rx.await; - }); - - assert_eq!(ctx.shutdown_task_count(), 1); - - ctx.teardown_sleep_state().await; - advance(Duration::from_millis(1)).await; - - drop_rx - .await - .expect("teardown should abort the tracked task"); - assert_eq!(ctx.shutdown_task_count(), 0); - } - - #[tokio::test(start_paused = true)] - async fn track_shutdown_task_refuses_spawns_after_teardown() { - let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-refuse"); - let warning_count = Arc::new(AtomicUsize::new(0)); - let subscriber = Registry::default().with(ShutdownTaskRefusedLayer { - count: warning_count.clone(), - }); - let _guard = tracing::subscriber::set_default(subscriber); - - ctx.teardown_sleep_state().await; - ctx.track_shutdown_task(async move { - panic!("post-teardown shutdown task should never spawn"); - }); - yield_now().await; - - assert_eq!(ctx.shutdown_task_count(), 0); - assert_eq!(warning_count.load(Ordering::SeqCst), 1); - } - - #[tokio::test(start_paused = true)] - async fn sleep_then_destroy_signal_tasks_do_not_leak_after_teardown() { - let ctx = ActorContext::new_for_sleep_tests("actor-sleep-destroy"); - ctx.set_started(true); - - ctx.sleep().expect("sleep should be accepted after startup"); - ctx.destroy() - .expect("destroy should be accepted after startup"); - - assert_eq!( - ctx.shutdown_task_count(), - 2, - "sleep and destroy bridge work should be tracked before it runs" - ); - - ctx.teardown_sleep_state().await; - advance(Duration::from_millis(1)).await; - - assert_eq!(ctx.shutdown_task_count(), 0); - } - - #[tokio::test(start_paused = true)] - async fn sleep_idle_window_without_work_returns_next_tick() { - let ctx = ActorContext::new_for_sleep_tests("actor-sleep-idle"); - - let waiter = tokio::spawn({ - let ctx = ctx.clone(); - async move { - ctx.wait_for_sleep_idle_window(Instant::now() + Duration::from_secs(1)) - .await - } - }); - - yield_now().await; - - assert!( - waiter.is_finished(), - "idle wait should not poll in 10ms slices" - ); - assert!(waiter.await.expect("idle waiter should join")); - } - - #[tokio::test(start_paused = true)] - async fn sleep_idle_window_waits_for_http_counter_zero_transition() { - let ctx = ActorContext::new_for_sleep_tests("actor-http-idle"); - let counter = Arc::new(AsyncCounter::new()); - counter.register_zero_notify(&ctx.0.sleep.work.idle_notify); - counter.register_change_notify(&ctx.sleep_activity_notify()); - *ctx.0.sleep.http_request_counter.lock() = Some(counter.clone()); - - counter.increment(); - let waiter = tokio::spawn({ - let ctx = ctx.clone(); - async move { - ctx.wait_for_sleep_idle_window(Instant::now() + Duration::from_secs(1)) - .await - } - }); - - yield_now().await; - assert!( - !waiter.is_finished(), - "http request drain should stay blocked while the counter is non-zero" - ); - - counter.decrement(); - advance(Duration::from_millis(1)).await; - yield_now().await; - assert!(waiter.await.expect("http idle waiter should join")); - } - - #[tokio::test(start_paused = true)] - async fn http_request_idle_wait_uses_zero_notify() { - let ctx = ActorContext::new_for_sleep_tests("actor-http-zero-notify"); - let counter = Arc::new(AsyncCounter::new()); - counter.register_zero_notify(&ctx.0.sleep.work.idle_notify); - *ctx.0.sleep.http_request_counter.lock() = Some(counter.clone()); - - counter.increment(); - let waiter = tokio::spawn({ - let ctx = ctx.clone(); - async move { - ctx.wait_for_http_requests_idle().await; - } - }); - - yield_now().await; - assert!( - !waiter.is_finished(), - "http request idle wait should block while the counter is non-zero" - ); - - counter.decrement(); - yield_now().await; - - assert!( - waiter.is_finished(), - "http request idle wait should wake on the zero notification" - ); - waiter.await.expect("http idle waiter should join"); - } - - #[tokio::test(start_paused = true)] - async fn sleep_idle_window_waits_for_websocket_callback_zero_transition() { - let ctx = ActorContext::new_for_sleep_tests("actor-websocket-idle"); - let guard = ctx.websocket_callback_region(); - - let waiter = tokio::spawn({ - let ctx = ctx.clone(); - async move { - ctx.wait_for_sleep_idle_window(Instant::now() + Duration::from_secs(1)) - .await - } - }); - - yield_now().await; - assert!( - !waiter.is_finished(), - "websocket callback drain should stay blocked while the counter is non-zero" - ); - - drop(guard); - advance(Duration::from_millis(1)).await; - yield_now().await; - assert!(waiter.await.expect("websocket idle waiter should join")); - } - - #[tokio::test(start_paused = true)] - async fn sleep_before_started_errors_with_actor_starting() { - let ctx = ActorContext::new_for_sleep_tests("actor-sleep-before-started"); - - let err = ctx - .sleep() - .expect_err("sleep should fail before started is set"); - let rivet_err = rivet_error::RivetError::extract(&err); - assert_eq!(rivet_err.group(), "actor"); - assert_eq!(rivet_err.code(), "starting"); - } - - #[tokio::test(start_paused = true)] - async fn destroy_before_started_errors_with_actor_starting() { - let ctx = ActorContext::new_for_sleep_tests("actor-destroy-before-started"); - - let err = ctx - .destroy() - .expect_err("destroy should fail before started is set"); - let rivet_err = rivet_error::RivetError::extract(&err); - assert_eq!(rivet_err.group(), "actor"); - assert_eq!(rivet_err.code(), "starting"); - } - - #[tokio::test(start_paused = true)] - async fn double_sleep_errors_with_actor_stopping() { - let ctx = ActorContext::new_for_sleep_tests("actor-double-sleep"); - ctx.set_started(true); - - ctx.sleep() - .expect("first sleep call should be accepted after startup"); - - let err = ctx - .sleep() - .expect_err("second sleep call should fail as already requested"); - let rivet_err = rivet_error::RivetError::extract(&err); - assert_eq!(rivet_err.group(), "actor"); - assert_eq!(rivet_err.code(), "stopping"); - } - - #[tokio::test(start_paused = true)] - async fn double_destroy_errors_with_actor_stopping() { - let ctx = ActorContext::new_for_sleep_tests("actor-double-destroy"); - ctx.set_started(true); - - ctx.destroy() - .expect("first destroy call should be accepted after startup"); - - let err = ctx - .destroy() - .expect_err("second destroy call should fail as already requested"); - let rivet_err = rivet_error::RivetError::extract(&err); - assert_eq!(rivet_err.group(), "actor"); - assert_eq!(rivet_err.code(), "stopping"); - } - - #[tokio::test(start_paused = true)] - #[allow(deprecated)] - async fn set_prevent_sleep_is_a_deprecated_noop() { - let ctx = ActorContext::new_for_sleep_tests("actor-prevent-sleep-noop"); - ctx.set_started(true); - - // The stub must never flip the underlying flag. - ctx.set_prevent_sleep(true); - assert!( - !ctx.prevent_sleep(), - "prevent_sleep must stay false because the stub is a no-op" - ); - - // Exhaustive match guards against reintroducing a `PreventSleep` enum - // variant. If a future change adds the variant back, this match stops - // compiling — surfacing the regression at build time rather than via a - // runtime assertion that could silently pass. - match ctx.can_sleep().await { - CanSleep::Yes - | CanSleep::NotReady - | CanSleep::NoSleep - | CanSleep::ActiveHttpRequests - | CanSleep::ActiveKeepAwake - | CanSleep::ActiveInternalKeepAwake - | CanSleep::ActiveRunHandler - | CanSleep::ActiveDisconnectCallbacks - | CanSleep::ActiveConnections - | CanSleep::ActiveWebSocketCallbacks => {} - } - - ctx.set_prevent_sleep(false); - assert!(!ctx.prevent_sleep()); - } - - #[tokio::test(start_paused = true)] - async fn shutdown_deadline_token_aborts_select_awaiting_task() { - // Mirrors the NAPI `RunGracefulCleanup` pattern: a task awaits user - // work and the shutdown_deadline cancellation in a `tokio::select!`. - // If `cancel_shutdown_deadline()` does not propagate to clones of the - // token (a regression we cannot catch with `is_cancelled()` alone), - // the spawned task would hang and the test would time out. - let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-deadline"); - let token = ctx.shutdown_deadline_token(); - assert!(!token.is_cancelled()); - - let aborted = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); - let aborted_clone = aborted.clone(); - let task = tokio::spawn(async move { - tokio::select! { - _ = token.cancelled() => { - aborted_clone.store(true, Ordering::SeqCst); - } - _ = futures::future::pending::<()>() => {} - } - }); - - yield_now().await; - assert!(!aborted.load(Ordering::SeqCst)); - - ctx.cancel_shutdown_deadline(); - task.await.expect("select task should join after cancel"); - assert!( - aborted.load(Ordering::SeqCst), - "select-awaiting task must observe cancel via the cloned token" - ); - } - - #[tokio::test(start_paused = true)] - async fn sleep_after_grace_clears_started_returns_stopping_not_starting() { - // Simulate the lifecycle state machine clearing `started` when it - // transitions into SleepGrace. Calls into `sleep()` after that point - // must surface `Stopping`, not `Starting`. - let ctx = ActorContext::new_for_sleep_tests("actor-sleep-after-grace"); - ctx.set_started(true); - - ctx.sleep().expect("first sleep call should be accepted"); - - // Lifecycle machine clears `started` on transition into SleepGrace. - ctx.set_started(false); - - let err = ctx.sleep().expect_err("second sleep should fail"); - let rivet_err = rivet_error::RivetError::extract(&err); - assert_eq!(rivet_err.group(), "actor"); - assert_eq!( - rivet_err.code(), - "stopping", - "started=false during shutdown must surface stopping, not starting" - ); - } -} +#[path = "../../tests/sleep.rs"] +mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/state.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/state.rs index 7949798dbb..5c4c331eba 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/state.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/state.rs @@ -130,11 +130,13 @@ impl ActorContext { self.set_initial_state(state); } - /// Requests a save without surfacing delivery failures to the caller. + /// Fire-and-forget save request helper. /// /// If the lifecycle event inbox is overloaded or unavailable, this only logs - /// a warning and returns. Call [`Self::request_save_and_wait`] when the caller - /// needs a `Result` and must observe save-request delivery failures. + /// a warning and returns. That `warn!` is the sole failure signal for this + /// path; callers do not receive a `Result`. Call + /// [`Self::request_save_and_wait`] when the caller must observe + /// save-request delivery failures. pub fn request_save(&self, opts: RequestSaveOpts) { if let Err(error) = self.request_save_with_revision(opts) { tracing::warn!(?error, "failed to request actor state save"); @@ -752,6 +754,7 @@ fn throttled_save_delay( } } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -#[path = "../../tests/modules/state.rs"] +#[path = "../../tests/state.rs"] mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs index 670228c9af..029c007d69 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/task.rs @@ -81,8 +81,9 @@ pub(crate) const DISPATCH_INBOX_CHANNEL: &str = "dispatch_inbox"; pub(crate) const LIFECYCLE_EVENT_INBOX_CHANNEL: &str = "lifecycle_event_inbox"; pub use crate::actor::task_types::LifecycleState; +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -#[path = "../../tests/modules/task.rs"] +#[path = "../../tests/task.rs"] mod tests; #[cfg(test)] diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/work_registry.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/work_registry.rs index 7e621f4370..232c8e9910 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/work_registry.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/work_registry.rs @@ -120,39 +120,7 @@ impl Drop for RegionGuard { /// `CountGuard` is the same RAII shape as `RegionGuard`, but used for task-counting sites. pub(crate) type CountGuard = RegionGuard; +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -mod tests { - use std::panic::{AssertUnwindSafe, catch_unwind}; - - use super::WorkRegistry; - - #[test] - fn region_guard_drop_decrements_counter() { - let work = WorkRegistry::new(); - assert_eq!(work.keep_awake.load(), 0); - - { - let _guard = work.keep_awake_guard(); - assert_eq!(work.keep_awake.load(), 1); - } - - assert_eq!(work.keep_awake.load(), 0); - } - - #[test] - fn region_guard_drop_during_panic_unwind_decrements_counter() { - let work = WorkRegistry::new(); - - let result = catch_unwind(AssertUnwindSafe(|| { - let _guard = work.keep_awake_guard(); - assert_eq!(work.keep_awake.load(), 1); - panic!("boom"); - })); - - assert!( - result.is_err(), - "panic should propagate through catch_unwind" - ); - assert_eq!(work.keep_awake.load(), 0); - } -} +#[path = "../../tests/work_registry.rs"] +mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/inspector/auth.rs b/rivetkit-rust/packages/rivetkit-core/src/inspector/auth.rs index 3aeb42a04a..c0093567df 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/inspector/auth.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/inspector/auth.rs @@ -3,6 +3,8 @@ use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; use rand::RngCore; use rivet_error::RivetError as RivetErrorDerive; use serde::{Deserialize, Serialize}; +#[cfg(test)] +use std::sync::{Mutex, OnceLock}; use subtle::ConstantTimeEq; use crate::ActorContext; @@ -13,6 +15,12 @@ const INSPECTOR_TOKEN_KEY: [u8; 1] = [3]; const INSPECTOR_TOKEN_ENV: &str = "_RIVET_TEST_INSPECTOR_TOKEN"; const INSPECTOR_TOKEN_BYTES: usize = 32; +#[cfg(test)] +pub(crate) fn test_inspector_env_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + #[derive(Clone, Copy, Debug, Default)] pub struct InspectorAuth; diff --git a/rivetkit-rust/packages/rivetkit-core/src/inspector/mod.rs b/rivetkit-rust/packages/rivetkit-core/src/inspector/mod.rs index 8704ee18be..96dc6f4181 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/inspector/mod.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/inspector/mod.rs @@ -199,6 +199,7 @@ pub fn encode_response_payload(payload: &[u8], target_version: u16) -> anyhow::R protocol::encode_server_payload(&message, target_version) } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -#[path = "../../tests/modules/inspector.rs"] +#[path = "../../tests/inspector.rs"] mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/envoy_callbacks.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/envoy_callbacks.rs index aadfa67e42..a40ebabc31 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/envoy_callbacks.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/envoy_callbacks.rs @@ -294,60 +294,7 @@ fn preloaded_kv_from_protocol(preloaded_kv: protocol::PreloadedKv) -> PreloadedK ) } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -mod preload_tests { - use super::*; - use crate::actor::state::{PersistedActor, encode_persisted_actor}; - - #[test] - fn decode_preloaded_persisted_actor_distinguishes_bundle_states() { - assert_eq!( - decode_preloaded_persisted_actor(None).expect("no bundle should decode"), - PreloadedPersistedActor::NoBundle - ); - - let requested_empty = protocol::PreloadedKv { - entries: Vec::new(), - requested_get_keys: vec![PERSIST_DATA_KEY.to_vec()], - requested_prefixes: Vec::new(), - }; - assert_eq!( - decode_preloaded_persisted_actor(Some(&requested_empty)) - .expect("empty bundle should decode"), - PreloadedPersistedActor::BundleExistsButEmpty - ); - - let not_requested = protocol::PreloadedKv { - entries: Vec::new(), - requested_get_keys: Vec::new(), - requested_prefixes: Vec::new(), - }; - assert_eq!( - decode_preloaded_persisted_actor(Some(¬_requested)) - .expect("unrequested bundle should decode"), - PreloadedPersistedActor::NoBundle - ); - - let persisted = PersistedActor { - state: vec![1, 2, 3], - ..PersistedActor::default() - }; - let with_actor = protocol::PreloadedKv { - entries: vec![protocol::PreloadedKvEntry { - key: PERSIST_DATA_KEY.to_vec(), - value: encode_persisted_actor(&persisted).expect("persisted actor should encode"), - metadata: protocol::KvMetadata { - version: Vec::new(), - update_ts: 0, - }, - }], - requested_get_keys: vec![PERSIST_DATA_KEY.to_vec()], - requested_prefixes: Vec::new(), - }; - assert_eq!( - decode_preloaded_persisted_actor(Some(&with_actor)) - .expect("persisted actor bundle should decode"), - PreloadedPersistedActor::Some(persisted) - ); - } -} +#[path = "../../tests/envoy_callbacks.rs"] +mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/http.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/http.rs index 9f02d16dcd..614f623622 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/http.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/http.rs @@ -584,6 +584,16 @@ pub(super) async fn build_http_request(request: HttpRequest) -> Result .with_context(|| format!("build actor request for `{}`", request.path)) } +pub(super) fn is_actor_request_path(path: &str) -> bool { + let Some(stripped) = path.strip_prefix("/request") else { + return false; + }; + if stripped.is_empty() { + return true; + } + matches!(stripped.as_bytes().first(), Some(b'/') | Some(b'?')) +} + pub(super) fn normalize_actor_request_path(path: &str) -> String { let Some(stripped) = path.strip_prefix("/request") else { return path.to_owned(); @@ -931,5 +941,5 @@ fn bearer_token_from_authorization(value: &str) -> Option<&str> { } #[cfg(test)] -#[path = "../../tests/modules/registry_http.rs"] +#[path = "../../tests/registry_http.rs"] mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/serverless.rs b/rivetkit-rust/packages/rivetkit-core/src/serverless.rs index 97c67eab6a..24e5c5170e 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/serverless.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/serverless.rs @@ -699,146 +699,7 @@ fn is_loopback_address(hostname: &str) -> bool { matches!(hostname, "127.0.0.1" | "0.0.0.0" | "::1" | "[::1]") } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -mod tests { - use std::collections::HashMap; - - use tokio_util::sync::CancellationToken; - - use super::{ - CoreServerlessRuntime, ServerlessRequest, endpoints_match, normalize_endpoint_url, - }; - use crate::registry::ServeConfig; - - #[test] - fn normalizes_loopback_addresses() { - assert_eq!( - normalize_endpoint_url("http://127.0.0.1:6420/").as_deref(), - Some("http://localhost:6420/") - ); - assert!(endpoints_match( - "http://0.0.0.0:6420/api/", - "http://localhost:6420/api" - )); - } - - #[test] - fn normalizes_rivet_regional_hosts() { - assert!(endpoints_match( - "https://api-us-west-1.rivet.dev", - "https://api.rivet.dev/" - )); - assert!(endpoints_match( - "https://api-lax.staging.rivet.dev", - "https://api.staging.rivet.dev/" - )); - assert!(!endpoints_match( - "https://api-us-west-1.example.com", - "https://api.example.com" - )); - } - - #[test] - fn invalid_urls_fall_back_to_string_comparison() { - assert!(endpoints_match("not a url", "not a url")); - assert!(!endpoints_match("not a url", "also not a url")); - } - - #[tokio::test] - async fn handles_basic_routes() { - let runtime = test_runtime().await; - - let health = runtime - .handle_request(test_request("GET", "/api/rivet/health")) - .await; - assert_eq!(health.status, 200); - let health_body = read_body(health).await; - assert_eq!(health_body["status"], "ok"); - assert_eq!(health_body["runtime"], "rivetkit"); - assert_eq!(health_body["version"], "test-version"); - - let metadata = runtime - .handle_request(test_request("GET", "/api/rivet/metadata")) - .await; - assert_eq!(metadata.status, 200); - let metadata_body = read_body(metadata).await; - assert_eq!(metadata_body["runtime"], "rivetkit"); - assert_eq!(metadata_body["version"], "test-version"); - assert_eq!( - metadata_body["envoy"]["kind"]["serverless"], - serde_json::json!({}) - ); - assert_eq!(metadata_body["clientEndpoint"], "http://client.example"); - assert_eq!(metadata_body["clientNamespace"], "default"); - assert_eq!(metadata_body["clientToken"], "client-token"); - - let root = runtime - .handle_request(test_request("GET", "/api/rivet")) - .await; - assert_eq!(root.status, 200); - let root_body = read_text(root).await; - assert_eq!( - root_body, - "This is a RivetKit server.\n\nLearn more at https://rivet.dev" - ); - } - - #[tokio::test] - async fn start_requires_serverless_headers() { - let runtime = test_runtime().await; - let response = runtime - .handle_request(test_request("POST", "/api/rivet/start")) - .await; - assert_eq!(response.status, 400); - let body = read_body(response).await; - assert_eq!(body["group"], "request"); - assert_eq!(body["code"], "invalid"); - } - - async fn test_runtime() -> CoreServerlessRuntime { - CoreServerlessRuntime::new( - HashMap::new(), - ServeConfig { - version: 1, - endpoint: "http://127.0.0.1:6420".to_owned(), - token: Some("dev".to_owned()), - namespace: "default".to_owned(), - pool_name: "default".to_owned(), - engine_binary_path: None, - handle_inspector_http_in_runtime: true, - serverless_base_path: Some("/api/rivet".to_owned()), - serverless_package_version: "test-version".to_owned(), - serverless_client_endpoint: Some("http://client.example".to_owned()), - serverless_client_namespace: Some("default".to_owned()), - serverless_client_token: Some("client-token".to_owned()), - serverless_validate_endpoint: true, - serverless_max_start_payload_bytes: 1_048_576, - }, - ) - .await - .expect("runtime should build") - } - - fn test_request(method: &str, path: &str) -> ServerlessRequest { - ServerlessRequest { - method: method.to_owned(), - url: format!("http://localhost{path}"), - headers: HashMap::new(), - body: Vec::new(), - cancel_token: CancellationToken::new(), - } - } - - async fn read_body(response: super::ServerlessResponse) -> serde_json::Value { - let text = read_text(response).await; - serde_json::from_str(&text).expect("response should be json") - } - - async fn read_text(mut response: super::ServerlessResponse) -> String { - let mut body = Vec::new(); - while let Some(chunk) = response.body.recv().await { - body.extend(chunk.expect("stream should not error")); - } - String::from_utf8(body).expect("response should be utf-8") - } -} +#[path = "../tests/serverless.rs"] +mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/src/websocket.rs b/rivetkit-rust/packages/rivetkit-core/src/websocket.rs index af4e67cf34..9346583222 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/websocket.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/websocket.rs @@ -227,6 +227,7 @@ impl fmt::Debug for WebSocket { } } +// Test shim keeps moved tests in crate-root tests/ with private-module access. #[cfg(test)] -#[path = "../tests/modules/websocket.rs"] +#[path = "../tests/websocket.rs"] mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/config.rs b/rivetkit-rust/packages/rivetkit-core/tests/config.rs similarity index 100% rename from rivetkit-rust/packages/rivetkit-core/tests/modules/config.rs rename to rivetkit-rust/packages/rivetkit-core/tests/config.rs diff --git a/rivetkit-rust/packages/rivetkit-core/tests/connection.rs b/rivetkit-rust/packages/rivetkit-core/tests/connection.rs new file mode 100644 index 0000000000..48bbe3d844 --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/connection.rs @@ -0,0 +1,324 @@ +use super::*; + +mod moved_tests { + use std::collections::BTreeSet; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use parking_lot::Mutex; + use tokio::sync::{Barrier, mpsc}; + use tokio::task::yield_now; + + use super::{ + HibernatableConnectionMetadata, PersistedConnection, decode_persisted_connection, + encode_persisted_connection, hibernatable_id_from_slice, make_connection_key, + }; + use crate::actor::context::ActorContext; + use crate::actor::messages::ActorEvent; + use crate::actor::preload::PreloadedKv; + use crate::actor::task::LifecycleEvent; + use crate::kv::Kv; + + fn next_non_activity_lifecycle_event( + rx: &mut mpsc::Receiver, + ) -> Option { + rx.try_recv().ok() + } + + #[tokio::test] + async fn restore_persisted_uses_preloaded_connection_prefix_when_present() { + let ctx = ActorContext::new_with_kv( + "actor-preload", + "actor", + Vec::new(), + "local", + Kv::new_in_memory(), + ); + let persisted = PersistedConnection { + id: "conn-preloaded".to_owned(), + parameters: vec![1], + state: vec![2], + gateway_id: [1, 2, 3, 4], + request_id: [5, 6, 7, 8], + request_path: "/socket".to_owned(), + ..PersistedConnection::default() + }; + let preloaded = PreloadedKv::new_with_requested_get_keys( + [( + make_connection_key(&persisted.id), + encode_persisted_connection(&persisted) + .expect("persisted connection should encode"), + )], + Vec::new(), + vec![vec![2]], + ); + + let restored = ctx + .restore_persisted(Some(&preloaded)) + .await + .expect("restore should use preloaded entries instead of unconfigured kv"); + + assert_eq!(restored.len(), 1); + assert_eq!(restored[0].id(), "conn-preloaded"); + assert_eq!(restored[0].state(), vec![2]); + assert!(ctx.connection("conn-preloaded").is_some()); + } + + #[test] + fn persisted_connection_uses_ts_v4_fixed_id_wire_format() { + let persisted = PersistedConnection { + id: "c".to_owned(), + parameters: vec![1, 2], + state: vec![3], + gateway_id: [10, 11, 12, 13], + request_id: [20, 21, 22, 23], + server_message_index: 9, + client_message_index: 10, + request_path: "/".to_owned(), + ..PersistedConnection::default() + }; + + let encoded = + encode_persisted_connection(&persisted).expect("persisted connection should encode"); + + assert_eq!( + encoded, + vec![ + 4, 0, 1, b'c', 2, 1, 2, 1, 3, 0, 10, 11, 12, 13, 20, 21, 22, 23, 9, 0, 10, 0, 1, + b'/', 0, + ] + ); + + let decoded = + decode_persisted_connection(&encoded).expect("persisted connection should decode"); + assert_eq!(decoded.gateway_id, [10, 11, 12, 13]); + assert_eq!(decoded.request_id, [20, 21, 22, 23]); + } + + #[test] + fn hibernatable_id_validation_returns_rivet_error() { + let error = hibernatable_id_from_slice("gateway_id", &[1, 2, 3]) + .expect_err("invalid id should fail"); + let error = rivet_error::RivetError::extract(&error); + + assert_eq!(error.group(), "actor"); + assert_eq!(error.code(), "invalid_request"); + } + + #[tokio::test(start_paused = true)] + async fn concurrent_disconnects_only_emit_one_close_and_one_hibernation_removal() { + let ctx = ActorContext::new_with_kv( + "actor-race", + "actor", + Vec::new(), + "local", + Kv::new_in_memory(), + ); + ctx.configure_connection_runtime(crate::actor::config::ActorConfig::default()); + let (events_tx, mut events_rx) = mpsc::unbounded_channel(); + ctx.configure_actor_events(Some(events_tx)); + let closed = Arc::new(AtomicUsize::new(0)); + let observed_conn_id = Arc::new(Mutex::new(None::)); + + let recv = tokio::spawn({ + let closed = closed.clone(); + let observed_conn_id = observed_conn_id.clone(); + async move { + while let Some(event) = events_rx.recv().await { + match event { + ActorEvent::ConnectionOpen { reply, .. } => reply.send(Ok(())), + ActorEvent::ConnectionClosed { conn } => { + *observed_conn_id.lock() = Some(conn.id().to_owned()); + closed.fetch_add(1, Ordering::SeqCst); + break; + } + other => panic!("unexpected event: {other:?}"), + } + } + } + }); + + let conn = ctx + .connect_with_state( + vec![1], + true, + Some(HibernatableConnectionMetadata { + gateway_id: [1, 2, 3, 4], + request_id: [5, 6, 7, 8], + ..HibernatableConnectionMetadata::default() + }), + None, + async { Ok(vec![9]) }, + ) + .await + .expect("connection should open"); + let conn_id = conn.id().to_owned(); + ctx.record_connections_updated(); + ctx.reset_sleep_timer(); + + let barrier = Arc::new(Barrier::new(2)); + conn.configure_transport_disconnect_handler(Some(Arc::new({ + let barrier = barrier.clone(); + move |_reason| { + let barrier = barrier.clone(); + Box::pin(async move { + barrier.wait().await; + Ok(()) + }) + } + }))); + + let first = tokio::spawn({ + let conn = conn.clone(); + async move { conn.disconnect(Some("first")).await } + }); + let second = tokio::spawn({ + let conn = conn.clone(); + async move { conn.disconnect(Some("second")).await } + }); + + yield_now().await; + first + .await + .expect("first disconnect task should join") + .expect("first disconnect should succeed"); + second + .await + .expect("second disconnect task should join") + .expect("second disconnect should succeed"); + recv.await.expect("event receiver should join"); + + assert_eq!(closed.load(Ordering::SeqCst), 1); + assert_eq!(observed_conn_id.lock().as_deref(), Some(conn_id.as_str())); + assert!(ctx.connection(&conn_id).is_none()); + + let pending = ctx.take_pending_hibernation_changes_inner(); + assert!(pending.updated.is_empty()); + assert_eq!(pending.removed, BTreeSet::from([conn_id])); + } + + #[tokio::test] + async fn hibernatable_set_state_queues_save_and_non_hibernatable_stays_memory_only() { + let ctx = ActorContext::new_with_kv( + "actor-state-dirty", + "actor", + Vec::new(), + "local", + Kv::new_in_memory(), + ); + let (actor_events_tx, mut actor_events_rx) = mpsc::unbounded_channel(); + let (lifecycle_events_tx, mut lifecycle_events_rx) = mpsc::channel(4); + ctx.configure_actor_events(Some(actor_events_tx)); + ctx.configure_lifecycle_events(Some(lifecycle_events_tx)); + + let open_replies = tokio::spawn(async move { + for _ in 0..2 { + match actor_events_rx + .recv() + .await + .expect("open event should arrive") + { + ActorEvent::ConnectionOpen { reply, .. } => reply.send(Ok(())), + other => panic!("unexpected actor event: {other:?}"), + } + } + }); + + let non_hibernatable = ctx + .connect_with_state(vec![1], false, None, None, async { Ok(vec![2]) }) + .await + .expect("non-hibernatable connection should open"); + non_hibernatable.set_state(vec![3]); + assert_eq!(non_hibernatable.state(), vec![3]); + assert!( + ctx.dirty_hibernatable_conns_inner().is_empty(), + "non-hibernatable state changes should not queue persistence" + ); + assert!( + next_non_activity_lifecycle_event(&mut lifecycle_events_rx).is_none(), + "non-hibernatable state changes should not request actor save" + ); + + let hibernatable = ctx + .connect_with_state( + vec![4], + true, + Some(HibernatableConnectionMetadata { + gateway_id: [1, 2, 3, 4], + request_id: [5, 6, 7, 8], + ..HibernatableConnectionMetadata::default() + }), + None, + async { Ok(vec![5]) }, + ) + .await + .expect("hibernatable connection should open"); + hibernatable.set_state(vec![6]); + + assert_eq!( + ctx.dirty_hibernatable_conns_inner() + .into_iter() + .map(|conn| conn.id().to_owned()) + .collect::>(), + vec![hibernatable.id().to_owned()] + ); + assert_eq!( + next_non_activity_lifecycle_event(&mut lifecycle_events_rx) + .expect("hibernatable state change should request save"), + LifecycleEvent::SaveRequested { immediate: false } + ); + + open_replies + .await + .expect("open reply task should join cleanly"); + } + + #[tokio::test(start_paused = true)] + async fn remove_existing_for_disconnect_has_exactly_one_winner() { + let ctx = ActorContext::new_with_kv( + "actor-race", + "actor", + Vec::new(), + "local", + Kv::new_in_memory(), + ); + let conn = super::ConnHandle::new("conn-race", vec![1], vec![2], true); + conn.configure_hibernation(Some(HibernatableConnectionMetadata { + gateway_id: [1, 2, 3, 4], + request_id: [5, 6, 7, 8], + ..HibernatableConnectionMetadata::default() + })); + ctx.insert_existing(conn); + + let barrier = Arc::new(Barrier::new(2)); + let first = tokio::spawn({ + let ctx = ctx.clone(); + let barrier = barrier.clone(); + async move { + barrier.wait().await; + ctx.remove_existing_for_disconnect("conn-race") + .map(|conn| conn.id().to_owned()) + } + }); + let second = tokio::spawn({ + let ctx = ctx.clone(); + let barrier = barrier.clone(); + async move { + barrier.wait().await; + ctx.remove_existing_for_disconnect("conn-race") + .map(|conn| conn.id().to_owned()) + } + }); + + let first = first.await.expect("first task should join"); + let second = second.await.expect("second task should join"); + let winners = [first, second].into_iter().flatten().collect::>(); + + assert_eq!(winners, vec!["conn-race".to_owned()]); + assert!(ctx.connection("conn-race").is_none()); + + let pending = ctx.take_pending_hibernation_changes_inner(); + assert!(pending.updated.is_empty()); + assert_eq!(pending.removed, BTreeSet::from(["conn-race".to_owned()])); + } +} diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/context.rs b/rivetkit-rust/packages/rivetkit-core/tests/context.rs similarity index 100% rename from rivetkit-rust/packages/rivetkit-core/tests/modules/context.rs rename to rivetkit-rust/packages/rivetkit-core/tests/context.rs diff --git a/rivetkit-rust/packages/rivetkit-core/tests/envoy_callbacks.rs b/rivetkit-rust/packages/rivetkit-core/tests/envoy_callbacks.rs new file mode 100644 index 0000000000..d02e0ba974 --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/envoy_callbacks.rs @@ -0,0 +1,59 @@ +mod moved_tests { + use crate::actor::state::{PersistedActor, encode_persisted_actor}; + use crate::registry::envoy_callbacks::{ + PERSIST_DATA_KEY, PreloadedPersistedActor, decode_preloaded_persisted_actor, + }; + use rivet_envoy_client::protocol; + + #[test] + fn decode_preloaded_persisted_actor_distinguishes_bundle_states() { + assert_eq!( + decode_preloaded_persisted_actor(None).expect("no bundle should decode"), + PreloadedPersistedActor::NoBundle + ); + + let requested_empty = protocol::PreloadedKv { + entries: Vec::new(), + requested_get_keys: vec![PERSIST_DATA_KEY.to_vec()], + requested_prefixes: Vec::new(), + }; + assert_eq!( + decode_preloaded_persisted_actor(Some(&requested_empty)) + .expect("empty bundle should decode"), + PreloadedPersistedActor::BundleExistsButEmpty + ); + + let not_requested = protocol::PreloadedKv { + entries: Vec::new(), + requested_get_keys: Vec::new(), + requested_prefixes: Vec::new(), + }; + assert_eq!( + decode_preloaded_persisted_actor(Some(¬_requested)) + .expect("unrequested bundle should decode"), + PreloadedPersistedActor::NoBundle + ); + + let persisted = PersistedActor { + state: vec![1, 2, 3], + ..PersistedActor::default() + }; + let with_actor = protocol::PreloadedKv { + entries: vec![protocol::PreloadedKvEntry { + key: PERSIST_DATA_KEY.to_vec(), + value: encode_persisted_actor(&persisted).expect("persisted actor should encode"), + metadata: protocol::KvMetadata { + version: Vec::new(), + update_ts: 0, + }, + }], + requested_get_keys: vec![PERSIST_DATA_KEY.to_vec()], + requested_prefixes: Vec::new(), + }; + assert_eq!( + decode_preloaded_persisted_actor(Some(&with_actor)) + .expect("persisted actor bundle should decode"), + PreloadedPersistedActor::Some(persisted) + ); + } +} diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/inspector.rs b/rivetkit-rust/packages/rivetkit-core/tests/inspector.rs similarity index 96% rename from rivetkit-rust/packages/rivetkit-core/tests/modules/inspector.rs rename to rivetkit-rust/packages/rivetkit-core/tests/inspector.rs index 60546b237f..acbd64b978 100644 --- a/rivetkit-rust/packages/rivetkit-core/tests/modules/inspector.rs +++ b/rivetkit-rust/packages/rivetkit-core/tests/inspector.rs @@ -10,14 +10,12 @@ mod moved_tests { use crate::actor::context::tests::new_with_kv; use crate::actor::messages::StateDelta; use crate::inspector::InspectorAuth; + use crate::inspector::auth::test_inspector_env_lock; use rivet_error::RivetError; use std::collections::BTreeMap; use std::sync::Arc; - use std::sync::Mutex; use std::sync::atomic::{AtomicUsize, Ordering}; - static INSPECTOR_ENV_LOCK: Mutex<()> = Mutex::new(()); - #[tokio::test] async fn state_updates_increment_inspector_revisions() { let ctx = new_with_kv( @@ -249,7 +247,7 @@ mod moved_tests { #[tokio::test] async fn inspector_auth_uses_env_token_before_kv_fallback() { - let _env_guard = INSPECTOR_ENV_LOCK.lock().expect("env lock poisoned"); + let _env_guard = test_inspector_env_lock().lock().expect("env lock poisoned"); unsafe { std::env::set_var("_RIVET_TEST_INSPECTOR_TOKEN", "env-token"); } @@ -286,7 +284,7 @@ mod moved_tests { #[tokio::test] async fn inspector_auth_falls_back_to_actor_kv_token() { - let _env_guard = INSPECTOR_ENV_LOCK.lock().expect("env lock poisoned"); + let _env_guard = test_inspector_env_lock().lock().expect("env lock poisoned"); unsafe { std::env::remove_var("_RIVET_TEST_INSPECTOR_TOKEN"); } @@ -319,7 +317,7 @@ mod moved_tests { #[tokio::test] async fn inspector_auth_rejects_missing_token() { - let _env_guard = INSPECTOR_ENV_LOCK.lock().expect("env lock poisoned"); + let _env_guard = test_inspector_env_lock().lock().expect("env lock poisoned"); unsafe { std::env::remove_var("_RIVET_TEST_INSPECTOR_TOKEN"); } diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/kv.rs b/rivetkit-rust/packages/rivetkit-core/tests/kv.rs similarity index 100% rename from rivetkit-rust/packages/rivetkit-core/tests/modules/kv.rs rename to rivetkit-rust/packages/rivetkit-core/tests/kv.rs diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/messages.rs b/rivetkit-rust/packages/rivetkit-core/tests/messages.rs similarity index 100% rename from rivetkit-rust/packages/rivetkit-core/tests/modules/messages.rs rename to rivetkit-rust/packages/rivetkit-core/tests/messages.rs diff --git a/rivetkit-rust/packages/rivetkit-core/tests/metrics.rs b/rivetkit-rust/packages/rivetkit-core/tests/metrics.rs new file mode 100644 index 0000000000..61eccc978f --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/metrics.rs @@ -0,0 +1,37 @@ +use super::*; + +mod moved_tests { + use std::panic::{AssertUnwindSafe, catch_unwind}; + + use super::*; + + #[test] + fn duplicate_metric_registration_uses_noop_fallback() { + let registry = Registry::new(); + let first = IntGauge::with_opts(Opts::new( + "duplicate_actor_metric", + "first duplicate metric", + )) + .expect("first gauge should be valid"); + let second = IntGauge::with_opts(Opts::new( + "duplicate_actor_metric", + "second duplicate metric", + )) + .expect("second gauge should be valid"); + + register_metric(®istry, first.clone()); + let result = catch_unwind(AssertUnwindSafe(|| { + register_metric(®istry, second.clone()); + })); + + assert!(result.is_ok()); + assert_eq!( + 1, + registry + .gather() + .iter() + .filter(|family| family.name() == "duplicate_actor_metric") + .count() + ); + } +} diff --git a/rivetkit-rust/packages/rivetkit-core/tests/queue.rs b/rivetkit-rust/packages/rivetkit-core/tests/queue.rs new file mode 100644 index 0000000000..9806fc406a --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/queue.rs @@ -0,0 +1,160 @@ +use super::*; + +mod moved_tests { + use super::{ + PersistedQueueMessage, QUEUE_MESSAGES_PREFIX, QUEUE_METADATA_KEY, QueueMetadata, + QueueNextOpts, QueueWaitOpts, encode_queue_message, encode_queue_metadata, + make_queue_message_key, + }; + + use crate::actor::context::ActorContext; + use crate::actor::preload::PreloadedKv; + use crate::kv::Kv; + use std::time::Duration; + use tokio::task::yield_now; + use tokio_util::sync::CancellationToken; + + fn test_queue() -> ActorContext { + ActorContext::new_with_kv( + "actor-queue", + "queue-test", + Vec::new(), + "local", + Kv::new_in_memory(), + ) + } + + fn assert_actor_aborted(error: anyhow::Error) { + let error = rivet_error::RivetError::extract(&error); + assert_eq!(error.group(), "actor"); + assert_eq!(error.code(), "aborted"); + } + + #[tokio::test] + async fn inspect_messages_uses_preloaded_queue_entries_when_present() { + let queue = ActorContext::new_with_kv( + "actor-queue", + "queue-preload", + Vec::new(), + "local", + Kv::default(), + ); + let metadata = QueueMetadata { + next_id: 8, + size: 1, + }; + let persisted = PersistedQueueMessage { + name: "preloaded".to_owned(), + body: b"body".to_vec(), + created_at: 42, + failure_count: None, + available_at: None, + in_flight: None, + in_flight_at: None, + }; + queue.configure_preload(Some(PreloadedKv::new_with_requested_get_keys( + [ + ( + QUEUE_METADATA_KEY.to_vec(), + encode_queue_metadata(&metadata).expect("metadata should encode"), + ), + ( + make_queue_message_key(7), + encode_queue_message(&persisted).expect("message should encode"), + ), + ], + vec![QUEUE_METADATA_KEY.to_vec()], + vec![QUEUE_MESSAGES_PREFIX.to_vec()], + ))); + + let messages = queue + .inspect_messages() + .await + .expect("queue should initialize from preload without touching kv"); + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].id, 7); + assert_eq!(messages[0].name, "preloaded"); + assert_eq!(messages[0].body, b"body"); + assert_eq!(*queue.0.queue_metadata.lock().await, metadata); + } + + #[tokio::test] + async fn wait_for_names_returns_aborted_when_signal_is_already_cancelled() { + let queue = test_queue(); + let signal = CancellationToken::new(); + signal.cancel(); + + let error = queue + .wait_for_names( + vec!["missing".to_owned()], + QueueWaitOpts { + signal: Some(signal), + ..Default::default() + }, + ) + .await + .expect_err("already-cancelled waits should abort immediately"); + + assert_actor_aborted(error); + } + + #[tokio::test(start_paused = true)] + async fn wait_for_names_returns_aborted_when_signal_cancels_during_wait() { + let queue = test_queue(); + let signal = CancellationToken::new(); + let wait_signal = signal.clone(); + let wait_queue = queue.clone(); + + let wait = tokio::spawn(async move { + wait_queue + .wait_for_names( + vec!["missing".to_owned()], + QueueWaitOpts { + timeout: Some(Duration::from_secs(60)), + signal: Some(wait_signal), + ..Default::default() + }, + ) + .await + }); + + yield_now().await; + signal.cancel(); + + let error = wait + .await + .expect("wait task should join") + .expect_err("cancelled waits should abort"); + + assert_actor_aborted(error); + } + + #[tokio::test(start_paused = true)] + async fn next_returns_aborted_when_actor_signal_cancels_during_wait() { + let queue = test_queue(); + + let wait = tokio::spawn({ + let queue = queue.clone(); + async move { + queue + .next(QueueNextOpts { + names: Some(vec!["missing".to_owned()]), + timeout: Some(Duration::from_secs(60)), + ..Default::default() + }) + .await + } + }); + + yield_now().await; + queue.mark_destroy_requested(); + + let error = wait + .await + .expect("wait task should join") + .expect_err("cancelled actor waits should abort"); + + assert_actor_aborted(error); + } +} diff --git a/rivetkit-rust/packages/rivetkit-core/tests/registry_http.rs b/rivetkit-rust/packages/rivetkit-core/tests/registry_http.rs new file mode 100644 index 0000000000..376d297948 --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/registry_http.rs @@ -0,0 +1,237 @@ +use super::*; + +mod moved_tests { + use std::collections::HashMap; + use std::time::Duration; + + use super::{ + HttpRequest, HttpResponseEncoding, authorization_bearer_token, + authorization_bearer_token_map, framework_action_error_response, is_actor_request_path, + message_boundary_error_response, normalize_actor_request_path, request_encoding, + request_has_bearer_token, workflow_dispatch_result, + }; + use crate::actor::action::ActionDispatchError; + use crate::error::ActorLifecycle as ActorLifecycleError; + use http::StatusCode; + use rivet_error::RivetError; + use serde_json::json; + use vbare::OwnedVersionedData; + + #[derive(RivetError)] + #[error("message", "incoming_too_long", "Incoming message too long")] + struct IncomingMessageTooLong; + + #[derive(RivetError)] + #[error("message", "outgoing_too_long", "Outgoing message too long")] + struct OutgoingMessageTooLong; + + #[test] + fn workflow_dispatch_result_marks_handled_workflow_as_enabled() { + assert_eq!( + workflow_dispatch_result(Ok(Some(vec![1, 2, 3]))) + .expect("workflow dispatch should succeed"), + (true, Some(vec![1, 2, 3])), + ); + assert_eq!( + workflow_dispatch_result(Ok(None)).expect("workflow dispatch should succeed"), + (true, None), + ); + } + + #[test] + fn workflow_dispatch_result_treats_dropped_reply_as_disabled() { + assert_eq!( + workflow_dispatch_result(Err(ActorLifecycleError::DroppedReply.build())) + .expect("dropped reply should map to workflow disabled"), + (false, None), + ); + } + + #[test] + fn workflow_dispatch_result_preserves_non_dropped_reply_errors() { + let error = workflow_dispatch_result(Err(ActorLifecycleError::Destroying.build())) + .expect_err("non-dropped reply errors should be preserved"); + let error = rivet_error::RivetError::extract(&error); + assert_eq!(error.group(), "actor"); + assert_eq!(error.code(), "destroying"); + } + + #[test] + fn inspector_error_status_maps_action_timeout_to_408() { + assert_eq!( + super::inspector_error_status("actor", "action_timed_out"), + StatusCode::REQUEST_TIMEOUT, + ); + } + + #[test] + fn authorization_bearer_token_accepts_case_insensitive_scheme_and_whitespace() { + let mut headers = http::HeaderMap::new(); + headers.insert( + http::header::AUTHORIZATION, + "bearer test-token".parse().unwrap(), + ); + + assert_eq!(authorization_bearer_token(&headers), Some("test-token")); + + let map = HashMap::from([( + http::header::AUTHORIZATION.as_str().to_owned(), + "BEARER\ttest-token".to_owned(), + )]); + assert_eq!(authorization_bearer_token_map(&map), Some("test-token")); + } + + #[test] + fn request_has_bearer_token_uses_same_authorization_parser() { + let request = HttpRequest { + method: "GET".to_owned(), + path: "/metrics".to_owned(), + headers: HashMap::from([( + http::header::AUTHORIZATION.as_str().to_owned(), + "Bearer configured".to_owned(), + )]), + body: Some(Vec::new()), + body_stream: None, + }; + + assert!(request_has_bearer_token(&request, Some("configured"))); + assert!(!request_has_bearer_token(&request, Some("other"))); + } + + #[tokio::test] + async fn action_dispatch_timeout_returns_structured_error() { + let error = super::with_action_dispatch_timeout(Duration::from_millis(1), async { + tokio::time::sleep(Duration::from_secs(60)).await; + Ok::, ActionDispatchError>(Vec::new()) + }) + .await + .expect_err("timeout should return an action dispatch error"); + + assert_eq!(error.group, "actor"); + assert_eq!(error.code, "action_timed_out"); + assert_eq!(error.message, "Action timed out"); + } + + #[tokio::test] + async fn framework_action_timeout_returns_structured_error() { + let error = super::with_framework_action_timeout(Duration::from_millis(1), async { + tokio::time::sleep(Duration::from_secs(60)).await; + Ok::<(), anyhow::Error>(()) + }) + .await + .expect_err("timeout should return a framework error"); + let error = RivetError::extract(&error); + + assert_eq!(error.group(), "actor"); + assert_eq!(error.code(), "action_timed_out"); + assert_eq!(error.message(), "Action timed out"); + } + + #[test] + fn framework_action_error_response_maps_timeout_to_408() { + let response = framework_action_error_response( + HttpResponseEncoding::Json, + ActionDispatchError { + group: "actor".to_owned(), + code: "action_timed_out".to_owned(), + message: "Action timed out".to_owned(), + metadata: None, + }, + ) + .expect("timeout error response should serialize"); + + assert_eq!(response.status, StatusCode::REQUEST_TIMEOUT.as_u16()); + assert_eq!( + response.body, + Some( + serde_json::to_vec(&json!({ + "group": "actor", + "code": "action_timed_out", + "message": "Action timed out", + })) + .expect("json body should encode") + ) + ); + } + + #[test] + fn message_boundary_error_response_defaults_to_json() { + let response = message_boundary_error_response( + HttpResponseEncoding::Json, + StatusCode::BAD_REQUEST, + IncomingMessageTooLong.build(), + ) + .expect("json response should serialize"); + + assert_eq!(response.status, StatusCode::BAD_REQUEST.as_u16()); + assert_eq!( + response.headers.get(http::header::CONTENT_TYPE.as_str()), + Some(&"application/json".to_owned()) + ); + assert_eq!( + response.body, + Some( + serde_json::to_vec(&json!({ + "group": "message", + "code": "incoming_too_long", + "message": "Incoming message too long", + })) + .expect("json body should encode") + ) + ); + } + + #[test] + fn request_encoding_reads_cbor_header() { + let mut headers = http::HeaderMap::new(); + headers.insert("x-rivet-encoding", "cbor".parse().unwrap()); + + assert_eq!(request_encoding(&headers), HttpResponseEncoding::Cbor); + } + + #[test] + fn normalize_actor_request_path_preserves_raw_root_paths() { + assert!(is_actor_request_path("/request")); + assert!(is_actor_request_path("/request/")); + assert!(is_actor_request_path("/request/users/1")); + assert!(is_actor_request_path("/request?foo=bar")); + assert_eq!(normalize_actor_request_path("/request"), "/"); + assert_eq!(normalize_actor_request_path("/request/"), "/"); + assert_eq!(normalize_actor_request_path("/request/users/1"), "/users/1",); + assert_eq!(normalize_actor_request_path("/request?foo=bar"), "?foo=bar"); + } + + #[test] + fn normalize_actor_request_path_does_not_mark_framework_routes_as_raw() { + assert!(!is_actor_request_path("/")); + assert!(!is_actor_request_path("/action/ping")); + assert!(!is_actor_request_path("/requestfoo")); + assert_eq!(normalize_actor_request_path("/"), "/"); + assert_eq!(normalize_actor_request_path("/action/ping"), "/action/ping"); + assert_eq!(normalize_actor_request_path("/requestfoo"), "/requestfoo"); + } + + #[test] + fn message_boundary_error_response_serializes_bare_v3() { + let response = message_boundary_error_response( + HttpResponseEncoding::Bare, + StatusCode::BAD_REQUEST, + OutgoingMessageTooLong.build(), + ) + .expect("bare response should serialize"); + + assert_eq!( + response.headers.get(http::header::CONTENT_TYPE.as_str()), + Some(&"application/octet-stream".to_owned()) + ); + + let body = response.body.expect("bare response should include body"); + let decoded = + ::deserialize_with_embedded_version(&body) + .expect("bare error should decode"); + assert_eq!(decoded.group, "message"); + assert_eq!(decoded.code, "outgoing_too_long"); + assert_eq!(decoded.message, "Outgoing message too long"); + assert_eq!(decoded.metadata, None); + } +} diff --git a/rivetkit-rust/packages/rivetkit-core/tests/schedule.rs b/rivetkit-rust/packages/rivetkit-core/tests/schedule.rs new file mode 100644 index 0000000000..11665fecc3 --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/schedule.rs @@ -0,0 +1,197 @@ +use super::*; + +mod moved_tests { + use std::collections::HashMap; + use std::sync::Mutex as EnvoySharedMutex; + use std::sync::atomic::AtomicBool; + + use rivet_envoy_client::config::{ + BoxFuture, EnvoyCallbacks, EnvoyConfig, HttpRequest, HttpResponse, WebSocketHandler, + WebSocketSender, + }; + use rivet_envoy_client::context::{SharedContext, WsTxMessage}; + use rivet_envoy_client::envoy::ToEnvoyMessage; + use rivet_envoy_client::protocol; + use tokio::sync::mpsc; + + use super::*; + + struct IdleEnvoyCallbacks; + + impl EnvoyCallbacks for IdleEnvoyCallbacks { + fn on_actor_start( + &self, + _handle: EnvoyHandle, + _actor_id: String, + _generation: u32, + _config: protocol::ActorConfig, + _preloaded_kv: Option, + _sqlite_startup_data: Option, + ) -> BoxFuture> { + Box::pin(async { Ok(()) }) + } + + fn on_shutdown(&self) {} + + fn fetch( + &self, + _handle: EnvoyHandle, + _actor_id: String, + _gateway_id: protocol::GatewayId, + _request_id: protocol::RequestId, + _request: HttpRequest, + ) -> BoxFuture> { + Box::pin(async { anyhow::bail!("fetch should not run in schedule tests") }) + } + + fn websocket( + &self, + _handle: EnvoyHandle, + _actor_id: String, + _gateway_id: protocol::GatewayId, + _request_id: protocol::RequestId, + _request: HttpRequest, + _path: String, + _headers: HashMap, + _is_hibernatable: bool, + _is_restoring_hibernatable: bool, + _sender: WebSocketSender, + ) -> BoxFuture> { + Box::pin(async { anyhow::bail!("websocket should not run in schedule tests") }) + } + + fn can_hibernate( + &self, + _actor_id: &str, + _gateway_id: &protocol::GatewayId, + _request_id: &protocol::RequestId, + _request: &HttpRequest, + ) -> BoxFuture> { + Box::pin(async { Ok(false) }) + } + } + + fn test_envoy_handle() -> (EnvoyHandle, mpsc::UnboundedReceiver) { + let (envoy_tx, envoy_rx) = mpsc::unbounded_channel(); + let shared = Arc::new(SharedContext { + config: EnvoyConfig { + version: 1, + endpoint: "http://127.0.0.1:1".to_string(), + token: None, + namespace: "test".to_string(), + pool_name: "test".to_string(), + prepopulate_actor_names: HashMap::new(), + metadata: None, + not_global: true, + debug_latency_ms: None, + callbacks: Arc::new(IdleEnvoyCallbacks), + }, + envoy_key: "test-envoy".to_string(), + envoy_tx, + actors: Arc::new(EnvoySharedMutex::new(HashMap::new())), + live_tunnel_requests: Arc::new(EnvoySharedMutex::new(HashMap::new())), + pending_hibernation_restores: Arc::new(EnvoySharedMutex::new(HashMap::new())), + ws_tx: Arc::new(tokio::sync::Mutex::new( + None::>, + )), + protocol_metadata: Arc::new(tokio::sync::Mutex::new(None)), + shutting_down: AtomicBool::new(false), + stopped_tx: tokio::sync::watch::channel(true).0, + }); + + (EnvoyHandle::from_shared(shared), envoy_rx) + } + + fn recv_alarm_now( + rx: &mut mpsc::UnboundedReceiver, + expected_actor_id: &str, + expected_generation: Option, + ) -> Option { + match rx.try_recv() { + Ok(ToEnvoyMessage::SetAlarm { + actor_id, + generation, + alarm_ts, + ack_tx, + }) => { + assert_eq!(actor_id, expected_actor_id); + assert_eq!(generation, expected_generation); + if let Some(ack_tx) = ack_tx { + let _ = ack_tx.send(()); + } + alarm_ts + } + Ok(_) => panic!("expected set_alarm envoy message"), + Err(error) => panic!("expected set_alarm envoy message, got {error:?}"), + } + } + + fn assert_no_alarm(rx: &mut mpsc::UnboundedReceiver) { + assert!(matches!( + rx.try_recv(), + Err(mpsc::error::TryRecvError::Empty) + )); + } + + #[test] + fn sync_alarm_skips_driver_push_until_schedule_changes() { + let schedule = ActorContext::new_for_schedule_tests("actor-schedule-dirty"); + let (handle, mut rx) = test_envoy_handle(); + schedule.configure_schedule_envoy(handle, Some(7)); + + schedule.sync_alarm_logged(); + assert_eq!( + recv_alarm_now(&mut rx, "actor-schedule-dirty", Some(7)), + None + ); + + schedule.sync_alarm_logged(); + assert_no_alarm(&mut rx); + + schedule.at(123, "tick", b"args"); + assert_eq!( + recv_alarm_now(&mut rx, "actor-schedule-dirty", Some(7)), + Some(123) + ); + + schedule.sync_alarm_logged(); + assert_no_alarm(&mut rx); + + let event_id = schedule + .next_event() + .expect("scheduled event should exist") + .event_id; + assert!(schedule.cancel_scheduled_event(&event_id)); + assert_eq!( + recv_alarm_now(&mut rx, "actor-schedule-dirty", Some(7)), + None + ); + + schedule.sync_alarm_logged(); + assert_no_alarm(&mut rx); + } + + #[test] + fn sync_future_alarm_uses_dirty_since_push_gate() { + let schedule = ActorContext::new_for_schedule_tests("actor-future-alarm-dirty"); + let (handle, mut rx) = test_envoy_handle(); + schedule.configure_schedule_envoy(handle, Some(8)); + + let future_ts = now_timestamp_ms() + 60_000; + schedule.set_scheduled_events(vec![PersistedScheduleEvent { + event_id: "event-1".to_owned(), + timestamp_ms: future_ts, + action: "tick".to_owned(), + args: vec![1, 2, 3], + }]); + + schedule.sync_future_alarm_logged(); + assert_eq!( + recv_alarm_now(&mut rx, "actor-future-alarm-dirty", Some(8)), + Some(future_ts) + ); + + schedule.sync_future_alarm_logged(); + assert_no_alarm(&mut rx); + } +} diff --git a/rivetkit-rust/packages/rivetkit-core/tests/serverless.rs b/rivetkit-rust/packages/rivetkit-core/tests/serverless.rs new file mode 100644 index 0000000000..b91e3a1d49 --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/serverless.rs @@ -0,0 +1,144 @@ +use super::*; + +mod moved_tests { + use std::collections::HashMap; + + use tokio_util::sync::CancellationToken; + + use super::{ + CoreServerlessRuntime, ServerlessRequest, endpoints_match, normalize_endpoint_url, + }; + use crate::registry::ServeConfig; + + #[test] + fn normalizes_loopback_addresses() { + assert_eq!( + normalize_endpoint_url("http://127.0.0.1:6420/").as_deref(), + Some("http://localhost:6420/") + ); + assert!(endpoints_match( + "http://0.0.0.0:6420/api/", + "http://localhost:6420/api" + )); + } + + #[test] + fn normalizes_rivet_regional_hosts() { + assert!(endpoints_match( + "https://api-us-west-1.rivet.dev", + "https://api.rivet.dev/" + )); + assert!(endpoints_match( + "https://api-lax.staging.rivet.dev", + "https://api.staging.rivet.dev/" + )); + assert!(!endpoints_match( + "https://api-us-west-1.example.com", + "https://api.example.com" + )); + } + + #[test] + fn invalid_urls_fall_back_to_string_comparison() { + assert!(endpoints_match("not a url", "not a url")); + assert!(!endpoints_match("not a url", "also not a url")); + } + + #[tokio::test] + async fn handles_basic_routes() { + let runtime = test_runtime().await; + + let health = runtime + .handle_request(test_request("GET", "/api/rivet/health")) + .await; + assert_eq!(health.status, 200); + let health_body = read_body(health).await; + assert_eq!(health_body["status"], "ok"); + assert_eq!(health_body["runtime"], "rivetkit"); + assert_eq!(health_body["version"], "test-version"); + + let metadata = runtime + .handle_request(test_request("GET", "/api/rivet/metadata")) + .await; + assert_eq!(metadata.status, 200); + let metadata_body = read_body(metadata).await; + assert_eq!(metadata_body["runtime"], "rivetkit"); + assert_eq!(metadata_body["version"], "test-version"); + assert_eq!( + metadata_body["envoy"]["kind"]["serverless"], + serde_json::json!({}) + ); + assert_eq!(metadata_body["clientEndpoint"], "http://client.example"); + assert_eq!(metadata_body["clientNamespace"], "default"); + assert_eq!(metadata_body["clientToken"], "client-token"); + + let root = runtime + .handle_request(test_request("GET", "/api/rivet")) + .await; + assert_eq!(root.status, 200); + let root_body = read_text(root).await; + assert_eq!( + root_body, + "This is a RivetKit server.\n\nLearn more at https://rivet.dev" + ); + } + + #[tokio::test] + async fn start_requires_serverless_headers() { + let runtime = test_runtime().await; + let response = runtime + .handle_request(test_request("POST", "/api/rivet/start")) + .await; + assert_eq!(response.status, 400); + let body = read_body(response).await; + assert_eq!(body["group"], "request"); + assert_eq!(body["code"], "invalid"); + } + + async fn test_runtime() -> CoreServerlessRuntime { + CoreServerlessRuntime::new( + HashMap::new(), + ServeConfig { + version: 1, + endpoint: "http://127.0.0.1:6420".to_owned(), + token: Some("dev".to_owned()), + namespace: "default".to_owned(), + pool_name: "default".to_owned(), + engine_binary_path: None, + handle_inspector_http_in_runtime: true, + serverless_base_path: Some("/api/rivet".to_owned()), + serverless_package_version: "test-version".to_owned(), + serverless_client_endpoint: Some("http://client.example".to_owned()), + serverless_client_namespace: Some("default".to_owned()), + serverless_client_token: Some("client-token".to_owned()), + serverless_validate_endpoint: true, + serverless_max_start_payload_bytes: 1_048_576, + }, + ) + .await + .expect("runtime should build") + } + + fn test_request(method: &str, path: &str) -> ServerlessRequest { + ServerlessRequest { + method: method.to_owned(), + url: format!("http://localhost{path}"), + headers: HashMap::new(), + body: Vec::new(), + cancel_token: CancellationToken::new(), + } + } + + async fn read_body(response: super::ServerlessResponse) -> serde_json::Value { + let text = read_text(response).await; + serde_json::from_str(&text).expect("response should be json") + } + + async fn read_text(mut response: super::ServerlessResponse) -> String { + let mut body = Vec::new(); + while let Some(chunk) = response.body.recv().await { + body.extend(chunk.expect("stream should not error")); + } + String::from_utf8(body).expect("response should be utf-8") + } +} diff --git a/rivetkit-rust/packages/rivetkit-core/tests/sleep.rs b/rivetkit-rust/packages/rivetkit-core/tests/sleep.rs new file mode 100644 index 0000000000..e08ba68d93 --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/sleep.rs @@ -0,0 +1,428 @@ +mod moved_tests { + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use crate::actor::context::ActorContext; + use parking_lot::Mutex as DropMutex; + use rivet_util::async_counter::AsyncCounter; + use tokio::sync::oneshot; + use tokio::task::yield_now; + use tokio::time::{Duration, Instant, advance}; + use tracing::field::{Field, Visit}; + use tracing::{Event, Subscriber}; + use tracing_subscriber::layer::{Context as LayerContext, Layer}; + use tracing_subscriber::prelude::*; + use tracing_subscriber::registry::Registry; + + #[derive(Default)] + struct MessageVisitor { + message: Option, + } + + impl Visit for MessageVisitor { + fn record_str(&mut self, field: &Field, value: &str) { + if field.name() == "message" { + self.message = Some(value.to_owned()); + } + } + + fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) { + if field.name() == "message" { + self.message = Some(format!("{value:?}").trim_matches('"').to_owned()); + } + } + } + + #[derive(Clone)] + struct ShutdownTaskRefusedLayer { + count: Arc, + } + + impl Layer for ShutdownTaskRefusedLayer + where + S: Subscriber, + { + fn on_event(&self, event: &Event<'_>, _ctx: LayerContext<'_, S>) { + if *event.metadata().level() != tracing::Level::WARN { + return; + } + + let mut visitor = MessageVisitor::default(); + event.record(&mut visitor); + if visitor.message.as_deref() + == Some("shutdown task spawned after teardown; aborting immediately") + { + self.count.fetch_add(1, Ordering::SeqCst); + } + } + } + + struct NotifyOnDrop(DropMutex>>); + + impl NotifyOnDrop { + fn new(sender: oneshot::Sender<()>) -> Self { + Self(DropMutex::new(Some(sender))) + } + } + + impl Drop for NotifyOnDrop { + fn drop(&mut self) { + if let Some(sender) = self.0.lock().take() { + let _ = sender.send(()); + } + } + } + + #[tokio::test(start_paused = true)] + async fn shutdown_task_counter_reaches_zero_after_completion() { + let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-complete"); + let (done_tx, done_rx) = oneshot::channel(); + + ctx.track_shutdown_task(async move { + let _ = done_tx.send(()); + }); + + done_rx.await.expect("shutdown task should complete"); + yield_now().await; + + assert_eq!(ctx.shutdown_task_count(), 0); + assert!( + ctx.0 + .sleep + .work + .shutdown_counter + .wait_zero(Instant::now() + Duration::from_millis(1)) + .await + ); + } + + #[tokio::test(start_paused = true)] + async fn shutdown_task_counter_reaches_zero_after_panic() { + let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-panic"); + + ctx.track_shutdown_task(async move { + panic!("boom"); + }); + + yield_now().await; + yield_now().await; + + assert_eq!(ctx.shutdown_task_count(), 0); + assert!( + ctx.0 + .sleep + .work + .shutdown_counter + .wait_zero(Instant::now() + Duration::from_millis(1)) + .await + ); + } + + #[tokio::test(start_paused = true)] + async fn teardown_aborts_tracked_shutdown_tasks() { + let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-teardown"); + let (drop_tx, drop_rx) = oneshot::channel(); + let (_never_tx, never_rx) = oneshot::channel::<()>(); + let notify = NotifyOnDrop::new(drop_tx); + + ctx.track_shutdown_task(async move { + let _notify = notify; + let _ = never_rx.await; + }); + + assert_eq!(ctx.shutdown_task_count(), 1); + + ctx.teardown_sleep_state().await; + advance(Duration::from_millis(1)).await; + + drop_rx + .await + .expect("teardown should abort the tracked task"); + assert_eq!(ctx.shutdown_task_count(), 0); + } + + #[tokio::test(start_paused = true)] + async fn track_shutdown_task_refuses_spawns_after_teardown() { + let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-refuse"); + let warning_count = Arc::new(AtomicUsize::new(0)); + let subscriber = Registry::default().with(ShutdownTaskRefusedLayer { + count: warning_count.clone(), + }); + let _guard = tracing::subscriber::set_default(subscriber); + + ctx.teardown_sleep_state().await; + ctx.track_shutdown_task(async move { + panic!("post-teardown shutdown task should never spawn"); + }); + yield_now().await; + + assert_eq!(ctx.shutdown_task_count(), 0); + assert_eq!(warning_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test(start_paused = true)] + async fn sleep_then_destroy_signal_tasks_do_not_leak_after_teardown() { + let ctx = ActorContext::new_for_sleep_tests("actor-sleep-destroy"); + ctx.set_started(true); + + ctx.sleep().expect("sleep should succeed after started is set"); + ctx.destroy() + .expect("destroy should succeed after started is set"); + + assert_eq!( + ctx.shutdown_task_count(), + 2, + "sleep and destroy bridge work should be tracked before it runs" + ); + + ctx.teardown_sleep_state().await; + advance(Duration::from_millis(1)).await; + + assert_eq!(ctx.shutdown_task_count(), 0); + } + + #[tokio::test(start_paused = true)] + async fn sleep_idle_window_without_work_returns_next_tick() { + let ctx = ActorContext::new_for_sleep_tests("actor-sleep-idle"); + + let waiter = tokio::spawn({ + let ctx = ctx.clone(); + async move { + ctx.wait_for_sleep_idle_window(Instant::now() + Duration::from_secs(1)) + .await + } + }); + + yield_now().await; + + assert!( + waiter.is_finished(), + "idle wait should not poll in 10ms slices" + ); + assert!(waiter.await.expect("idle waiter should join")); + } + + #[tokio::test(start_paused = true)] + async fn sleep_idle_window_waits_for_http_counter_zero_transition() { + let ctx = ActorContext::new_for_sleep_tests("actor-http-idle"); + let counter = Arc::new(AsyncCounter::new()); + counter.register_zero_notify(&ctx.0.sleep.work.idle_notify); + counter.register_change_notify(&ctx.sleep_activity_notify()); + *ctx.0.sleep.http_request_counter.lock() = Some(counter.clone()); + + counter.increment(); + let waiter = tokio::spawn({ + let ctx = ctx.clone(); + async move { + ctx.wait_for_sleep_idle_window(Instant::now() + Duration::from_secs(1)) + .await + } + }); + + yield_now().await; + assert!( + !waiter.is_finished(), + "http request drain should stay blocked while the counter is non-zero" + ); + + counter.decrement(); + advance(Duration::from_millis(1)).await; + yield_now().await; + assert!(waiter.await.expect("http idle waiter should join")); + } + + #[tokio::test(start_paused = true)] + async fn http_request_idle_wait_uses_zero_notify() { + let ctx = ActorContext::new_for_sleep_tests("actor-http-zero-notify"); + let counter = Arc::new(AsyncCounter::new()); + counter.register_zero_notify(&ctx.0.sleep.work.idle_notify); + *ctx.0.sleep.http_request_counter.lock() = Some(counter.clone()); + + counter.increment(); + let waiter = tokio::spawn({ + let ctx = ctx.clone(); + async move { + ctx.wait_for_http_requests_idle().await; + } + }); + + yield_now().await; + assert!( + !waiter.is_finished(), + "http request idle wait should block while the counter is non-zero" + ); + + counter.decrement(); + yield_now().await; + + assert!( + waiter.is_finished(), + "http request idle wait should wake on the zero notification" + ); + waiter.await.expect("http idle waiter should join"); + } + + #[tokio::test(start_paused = true)] + async fn sleep_idle_window_waits_for_websocket_callback_zero_transition() { + let ctx = ActorContext::new_for_sleep_tests("actor-websocket-idle"); + let guard = ctx.websocket_callback_region(); + + let waiter = tokio::spawn({ + let ctx = ctx.clone(); + async move { + ctx.wait_for_sleep_idle_window(Instant::now() + Duration::from_secs(1)) + .await + } + }); + + yield_now().await; + assert!( + !waiter.is_finished(), + "websocket callback drain should stay blocked while the counter is non-zero" + ); + + drop(guard); + advance(Duration::from_millis(1)).await; + yield_now().await; + assert!(waiter.await.expect("websocket idle waiter should join")); + } + + #[tokio::test(start_paused = true)] + async fn sleep_before_started_errors_with_actor_starting() { + let ctx = ActorContext::new_for_sleep_tests("actor-sleep-before-started"); + + let err = ctx + .sleep() + .expect_err("sleep should fail before started is set"); + let rivet_err = rivet_error::RivetError::extract(&err); + assert_eq!(rivet_err.group(), "actor"); + assert_eq!(rivet_err.code(), "starting"); + } + + #[tokio::test(start_paused = true)] + async fn destroy_before_started_errors_with_actor_starting() { + let ctx = ActorContext::new_for_sleep_tests("actor-destroy-before-started"); + + let err = ctx + .destroy() + .expect_err("destroy should fail before started is set"); + let rivet_err = rivet_error::RivetError::extract(&err); + assert_eq!(rivet_err.group(), "actor"); + assert_eq!(rivet_err.code(), "starting"); + } + + #[tokio::test(start_paused = true)] + async fn double_sleep_errors_with_actor_stopping() { + let ctx = ActorContext::new_for_sleep_tests("actor-double-sleep"); + ctx.set_started(true); + + ctx.sleep() + .expect("first sleep call should be accepted after startup"); + + let err = ctx + .sleep() + .expect_err("second sleep call should fail as already requested"); + let rivet_err = rivet_error::RivetError::extract(&err); + assert_eq!(rivet_err.group(), "actor"); + assert_eq!(rivet_err.code(), "stopping"); + } + + #[tokio::test(start_paused = true)] + async fn double_destroy_errors_with_actor_stopping() { + let ctx = ActorContext::new_for_sleep_tests("actor-double-destroy"); + ctx.set_started(true); + + ctx.destroy() + .expect("first destroy call should be accepted after startup"); + + let err = ctx + .destroy() + .expect_err("second destroy call should fail as already requested"); + let rivet_err = rivet_error::RivetError::extract(&err); + assert_eq!(rivet_err.group(), "actor"); + assert_eq!(rivet_err.code(), "stopping"); + } + + // `set_prevent_sleep` is a deprecated no-op kept for NAPI bridge + // compatibility. The exhaustive `CanSleep` match below is a build-time + // guard against reintroducing a `PreventSleep` enum variant. + #[tokio::test(start_paused = true)] + #[allow(deprecated)] + async fn set_prevent_sleep_is_a_deprecated_noop() { + use crate::actor::sleep::CanSleep; + + let ctx = ActorContext::new_for_sleep_tests("actor-prevent-sleep-noop"); + ctx.set_started(true); + + ctx.set_prevent_sleep(true); + match ctx.can_sleep().await { + CanSleep::Yes + | CanSleep::NotReady + | CanSleep::NoSleep + | CanSleep::ActiveHttpRequests + | CanSleep::ActiveKeepAwake + | CanSleep::ActiveInternalKeepAwake + | CanSleep::ActiveRunHandler + | CanSleep::ActiveDisconnectCallbacks + | CanSleep::ActiveConnections + | CanSleep::ActiveWebSocketCallbacks => {} + } + + ctx.set_prevent_sleep(false); + } + + #[tokio::test(start_paused = true)] + async fn shutdown_deadline_token_aborts_select_awaiting_task() { + // Mirrors the NAPI `RunGracefulCleanup` pattern: a task awaits user + // work and the shutdown_deadline cancellation in a `tokio::select!`. + // If `cancel_shutdown_deadline()` does not propagate to clones of the + // token, the spawned task would hang and the test would time out. + let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-deadline"); + let token = ctx.shutdown_deadline_token(); + assert!(!token.is_cancelled()); + + let aborted = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let aborted_clone = aborted.clone(); + let task = tokio::spawn(async move { + tokio::select! { + _ = token.cancelled() => { + aborted_clone.store(true, Ordering::SeqCst); + } + _ = futures::future::pending::<()>() => {} + } + }); + + yield_now().await; + assert!(!aborted.load(Ordering::SeqCst)); + + ctx.cancel_shutdown_deadline(); + task.await.expect("select task should join after cancel"); + assert!( + aborted.load(Ordering::SeqCst), + "select-awaiting task must observe cancel via the cloned token" + ); + } + + #[tokio::test(start_paused = true)] + async fn sleep_after_grace_clears_started_returns_stopping_not_starting() { + // Simulate the lifecycle state machine clearing `started` when it + // transitions into SleepGrace. Calls into `sleep()` after that point + // must surface `Stopping`, not `Starting`. + let ctx = ActorContext::new_for_sleep_tests("actor-sleep-after-grace"); + ctx.set_started(true); + + ctx.sleep().expect("first sleep call should be accepted"); + + // Lifecycle machine clears `started` on transition into SleepGrace. + ctx.set_started(false); + + let err = ctx.sleep().expect_err("second sleep should fail"); + let rivet_err = rivet_error::RivetError::extract(&err); + assert_eq!(rivet_err.group(), "actor"); + assert_eq!( + rivet_err.code(), + "stopping", + "started=false during shutdown must surface stopping, not starting" + ); + } +} diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/state.rs b/rivetkit-rust/packages/rivetkit-core/tests/state.rs similarity index 100% rename from rivetkit-rust/packages/rivetkit-core/tests/modules/state.rs rename to rivetkit-rust/packages/rivetkit-core/tests/state.rs diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/task.rs b/rivetkit-rust/packages/rivetkit-core/tests/task.rs similarity index 98% rename from rivetkit-rust/packages/rivetkit-core/tests/modules/task.rs rename to rivetkit-rust/packages/rivetkit-core/tests/task.rs index 416269c5b2..20c5dc4b96 100644 --- a/rivetkit-rust/packages/rivetkit-core/tests/modules/task.rs +++ b/rivetkit-rust/packages/rivetkit-core/tests/task.rs @@ -21,6 +21,7 @@ mod moved_tests { use tokio::task::yield_now; use tokio::time::{Instant, advance, sleep, timeout}; use tracing::field::{Field, Visit}; + use tracing::instrument::WithSubscriber; use tracing::{Event, Subscriber}; use tracing_subscriber::layer::{Context as LayerContext, Layer}; use tracing_subscriber::prelude::*; @@ -43,6 +44,7 @@ mod moved_tests { LifecycleEvent, LifecycleState, LiveExit, }; use crate::actor::task_types::ShutdownKind; + use crate::inspector::auth::test_inspector_env_lock; use crate::kv::tests::new_in_memory; use crate::{ActorConfig, ActorContext, ActorFactory}; @@ -1908,7 +1910,11 @@ mod moved_tests { } #[tokio::test] - async fn startup_uses_empty_preloaded_persisted_actor_without_fallback_get() { + async fn startup_uses_empty_preloaded_persisted_actor_without_startup_bundle_batch_get() { + let _env_guard = test_inspector_env_lock().lock().expect("env lock poisoned"); + unsafe { + std::env::remove_var("_RIVET_TEST_INSPECTOR_TOKEN"); + } let kv = new_in_memory(); let ctx = new_with_kv( "actor-preload-empty", @@ -1928,7 +1934,10 @@ mod moved_tests { .expect("start reply should send") .expect("start should succeed"); - assert_eq!(kv.test_batch_get_call_count(), 0); + // Startup still probes the inspector token key at [3], but it should not + // batch-get the persisted actor/alarm startup bundle when the registry + // already told us the persisted bundle exists and is empty. + assert_eq!(kv.test_batch_get_call_count(), 1); assert!(ctx.persisted_actor().has_initialized); } @@ -3294,8 +3303,8 @@ mod moved_tests { let subscriber = Registry::default().with(ActorTaskLogLayer { records: records.clone(), }); - let _guard = tracing::subscriber::set_default(subscriber); - let run = tokio::spawn(task.run()); + let dispatch = tracing::Dispatch::new(subscriber); + let run = tokio::spawn(task.run().with_subscriber(dispatch)); let (start_tx, start_rx) = oneshot::channel(); lifecycle_tx @@ -3345,30 +3354,33 @@ mod moved_tests { .lock() .expect("actor-task log lock poisoned") .clone(); - assert!(logs.iter().any(|log| { - log.level == tracing::Level::INFO - && log.actor_id.as_deref() == Some("actor-log-flow") - && log.message.as_deref() == Some("actor lifecycle transition") - && log.new.as_deref() == Some("Started") - })); - assert!(logs.iter().any(|log| { - log.level == tracing::Level::DEBUG - && log.actor_id.as_deref() == Some("actor-log-flow") - && log.message.as_deref() == Some("actor lifecycle command received") - && log.command.as_deref() == Some("start") - })); - assert!(logs.iter().any(|log| { - log.level == tracing::Level::DEBUG - && log.actor_id.as_deref() == Some("actor-log-flow") - && log.message.as_deref() == Some("actor event enqueued") - && log.event.as_deref() == Some("action") - })); - assert!(logs.iter().any(|log| { - log.level == tracing::Level::DEBUG - && log.actor_id.as_deref() == Some("actor-log-flow") - && log.message.as_deref() == Some("actor event drained") - && log.event.as_deref() == Some("action") - })); + assert!( + logs.iter().any(|log| { + log.level == tracing::Level::DEBUG + && log.actor_id.as_deref() == Some("actor-log-flow") + && log.message.as_deref() == Some("actor lifecycle command received") + && log.command.as_deref() == Some("start") + }), + "expected `actor lifecycle command received` log for actor-log-flow; logs={logs:?}" + ); + assert!( + logs.iter().any(|log| { + log.level == tracing::Level::DEBUG + && log.actor_id.as_deref() == Some("actor-log-flow") + && log.message.as_deref() == Some("actor event enqueued") + && log.event.as_deref() == Some("action") + }), + "expected `actor event enqueued` log for actor-log-flow; logs={logs:?}" + ); + assert!( + logs.iter().any(|log| { + log.level == tracing::Level::DEBUG + && log.actor_id.as_deref() == Some("actor-log-flow") + && log.message.as_deref() == Some("actor event drained") + && log.event.as_deref() == Some("action") + }), + "expected `actor event drained` log for actor-log-flow; logs={logs:?}" + ); } #[tokio::test] diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/websocket.rs b/rivetkit-rust/packages/rivetkit-core/tests/websocket.rs similarity index 100% rename from rivetkit-rust/packages/rivetkit-core/tests/modules/websocket.rs rename to rivetkit-rust/packages/rivetkit-core/tests/websocket.rs diff --git a/rivetkit-rust/packages/rivetkit-core/tests/work_registry.rs b/rivetkit-rust/packages/rivetkit-core/tests/work_registry.rs new file mode 100644 index 0000000000..632e64f8c4 --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/work_registry.rs @@ -0,0 +1,37 @@ +use super::*; + +mod moved_tests { + use std::panic::{AssertUnwindSafe, catch_unwind}; + + use super::WorkRegistry; + + #[test] + fn region_guard_drop_decrements_counter() { + let work = WorkRegistry::new(); + assert_eq!(work.keep_awake.load(), 0); + + { + let _guard = work.keep_awake_guard(); + assert_eq!(work.keep_awake.load(), 1); + } + + assert_eq!(work.keep_awake.load(), 0); + } + + #[test] + fn region_guard_drop_during_panic_unwind_decrements_counter() { + let work = WorkRegistry::new(); + + let result = catch_unwind(AssertUnwindSafe(|| { + let _guard = work.keep_awake_guard(); + assert_eq!(work.keep_awake.load(), 1); + panic!("boom"); + })); + + assert!( + result.is_err(), + "panic should propagate through catch_unwind" + ); + assert_eq!(work.keep_awake.load(), 0); + } +}