diff --git a/CLAUDE.md b/CLAUDE.md index 88f9cfe954..dc83ee1859 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -211,6 +211,11 @@ When the user asks to track something in a note, store it in `.agent/notes/` by - **Shared Libraries** (`shared/{language}/{package}/`) - Libraries shared between the engine and rivetkit (e.g., `shared/typescript/virtual-websocket/`) - **Service Infrastructure** - Distributed services communicate via NATS messaging with service discovery +### Engine Runner Parity +- Keep `engine/sdks/typescript/runner` and `engine/sdks/rust/engine-runner` at feature parity. +- Any behavior, protocol handling, or test coverage added to one runner should be mirrored in the other runner in the same change whenever possible. +- When parity cannot be completed in the same change, explicitly document the gap and add a follow-up task. + ### Important Patterns **Error Handling** diff --git a/engine/packages/engine/tests/runner/actors_alarm.rs b/engine/packages/engine/tests/runner/actors_alarm.rs index abc7ae603d..7b86411ff6 100644 --- a/engine/packages/engine/tests/runner/actors_alarm.rs +++ b/engine/packages/engine/tests/runner/actors_alarm.rs @@ -140,7 +140,7 @@ impl AlarmAndSleepActor { } #[async_trait] -impl TestActor for AlarmAndSleepActor { +impl Actor for AlarmAndSleepActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm actor starting"); @@ -195,7 +195,7 @@ impl AlarmAndSleepOnceActor { } #[async_trait] -impl TestActor for AlarmAndSleepOnceActor { +impl Actor for AlarmAndSleepOnceActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm once actor starting"); @@ -250,7 +250,7 @@ impl AlarmSleepThenClearActor { } #[async_trait] -impl TestActor for AlarmSleepThenClearActor { +impl Actor for AlarmSleepThenClearActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm actor starting"); @@ -311,7 +311,7 @@ impl AlarmSleepThenReplaceActor { } #[async_trait] -impl TestActor for AlarmSleepThenReplaceActor { +impl Actor for AlarmSleepThenReplaceActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm actor starting"); @@ -374,7 +374,7 @@ impl MultipleAlarmSetActor { } #[async_trait] -impl TestActor for MultipleAlarmSetActor { +impl Actor for MultipleAlarmSetActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "multi alarm actor starting"); @@ -429,7 +429,7 @@ impl MultiCycleAlarmActor { } #[async_trait] -impl TestActor for MultiCycleAlarmActor { +impl Actor for MultiCycleAlarmActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "multi cycle alarm actor starting"); @@ -481,7 +481,7 @@ impl AlarmOnceActor { } #[async_trait] -impl TestActor for AlarmOnceActor { +impl Actor for AlarmOnceActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm once actor starting"); @@ -536,7 +536,7 @@ impl AlarmSleepThenCrashActor { } #[async_trait] -impl TestActor for AlarmSleepThenCrashActor { +impl Actor for AlarmSleepThenCrashActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm crash actor starting"); @@ -599,7 +599,7 @@ impl RapidAlarmCycleActor { } #[async_trait] -impl TestActor for RapidAlarmCycleActor { +impl Actor for RapidAlarmCycleActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "rapid alarm cycle actor starting"); @@ -647,7 +647,7 @@ impl SetClearAlarmAndSleepActor { } #[async_trait] -impl TestActor for SetClearAlarmAndSleepActor { +impl Actor for SetClearAlarmAndSleepActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm actor starting"); diff --git a/engine/packages/engine/tests/runner/actors_kv_crud.rs b/engine/packages/engine/tests/runner/actors_kv_crud.rs index 4e10b23aa1..55f19cc156 100644 --- a/engine/packages/engine/tests/runner/actors_kv_crud.rs +++ b/engine/packages/engine/tests/runner/actors_kv_crud.rs @@ -38,7 +38,7 @@ impl PutAndGetActor { } #[async_trait] -impl TestActor for PutAndGetActor { +impl Actor for PutAndGetActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "put and get actor starting"); @@ -116,7 +116,7 @@ impl GetNonexistentKeyActor { } #[async_trait] -impl TestActor for GetNonexistentKeyActor { +impl Actor for GetNonexistentKeyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "get nonexistent key actor starting"); @@ -191,7 +191,7 @@ impl PutOverwriteActor { } #[async_trait] -impl TestActor for PutOverwriteActor { +impl Actor for PutOverwriteActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "put overwrite actor starting"); @@ -295,7 +295,7 @@ impl DeleteKeyActor { } #[async_trait] -impl TestActor for DeleteKeyActor { +impl Actor for DeleteKeyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "delete key actor starting"); @@ -383,7 +383,7 @@ impl DeleteNonexistentKeyActor { } #[async_trait] -impl TestActor for DeleteNonexistentKeyActor { +impl Actor for DeleteNonexistentKeyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "delete nonexistent key actor starting"); @@ -638,7 +638,7 @@ impl BatchPutActor { } #[async_trait] -impl TestActor for BatchPutActor { +impl Actor for BatchPutActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "batch put actor starting"); @@ -721,7 +721,7 @@ impl BatchGetActor { } #[async_trait] -impl TestActor for BatchGetActor { +impl Actor for BatchGetActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "batch get actor starting"); @@ -808,7 +808,7 @@ impl BatchDeleteActor { } #[async_trait] -impl TestActor for BatchDeleteActor { +impl Actor for BatchDeleteActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "batch delete actor starting"); diff --git a/engine/packages/engine/tests/runner/actors_kv_drop.rs b/engine/packages/engine/tests/runner/actors_kv_drop.rs index 4fa5e03e15..384ff3fb5b 100644 --- a/engine/packages/engine/tests/runner/actors_kv_drop.rs +++ b/engine/packages/engine/tests/runner/actors_kv_drop.rs @@ -39,7 +39,7 @@ impl DropClearsAllActor { } #[async_trait] -impl TestActor for DropClearsAllActor { +impl Actor for DropClearsAllActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "drop clears all actor starting"); @@ -137,7 +137,7 @@ impl DropEmptyActor { } #[async_trait] -impl TestActor for DropEmptyActor { +impl Actor for DropEmptyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "drop empty actor starting"); diff --git a/engine/packages/engine/tests/runner/actors_kv_list.rs b/engine/packages/engine/tests/runner/actors_kv_list.rs index d75e66dea1..d1bd585cda 100644 --- a/engine/packages/engine/tests/runner/actors_kv_list.rs +++ b/engine/packages/engine/tests/runner/actors_kv_list.rs @@ -39,7 +39,7 @@ impl ListAllEmptyActor { } #[async_trait] -impl TestActor for ListAllEmptyActor { +impl Actor for ListAllEmptyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list all empty actor starting"); @@ -102,7 +102,7 @@ impl ListAllKeysActor { } #[async_trait] -impl TestActor for ListAllKeysActor { +impl Actor for ListAllKeysActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list all keys actor starting"); @@ -199,7 +199,7 @@ impl ListAllLimitActor { } #[async_trait] -impl TestActor for ListAllLimitActor { +impl Actor for ListAllLimitActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list all limit actor starting"); @@ -277,7 +277,7 @@ impl ListAllReverseActor { } #[async_trait] -impl TestActor for ListAllReverseActor { +impl Actor for ListAllReverseActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list all reverse actor starting"); @@ -368,7 +368,7 @@ impl ListRangeInclusiveActor { } #[async_trait] -impl TestActor for ListRangeInclusiveActor { +impl Actor for ListRangeInclusiveActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list range inclusive actor starting"); @@ -467,7 +467,7 @@ impl ListRangeExclusiveActor { } #[async_trait] -impl TestActor for ListRangeExclusiveActor { +impl Actor for ListRangeExclusiveActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list range exclusive actor starting"); @@ -566,7 +566,7 @@ impl ListPrefixActor { } #[async_trait] -impl TestActor for ListPrefixActor { +impl Actor for ListPrefixActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list prefix actor starting"); @@ -669,7 +669,7 @@ impl ListPrefixNoMatchActor { } #[async_trait] -impl TestActor for ListPrefixNoMatchActor { +impl Actor for ListPrefixNoMatchActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list prefix no match actor starting"); diff --git a/engine/packages/engine/tests/runner/actors_kv_misc.rs b/engine/packages/engine/tests/runner/actors_kv_misc.rs index a26837ec39..2d2e75688e 100644 --- a/engine/packages/engine/tests/runner/actors_kv_misc.rs +++ b/engine/packages/engine/tests/runner/actors_kv_misc.rs @@ -39,7 +39,7 @@ impl BinaryDataActor { } #[async_trait] -impl TestActor for BinaryDataActor { +impl Actor for BinaryDataActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "binary data actor starting"); @@ -114,7 +114,7 @@ impl EmptyValueActor { } #[async_trait] -impl TestActor for EmptyValueActor { +impl Actor for EmptyValueActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "empty value actor starting"); @@ -203,7 +203,7 @@ impl LargeValueActor { } #[async_trait] -impl TestActor for LargeValueActor { +impl Actor for LargeValueActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "large value actor starting"); @@ -286,7 +286,7 @@ impl GetEmptyKeysActor { } #[async_trait] -impl TestActor for GetEmptyKeysActor { +impl Actor for GetEmptyKeysActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "get empty keys actor starting"); @@ -350,7 +350,7 @@ impl ListLimitZeroActor { } #[async_trait] -impl TestActor for ListLimitZeroActor { +impl Actor for ListLimitZeroActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list limit zero actor starting"); @@ -429,7 +429,7 @@ impl KeyOrderingActor { } #[async_trait] -impl TestActor for KeyOrderingActor { +impl Actor for KeyOrderingActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "key ordering actor starting"); @@ -520,7 +520,7 @@ impl ManyKeysActor { } #[async_trait] -impl TestActor for ManyKeysActor { +impl Actor for ManyKeysActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "many keys actor starting"); diff --git a/engine/packages/engine/tests/runner/actors_lifecycle.rs b/engine/packages/engine/tests/runner/actors_lifecycle.rs index b28c4bcff4..531b5e2e49 100644 --- a/engine/packages/engine/tests/runner/actors_lifecycle.rs +++ b/engine/packages/engine/tests/runner/actors_lifecycle.rs @@ -42,7 +42,7 @@ fn actor_basic_create() { "runner should have the actor allocated" ); - tracing::info!(?actor_id, runner_id = ?runner.runner_id, "actor allocated to runner"); + tracing::info!(?actor_id, "actor allocated to runner"); }); } diff --git a/engine/sdks/rust/engine-runner/examples/counter.rs b/engine/sdks/rust/engine-runner/examples/counter.rs new file mode 100644 index 0000000000..40529e419e --- /dev/null +++ b/engine/sdks/rust/engine-runner/examples/counter.rs @@ -0,0 +1,69 @@ +//! Counter example using the Rust engine runner API. + +use anyhow::Result; +use axum::{Json, Router, extract::State, routing::{get, post}}; +use rivet_engine_runner::{ + ActorContext, ActorRequestContext, AxumActorDefinition, AxumRunnerApp, Runner, RunnerConfig, +}; +use serde_json::json; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + let app = AxumRunnerApp::new().with_actor( + "counter", + AxumActorDefinition::new( + Router::new() + .route("/count", get(get_count)) + .route("/increment", post(increment)), + ) + .on_start(|ctx: ActorContext| async move { + tracing::info!(actor_id = %ctx.actor_id, generation = ctx.generation, "counter actor started"); + Ok(()) + }) + .on_stop(|ctx: ActorContext| async move { + tracing::info!(actor_id = %ctx.actor_id, generation = ctx.generation, "counter actor stopped"); + Ok(()) + }), + ); + + let runner = Runner::builder( + RunnerConfig::builder() + .endpoint("http://127.0.0.1:6420") + .namespace("default") + .runner_name("counter-runner") + .build()?, + ) + .app(app) + .build()?; + + println!( + "runner configured. call runner.start().await in an integration environment with a running engine" + ); + let _ = Arc::new(runner); + Ok(()) +} + +async fn get_count(State(ctx): State) -> Result, axum::http::StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0); + Ok(Json(json!({ "count": count }))) +} + +async fn increment(State(ctx): State) -> Result, axum::http::StatusCode> { + let next = ctx + .kv_get_u64("count") + .await + .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0) + + 1; + + ctx.kv_put_u64("count", next) + .await + .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json(json!({ "count": next }))) +} diff --git a/engine/sdks/rust/engine-runner/tests/common/mod.rs b/engine/sdks/rust/engine-runner/tests/common/mod.rs new file mode 100644 index 0000000000..99bbed105b --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/common/mod.rs @@ -0,0 +1,357 @@ +use anyhow::{Context, Result, bail}; +use reqwest::Method; +use serde_json::{Value, json}; +use std::{ + fmt::Write as _, + path::PathBuf, + process::{Child, Command, Stdio}, + sync::{Arc, OnceLock}, + time::{Duration, Instant}, +}; +use tempfile::TempDir; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_tungstenite::{ + connect_async, + tungstenite::client::IntoClientRequest, + WebSocketStream, + MaybeTlsStream, +}; +use urlencoding::encode; + +pub struct EngineProcess { + pub deps: rivet_test_deps::TestDeps, + child: Child, + _config_dir: TempDir, +} + +impl EngineProcess { + pub async fn start() -> Result { + let deps = rivet_test_deps::TestDeps::new().await?; + + let config_dir = tempfile::tempdir().context("failed to create config dir")?; + let config_path = config_dir.path().join("rivet.test.yaml"); + let mut root = (**deps.config()).clone(); + if let Some(rivet_config::config::Database::FileSystem(database)) = root.database.as_mut() { + let db_path = config_dir.path().join("engine-db"); + std::fs::create_dir_all(&db_path).context("failed to create engine db dir")?; + database.path = db_path; + } + + let config_yaml = serde_yaml::to_string(&root) + .context("failed to serialize config")?; + std::fs::write(&config_path, config_yaml).context("failed to write config")?; + + let engine_bin = ensure_engine_binary()?; + let mut cmd = Command::new(engine_bin); + cmd.arg("--config") + .arg(&config_path) + .arg("start") + .arg("-s") + .arg("api_peer") + .arg("-s") + .arg("guard") + .arg("-s") + .arg("workflow_worker") + .arg("-s") + .arg("bootstrap") + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .stdin(Stdio::null()); + + let child = cmd.spawn().context("failed to spawn rivet-engine")?; + + wait_for_port(deps.api_peer_port()).await?; + wait_for_port(deps.guard_port()).await?; + + Ok(Self { + deps, + child, + _config_dir: config_dir, + }) + } + + pub fn guard_url(&self) -> String { + format!("http://127.0.0.1:{}", self.deps.guard_port()) + } + + pub async fn create_actor( + &self, + namespace: &str, + name: &str, + runner_name_selector: &str, + key: Option<&str>, + ) -> Result { + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/actors", self.guard_url())) + .query(&[("namespace", namespace)]) + .json(&json!({ + "datacenter": null, + "name": name, + "key": key, + "input": null, + "runner_name_selector": runner_name_selector, + "crash_policy": "sleep", + })) + .send() + .await + .context("failed to create actor")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + bail!("create actor failed: {status} {body}"); + } + + let body: Value = response.json().await.context("failed to decode actor response")?; + let actor_id = body + .get("actor") + .and_then(|x| x.get("actor_id")) + .and_then(Value::as_str) + .context("actor id missing from create actor response")?; + Ok(actor_id.to_string()) + } + + #[allow(dead_code)] + pub async fn actor_request_json( + &self, + method: Method, + actor_id: &str, + path: &str, + body: Option, + ) -> Result { + let response = self + .actor_request_with_retry(method, actor_id, path, body) + .await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + bail!("actor request failed: {status} {body}"); + } + + response + .json() + .await + .context("failed to decode actor response json") + } + + #[allow(dead_code)] + pub async fn get_actor(&self, namespace: &str, actor_id: &str) -> Result> { + let client = reqwest::Client::new(); + let response = client + .get(format!("{}/actors", self.guard_url())) + .query(&[ + ("namespace", namespace), + ("actor_id", actor_id), + ("include_destroyed", "true"), + ]) + .send() + .await + .context("failed to fetch actors list")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + bail!("actors list request failed: {status} {body}"); + } + + let body: Value = response + .json() + .await + .context("failed to decode actors list response json")?; + let actor = body + .get("actors") + .and_then(Value::as_array) + .and_then(|actors| actors.first()) + .cloned(); + Ok(actor) + } + + #[allow(dead_code)] + pub async fn actor_request_with_retry( + &self, + method: Method, + actor_id: &str, + path: &str, + body: Option, + ) -> Result { + let url = format!("{}{}", self.guard_url(), path); + let client = reqwest::Client::new(); + + let start = Instant::now(); + let timeout = Duration::from_secs(30); + let mut last_error: Option = None; + + loop { + if start.elapsed() > timeout { + if let Some(err) = last_error { + return Err(err).context("timed out waiting for actor response"); + } + bail!("timed out waiting for actor response"); + } + + let mut request = client + .request(method.clone(), &url) + .header("x-rivet-target", "actor") + .header("x-rivet-token", "dev") + .header("x-rivet-actor", actor_id); + + if let Some(json) = &body { + request = request.json(json); + } + + match request.send().await { + Ok(response) + if response.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE + || response.status() == reqwest::StatusCode::NOT_FOUND => + { + tokio::time::sleep(Duration::from_millis(250)).await; + continue; + } + Ok(response) if response.status() == reqwest::StatusCode::BAD_REQUEST => { + tokio::time::sleep(Duration::from_millis(250)).await; + drop(response); + continue; + } + Ok(response) => return Ok(response), + Err(err) => { + last_error = Some(err.into()); + tokio::time::sleep(Duration::from_millis(250)).await; + } + } + } + } + + #[allow(dead_code)] + pub async fn actor_websocket_connect( + &self, + actor_id: &str, + path: &str, + ) -> Result>> { + let start = Instant::now(); + let timeout = Duration::from_secs(30); + let mut last_error: Option = None; + + loop { + if start.elapsed() > timeout { + if let Some(err) = last_error { + return Err(err).context("timed out connecting actor websocket"); + } + bail!("timed out connecting actor websocket"); + } + + let mut ws_url = self.guard_url().replace("http://", "ws://"); + if path.starts_with('/') { + ws_url.push_str(path); + } else { + ws_url.push('/'); + ws_url.push_str(path); + } + + let mut request = ws_url + .into_client_request() + .context("failed to build websocket request")?; + request + .headers_mut() + .insert("x-rivet-target", "actor".parse().context("invalid target header")?); + request + .headers_mut() + .insert("x-rivet-token", "dev".parse().context("invalid token header")?); + request + .headers_mut() + .insert("x-rivet-actor", actor_id.parse().context("invalid actor header")?); + let actor_id_protocol = format!("rivet_actor.{}", encode(actor_id)); + let websocket_protocol = format!( + "rivet_target.actor, {actor_id_protocol}, rivet_token.dev, rivet" + ); + request.headers_mut().insert( + "Sec-WebSocket-Protocol", + websocket_protocol + .parse() + .context("invalid websocket protocol header")?, + ); + + match connect_async(request).await { + Ok((ws, _response)) => return Ok(ws), + Err(err) => { + last_error = Some(err.into()); + tokio::time::sleep(Duration::from_millis(250)).await; + } + } + } + } +} + +pub async fn acquire_test_lock() -> Result { + static TEST_LOCK: OnceLock> = OnceLock::new(); + let lock = TEST_LOCK + .get_or_init(|| Arc::new(Semaphore::new(1))) + .clone(); + lock.acquire_owned() + .await + .context("failed to acquire test lock") +} + +pub fn random_name(prefix: &str) -> String { + let mut name = String::with_capacity(prefix.len() + 17); + let _ = write!(&mut name, "{}-{:016x}", prefix, rand::random::()); + name +} + +impl Drop for EngineProcess { + fn drop(&mut self) { + let _ = self.child.kill(); + let _ = self.child.wait(); + } +} + +fn ensure_engine_binary() -> Result { + static BUILD_RESULT: OnceLock> = OnceLock::new(); + + let result = BUILD_RESULT.get_or_init(|| { + let workspace = workspace_root(); + let status = Command::new("cargo") + .arg("build") + .arg("-p") + .arg("rivet-engine") + .current_dir(&workspace) + .status(); + + match status { + Ok(status) if status.success() => { + let bin = workspace.join("target").join("debug").join("rivet-engine"); + if bin.exists() { + Ok(bin) + } else { + Err(format!("engine binary not found at {}", bin.display())) + } + } + Ok(status) => Err(format!("cargo build -p rivet-engine failed with status {status}")), + Err(err) => Err(format!("failed to execute cargo build: {err}")), + } + }); + + result.clone().map_err(anyhow::Error::msg) +} + +fn workspace_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../../../") + .canonicalize() + .expect("workspace root") +} + +async fn wait_for_port(port: u16) -> Result<()> { + let addr = format!("127.0.0.1:{port}"); + let start = Instant::now(); + let timeout = Duration::from_secs(30); + + loop { + match tokio::net::TcpStream::connect(&addr).await { + Ok(_) => return Ok(()), + Err(_) if start.elapsed() <= timeout => tokio::time::sleep(Duration::from_millis(100)).await, + Err(err) => return Err(err).with_context(|| format!("timed out waiting for port {port}")), + } + } +} diff --git a/engine/sdks/rust/engine-runner/tests/e2e_counter_runner.rs b/engine/sdks/rust/engine-runner/tests/e2e_counter_runner.rs new file mode 100644 index 0000000000..16efde33b7 --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/e2e_counter_runner.rs @@ -0,0 +1,201 @@ +mod common; + +use anyhow::{Result, bail}; +use axum::{ + Json, Router, + extract::State, + http::StatusCode, + routing::{get, post}, +}; +use reqwest::Method; +use rivet_engine_runner::{ + ActorContext, ActorRequestContext, AxumActorDefinition, AxumRunnerApp, Runner, RunnerConfig, +}; +use serde_json::{Value, json}; +use std::{collections::HashSet, sync::Arc, time::{Duration, Instant}}; +use tokio::sync::Mutex; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn counter_actor_runner_http_kv_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + + let runner_name = common::random_name("rust-counter-runner"); + let runner_key = common::random_name("key"); + let actor_key = common::random_name("counter"); + let actor_registry = Arc::new(Mutex::new(HashSet::::new())); + + let runner = Runner::builder( + RunnerConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(runner_key) + .token("dev") + .total_slots(16) + .build()?, + ) + .app(build_counter_app(actor_registry.clone())) + .build()?; + + runner.start().await?; + runner.wait_ready().await?; + + let actor_id = engine + .create_actor(&namespace, "counter", &runner_name, Some(&actor_key)) + .await?; + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + let actor = engine + .get_actor(&namespace, &actor_id) + .await? + .ok_or_else(|| anyhow::anyhow!("actor missing after create: {actor_id}"))?; + if actor.get("destroy_ts").is_some_and(|x| !x.is_null()) { + bail!("actor is already destroyed before first request: {actor}"); + } + + let count = match engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await + { + Ok(value) => value, + Err(err) => { + let actor = engine.get_actor(&namespace, &actor_id).await?; + bail!("initial actor request failed actor={actor:?}: {err}"); + } + }; + assert_count(&count, 0)?; + + let incremented = match engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await + { + Ok(value) => value, + Err(err) => { + let actor = engine.get_actor(&namespace, &actor_id).await?; + bail!("first increment request failed actor={actor:?}: {err}"); + } + }; + assert_count(&incremented, 1)?; + + let incremented_again = match engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await + { + Ok(value) => value, + Err(err) => { + let actor = engine.get_actor(&namespace, &actor_id).await?; + bail!("second increment request failed actor={actor:?}: {err}"); + } + }; + assert_count(&incremented_again, 2)?; + + runner.handle().sleep_actor(&actor_id, None).await?; + wait_for_actor_presence(&actor_registry, &actor_id, false, Duration::from_secs(30)).await?; + + let persisted = match engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await + { + Ok(value) => value, + Err(err) => { + let actor = engine.get_actor(&namespace, &actor_id).await?; + bail!("persisted count request failed actor={actor:?}: {err}"); + } + }; + assert_count(&persisted, 2)?; + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + runner.shutdown(true).await?; + + Ok(()) +} + +fn build_counter_app(actor_registry: Arc>>) -> AxumRunnerApp { + let on_start_registry = actor_registry.clone(); + let on_stop_registry = actor_registry; + + AxumRunnerApp::new().with_actor( + "counter", + AxumActorDefinition::new( + Router::new() + .route("/count", get(get_count)) + .route("/increment", post(increment)), + ) + .on_start(move |ctx: ActorContext| { + let actor_registry = on_start_registry.clone(); + async move { + actor_registry.lock().await.insert(ctx.actor_id); + Ok(()) + } + }) + .on_stop(move |ctx: ActorContext| { + let actor_registry = on_stop_registry.clone(); + async move { + actor_registry.lock().await.remove(&ctx.actor_id); + Ok(()) + } + }), + ) +} + +async fn get_count( + State(ctx): State, +) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0); + Ok(Json(json!({ "count": count }))) +} + +async fn increment( + State(ctx): State, +) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0) + + 1; + ctx.kv_put_u64("count", count) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(Json(json!({ "count": count }))) +} + +fn assert_count(value: &Value, expected: u64) -> Result<()> { + let actual = value + .get("count") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("response missing `count` field: {value}"))?; + if actual != expected { + bail!("count mismatch: expected {expected}, got {actual}"); + } + Ok(()) +} + +async fn wait_for_actor_presence( + actor_registry: &Arc>>, + actor_id: &str, + expected: bool, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + let present = actor_registry.lock().await.contains(actor_id); + if present == expected { + return Ok(()); + } + if Instant::now() >= deadline { + bail!( + "timed out waiting for actor presence state actor_id={} expected_present={} actual_present={}", + actor_id, + expected, + present + ); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +} diff --git a/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless.rs b/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless.rs new file mode 100644 index 0000000000..2a5e46e6be --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless.rs @@ -0,0 +1,218 @@ +mod common; + +use anyhow::{Context, Result, bail}; +use axum::{ + Json, Router, + extract::State, + http::StatusCode, + routing::{get, post}, +}; +use reqwest::Method; +use rivet_engine_runner::{ + ActorContext, ActorRequestContext, AxumActorDefinition, AxumRunnerApp, ServerlessConfig, + ServerlessRunner, +}; +use serde_json::{Value, json}; +use std::{collections::HashSet, sync::Arc, time::{Duration, Instant}}; +use tokio::sync::{Mutex, oneshot}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn counter_actor_serverless_http_kv_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + + let runner_name = common::random_name("rust-counter-serverless"); + let runner_key = common::random_name("key"); + let actor_key = common::random_name("counter"); + let actor_registry = Arc::new(Mutex::new(HashSet::::new())); + + let serverless_runner = ServerlessRunner::builder( + ServerlessConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(runner_key) + .prepopulate_actor_name("counter", json!({})) + .token("dev") + .total_slots(1) + .max_runners(1000) + .slots_per_runner(1) + .request_lifespan(300) + .build()?, + ) + .app(build_counter_app(actor_registry.clone())) + .build()?; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .context("failed to bind serverless test listener")?; + let addr = listener.local_addr().context("missing listener local addr")?; + let serverless_url = format!("http://localhost:{}", addr.port()); + + let routes = Arc::new(serverless_runner.clone()).axum_routes(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let mut server_task = tokio::spawn(async move { + axum::serve(listener, routes) + .with_graceful_shutdown(async move { + let _ = shutdown_rx.await; + }) + .await + .context("serverless axum server exited with error") + }); + + let metadata_response = reqwest::get(format!("{serverless_url}/api/rivet/metadata")) + .await + .context("failed to call serverless metadata endpoint")?; + if metadata_response.status() != reqwest::StatusCode::OK { + bail!("metadata endpoint returned {}", metadata_response.status()); + } + + let start_response = reqwest::Client::new() + .get(format!("{serverless_url}/api/rivet/start")) + .send() + .await?; + if start_response.status() != reqwest::StatusCode::OK { + bail!("serverless start endpoint returned {}", start_response.status()); + } + + let actor_id = engine + .create_actor(&namespace, "counter", &runner_name, Some(&actor_key)) + .await?; + + let count = engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await?; + assert_count(&count, 0)?; + + let incremented = engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await?; + assert_count(&incremented, 1)?; + + let incremented_again = engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await?; + assert_count(&incremented_again, 2)?; + + tokio::time::timeout(Duration::from_secs(30), serverless_runner.runner().wait_ready()) + .await + .context("timed out waiting for serverless runner init")??; + + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + serverless_runner + .runner() + .handle() + .sleep_actor(&actor_id, None) + .await?; + wait_for_actor_presence(&actor_registry, &actor_id, false, Duration::from_secs(30)).await?; + + let persisted = engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await?; + assert_count(&persisted, 2)?; + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + serverless_runner.runner().shutdown(true).await?; + + let _ = shutdown_tx.send(()); + if tokio::time::timeout(Duration::from_secs(10), &mut server_task) + .await + .is_err() + { + server_task.abort(); + } + let _ = server_task.await; + + Ok(()) +} + +fn build_counter_app(actor_registry: Arc>>) -> AxumRunnerApp { + let on_start_registry = actor_registry.clone(); + let on_stop_registry = actor_registry; + + AxumRunnerApp::new().with_actor( + "counter", + AxumActorDefinition::new( + Router::new() + .route("/count", get(get_count)) + .route("/increment", post(increment)), + ) + .on_start(move |ctx: ActorContext| { + let actor_registry = on_start_registry.clone(); + async move { + actor_registry.lock().await.insert(ctx.actor_id); + Ok(()) + } + }) + .on_stop(move |ctx: ActorContext| { + let actor_registry = on_stop_registry.clone(); + async move { + actor_registry.lock().await.remove(&ctx.actor_id); + Ok(()) + } + }), + ) +} + +async fn get_count( + State(ctx): State, +) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0); + Ok(Json(json!({ "count": count }))) +} + +async fn increment( + State(ctx): State, +) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0) + + 1; + ctx.kv_put_u64("count", count) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(Json(json!({ "count": count }))) +} + +fn assert_count(value: &Value, expected: u64) -> Result<()> { + let actual = value + .get("count") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("response missing `count` field: {value}"))?; + if actual != expected { + bail!("count mismatch: expected {expected}, got {actual}"); + } + Ok(()) +} + +async fn wait_for_actor_presence( + actor_registry: &Arc>>, + actor_id: &str, + expected: bool, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + let present = actor_registry.lock().await.contains(actor_id); + if present == expected { + return Ok(()); + } + if Instant::now() >= deadline { + bail!( + "timed out waiting for actor presence state actor_id={} expected_present={} actual_present={}", + actor_id, + expected, + present + ); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +} diff --git a/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless_upsert.rs b/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless_upsert.rs new file mode 100644 index 0000000000..fd91cf6469 --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless_upsert.rs @@ -0,0 +1,208 @@ +mod common; + +use anyhow::{Context, Result, bail}; +use axum::{ + Json, Router, + extract::State, + http::StatusCode, + routing::{get, post}, +}; +use reqwest::Method; +use rivet_engine_runner::{ + ActorContext, ActorRequestContext, AxumActorDefinition, AxumRunnerApp, ServerlessConfig, + ServerlessRunner, +}; +use serde_json::{Value, json}; +use std::{ + collections::HashSet, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::sync::{Mutex, oneshot}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn counter_actor_serverless_upsert_config_http_kv_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + + let runner_name = common::random_name("rust-counter-serverless-upsert"); + let runner_key = common::random_name("key"); + let actor_key = common::random_name("counter"); + let actor_registry = Arc::new(Mutex::new(HashSet::::new())); + + let serverless_runner = ServerlessRunner::builder( + ServerlessConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(runner_key) + .prepopulate_actor_name("counter", json!({})) + .token("dev") + .total_slots(1) + .max_runners(1000) + .slots_per_runner(1) + .request_lifespan(300) + .build()?, + ) + .app(build_counter_app(actor_registry.clone())) + .build()?; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .context("failed to bind serverless test listener")?; + let addr = listener.local_addr().context("missing listener local addr")?; + let serverless_url = format!("http://localhost:{}", addr.port()); + + let routes = Arc::new(serverless_runner.clone()).axum_routes(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let mut server_task = tokio::spawn(async move { + axum::serve(listener, routes) + .with_graceful_shutdown(async move { + let _ = shutdown_rx.await; + }) + .await + .context("serverless axum server exited with error") + }); + + serverless_runner + .upsert_serverless_runner_config(&serverless_url) + .await + .context("failed to upsert serverless runner config")?; + + let actor_id = engine + .create_actor(&namespace, "counter", &runner_name, Some(&actor_key)) + .await?; + + let count = engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await?; + assert_count(&count, 0)?; + + let incremented = engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await?; + assert_count(&incremented, 1)?; + + let incremented_again = engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await?; + assert_count(&incremented_again, 2)?; + + tokio::time::timeout(Duration::from_secs(30), serverless_runner.runner().wait_ready()) + .await + .context("timed out waiting for serverless runner init")??; + + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + serverless_runner + .runner() + .handle() + .sleep_actor(&actor_id, None) + .await?; + wait_for_actor_presence(&actor_registry, &actor_id, false, Duration::from_secs(30)).await?; + + let persisted = engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await?; + assert_count(&persisted, 2)?; + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + serverless_runner.runner().shutdown(true).await?; + + let _ = shutdown_tx.send(()); + if tokio::time::timeout(Duration::from_secs(10), &mut server_task) + .await + .is_err() + { + server_task.abort(); + } + let _ = server_task.await; + + Ok(()) +} + +fn build_counter_app(actor_registry: Arc>>) -> AxumRunnerApp { + let on_start_registry = actor_registry.clone(); + let on_stop_registry = actor_registry; + + AxumRunnerApp::new().with_actor( + "counter", + AxumActorDefinition::new( + Router::new() + .route("/count", get(get_count)) + .route("/increment", post(increment)), + ) + .on_start(move |ctx: ActorContext| { + let actor_registry = on_start_registry.clone(); + async move { + actor_registry.lock().await.insert(ctx.actor_id); + Ok(()) + } + }) + .on_stop(move |ctx: ActorContext| { + let actor_registry = on_stop_registry.clone(); + async move { + actor_registry.lock().await.remove(&ctx.actor_id); + Ok(()) + } + }), + ) +} + +async fn get_count(State(ctx): State) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0); + Ok(Json(json!({ "count": count }))) +} + +async fn increment(State(ctx): State) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0) + + 1; + ctx.kv_put_u64("count", count) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(Json(json!({ "count": count }))) +} + +fn assert_count(value: &Value, expected: u64) -> Result<()> { + let actual = value + .get("count") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("response missing `count` field: {value}"))?; + if actual != expected { + bail!("count mismatch: expected {expected}, got {actual}"); + } + Ok(()) +} + +async fn wait_for_actor_presence( + actor_registry: &Arc>>, + actor_id: &str, + expected: bool, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + let present = actor_registry.lock().await.contains(actor_id); + if present == expected { + return Ok(()); + } + if Instant::now() >= deadline { + bail!( + "timed out waiting for actor presence state actor_id={} expected_present={} actual_present={}", + actor_id, + expected, + present + ); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +} diff --git a/engine/sdks/rust/engine-runner/tests/e2e_websocket.rs b/engine/sdks/rust/engine-runner/tests/e2e_websocket.rs new file mode 100644 index 0000000000..8db6160173 --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/e2e_websocket.rs @@ -0,0 +1,386 @@ +mod common; + +use anyhow::{Context, Result, bail}; +use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; +use rivet_engine_runner::{ + ActorContext, HibernatingWebSocketMetadata, Runner, RunnerApp, RunnerConfig, RunnerHandle, + ServerlessConfig, ServerlessRunner, WebSocketContext, WebSocketMessage, +}; +use serde_json::json; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::sync::{Mutex, oneshot}; +use tokio_tungstenite::tungstenite::Message; + +#[derive(Clone, Default)] +struct EchoWebSocketApp { + actors: Arc>>, + closes: Arc>>, + hibernating_metadata: + Arc>>>, +} + +#[async_trait] +impl RunnerApp for EchoWebSocketApp { + async fn on_actor_start(&self, runner: RunnerHandle, ctx: ActorContext) -> Result<()> { + self.actors.lock().await.insert(ctx.actor_id.clone()); + let metadata = self + .hibernating_metadata + .lock() + .await + .get(&ctx.actor_id) + .map(|entries| entries.values().cloned().collect()) + .unwrap_or_default(); + runner + .restore_hibernating_requests(&ctx.actor_id, metadata) + .await?; + Ok(()) + } + + async fn on_actor_stop(&self, _runner: RunnerHandle, ctx: ActorContext) -> Result<()> { + self.actors.lock().await.remove(&ctx.actor_id); + Ok(()) + } + + async fn websocket(&self, _runner: RunnerHandle, ctx: WebSocketContext) -> Result<()> { + self.hibernating_metadata + .lock() + .await + .entry(ctx.actor_id.clone()) + .or_default() + .insert( + (ctx.gateway_id, ctx.request_id), + HibernatingWebSocketMetadata { + gateway_id: ctx.gateway_id, + request_id: ctx.request_id, + client_message_index: 0, + server_message_index: 0, + path: ctx.path, + headers: ctx.headers, + }, + ); + Ok(()) + } + + async fn websocket_message( + &self, + runner: RunnerHandle, + ctx: WebSocketContext, + message: WebSocketMessage, + ) -> Result<()> { + if ctx.is_hibernatable { + runner + .send_hibernatable_websocket_message_ack( + ctx.gateway_id, + ctx.request_id, + message.message_index, + ) + .await?; + } + + let response_data = message.data.clone(); + let response_binary = message.binary; + runner + .send_websocket_message( + ctx.gateway_id, + ctx.request_id, + response_data, + response_binary, + ) + .await?; + + if let Some(actor_entries) = self + .hibernating_metadata + .lock() + .await + .get_mut(&ctx.actor_id) + { + if let Some(meta) = actor_entries.get_mut(&(ctx.gateway_id, ctx.request_id)) { + meta.server_message_index = message.message_index; + meta.client_message_index = meta.client_message_index.wrapping_add(1); + } + } + + Ok(()) + } + + async fn websocket_close( + &self, + _runner: RunnerHandle, + ctx: WebSocketContext, + _code: Option, + _reason: Option, + ) -> Result<()> { + let actor_id = ctx.actor_id.clone(); + self.closes.lock().await.push(actor_id.clone()); + if let Some(actor_entries) = self + .hibernating_metadata + .lock() + .await + .get_mut(&actor_id) + { + actor_entries.remove(&(ctx.gateway_id, ctx.request_id)); + } + Ok(()) + } + + fn can_hibernate(&self, _ctx: &WebSocketContext) -> bool { + true + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn websocket_runner_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + let runner_name = common::random_name("rust-ws-runner"); + let actor_key = common::random_name("ws"); + let app = EchoWebSocketApp::default(); + + let runner = Runner::builder( + RunnerConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(common::random_name("key")) + .token("dev") + .total_slots(16) + .build()?, + ) + .app(app.clone()) + .build()?; + runner.start().await?; + runner.wait_ready().await?; + + let actor_id = engine + .create_actor(&namespace, "ws-echo", &runner_name, Some(&actor_key)) + .await?; + wait_for_actor_presence(&app.actors, &actor_id, true, Duration::from_secs(30)).await?; + + let mut ws = engine.actor_websocket_connect(&actor_id, "/ws").await?; + ws.send(Message::Text("ping".to_string().into())).await?; + let echoed = ws + .next() + .await + .context("missing echoed text frame")??; + assert_text_message(&echoed, "ping")?; + + ws.send(Message::Binary(vec![1u8, 2, 3].into())).await?; + let echoed_binary = ws + .next() + .await + .context("missing echoed binary frame")??; + assert_binary_message(&echoed_binary, &[1, 2, 3])?; + + let mut large_payload = vec![0u8; 64 * 1024]; + for (idx, byte) in large_payload.iter_mut().enumerate() { + *byte = (idx % 251) as u8; + } + ws.send(Message::Binary(large_payload.clone().into())).await?; + let echoed_large_binary = ws + .next() + .await + .context("missing echoed large binary frame")??; + assert_binary_message(&echoed_large_binary, &large_payload)?; + + ws.close(None).await?; + wait_for_close(&app.closes, &actor_id, Duration::from_secs(10)).await?; + + runner.handle().sleep_actor(&actor_id, None).await?; + wait_for_actor_presence(&app.actors, &actor_id, false, Duration::from_secs(30)).await?; + runner.shutdown(true).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn websocket_hibernation_restore_runner_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + let runner_name = common::random_name("rust-ws-hibernation-runner"); + let actor_key = common::random_name("ws"); + let app = EchoWebSocketApp::default(); + + let runner = Runner::builder( + RunnerConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(common::random_name("key")) + .token("dev") + .total_slots(16) + .build()?, + ) + .app(app.clone()) + .build()?; + runner.start().await?; + runner.wait_ready().await?; + + let actor_id = engine + .create_actor(&namespace, "ws-echo", &runner_name, Some(&actor_key)) + .await?; + wait_for_actor_presence(&app.actors, &actor_id, true, Duration::from_secs(30)).await?; + + let mut ws = engine.actor_websocket_connect(&actor_id, "/ws").await?; + ws.send(Message::Text("before-sleep".to_string().into())).await?; + let echoed = ws + .next() + .await + .context("missing echoed before-sleep frame")??; + assert_text_message(&echoed, "before-sleep")?; + + runner.handle().sleep_actor(&actor_id, None).await?; + wait_for_actor_presence(&app.actors, &actor_id, false, Duration::from_secs(30)).await?; + + ws.send(Message::Text("after-sleep".to_string().into())).await?; + let echoed_after_sleep = ws + .next() + .await + .context("missing echoed after-sleep frame")??; + assert_text_message(&echoed_after_sleep, "after-sleep")?; + wait_for_actor_presence(&app.actors, &actor_id, true, Duration::from_secs(30)).await?; + + ws.close(None).await?; + wait_for_close(&app.closes, &actor_id, Duration::from_secs(10)).await?; + runner.shutdown(true).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn websocket_serverless_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + let runner_name = common::random_name("rust-ws-serverless"); + let actor_key = common::random_name("ws"); + let app = EchoWebSocketApp::default(); + + let serverless_runner = ServerlessRunner::builder( + ServerlessConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(common::random_name("key")) + .token("dev") + .prepopulate_actor_name("ws-echo", json!({})) + .total_slots(1) + .max_runners(1000) + .slots_per_runner(1) + .request_lifespan(300) + .build()?, + ) + .app(app.clone()) + .build()?; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .context("failed to bind serverless test listener")?; + let addr = listener.local_addr().context("missing listener local addr")?; + let serverless_url = format!("http://localhost:{}", addr.port()); + + let routes = Arc::new(serverless_runner.clone()).axum_routes(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let mut server_task = tokio::spawn(async move { + axum::serve(listener, routes) + .with_graceful_shutdown(async move { + let _ = shutdown_rx.await; + }) + .await + .context("serverless axum server exited with error") + }); + + let start_response = reqwest::Client::new() + .get(format!("{serverless_url}/api/rivet/start")) + .send() + .await?; + if start_response.status() != reqwest::StatusCode::OK { + bail!("serverless start endpoint returned {}", start_response.status()); + } + + let actor_id = engine + .create_actor(&namespace, "ws-echo", &runner_name, Some(&actor_key)) + .await?; + wait_for_actor_presence(&app.actors, &actor_id, true, Duration::from_secs(30)).await?; + + let mut ws = engine.actor_websocket_connect(&actor_id, "/ws").await?; + ws.send(Message::Text("pong".to_string().into())).await?; + let echoed = ws + .next() + .await + .context("missing echoed text frame")??; + assert_text_message(&echoed, "pong")?; + + ws.close(None).await?; + wait_for_close(&app.closes, &actor_id, Duration::from_secs(10)).await?; + + serverless_runner.runner().shutdown(true).await?; + let _ = shutdown_tx.send(()); + if tokio::time::timeout(Duration::from_secs(10), &mut server_task) + .await + .is_err() + { + server_task.abort(); + } + let _ = server_task.await; + Ok(()) +} + +fn assert_text_message(message: &Message, expected: &str) -> Result<()> { + match message { + Message::Text(text) if text.as_str() == expected => Ok(()), + _ => bail!("expected text websocket message `{expected}`, got `{message:?}`"), + } +} + +fn assert_binary_message(message: &Message, expected: &[u8]) -> Result<()> { + match message { + Message::Binary(data) if data.as_ref() == expected => Ok(()), + _ => bail!("expected binary websocket message `{expected:?}`, got `{message:?}`"), + } +} + +async fn wait_for_actor_presence( + actor_registry: &Arc>>, + actor_id: &str, + expected: bool, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + let present = actor_registry.lock().await.contains(actor_id); + if present == expected { + return Ok(()); + } + if Instant::now() >= deadline { + bail!( + "timed out waiting for actor presence state actor_id={} expected_present={} actual_present={}", + actor_id, + expected, + present + ); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +} + +async fn wait_for_close( + close_registry: &Arc>>, + actor_id: &str, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + if close_registry.lock().await.iter().any(|x| x == actor_id) { + return Ok(()); + } + if Instant::now() >= deadline { + bail!("timed out waiting for websocket close callback actor_id={actor_id}"); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +}