From 25dd07cfd6fb0237ea4fb5f2895d0a2ab78f4df2 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sat, 25 Apr 2026 19:32:30 -0700 Subject: [PATCH] fix(rivetkit): route raw request fetches to actors --- .../pegboard-envoy/src/sqlite_runtime.rs | 6 +- .../errors/message.incoming_too_long.json | 2 +- .../rivetkit-core/src/registry/http.rs | 315 ++++-------------- .../tests/modules/registry_http.rs | 255 ++++++++++++++ 4 files changed, 327 insertions(+), 251 deletions(-) create mode 100644 rivetkit-rust/packages/rivetkit-core/tests/modules/registry_http.rs diff --git a/engine/packages/pegboard-envoy/src/sqlite_runtime.rs b/engine/packages/pegboard-envoy/src/sqlite_runtime.rs index c4ad1569b4..7aa6a7c90c 100644 --- a/engine/packages/pegboard-envoy/src/sqlite_runtime.rs +++ b/engine/packages/pegboard-envoy/src/sqlite_runtime.rs @@ -676,9 +676,9 @@ mod tests { use universaldb::driver::RocksDbDatabaseDriver; use super::{ - FILE_TAG_JOURNAL, FILE_TAG_MAIN, FILE_TAG_SHM, FILE_TAG_WAL, - SQLITE_V1_CHUNK_SIZE, SQLITE_V1_MAX_MIGRATION_BYTES, SQLITE_V1_MIGRATION_LEASE_MS, - maybe_migrate_v1_to_v2, read_v1_file, sqlite_subspace, v1_chunk_key, v1_meta_key, + FILE_TAG_JOURNAL, FILE_TAG_MAIN, FILE_TAG_SHM, FILE_TAG_WAL, SQLITE_V1_CHUNK_SIZE, + SQLITE_V1_MAX_MIGRATION_BYTES, SQLITE_V1_MIGRATION_LEASE_MS, maybe_migrate_v1_to_v2, + read_v1_file, sqlite_subspace, v1_chunk_key, v1_meta_key, }; fn recipient(actor_id: Id) -> Recipient { diff --git a/rivetkit-rust/engine/artifacts/errors/message.incoming_too_long.json b/rivetkit-rust/engine/artifacts/errors/message.incoming_too_long.json index 729519603c..e35ce9f122 100644 --- a/rivetkit-rust/engine/artifacts/errors/message.incoming_too_long.json +++ b/rivetkit-rust/engine/artifacts/errors/message.incoming_too_long.json @@ -1,5 +1,5 @@ { "code": "incoming_too_long", "group": "message", - "message": "Incoming message too long." + "message": "Incoming message too long" } \ No newline at end of file diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/http.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/http.rs index 2ca2ebe6c8..9f02d16dcd 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/http.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/http.rs @@ -15,12 +15,13 @@ impl RegistryDispatcher { return self.handle_metrics_fetch(&instance, &request); } + let original_path = request.path.clone(); let request = build_http_request(request).await?; - let framework_route = framework_http_route(request.uri().path())?; + let route = RegistryHttpRoute::from_paths(&original_path, request.uri().path())?; let instance = match self.active_actor(actor_id).await { Ok(instance) => instance, Err(error) => { - if framework_route.is_some() { + if matches!(route, RegistryHttpRoute::Framework(_)) { return message_boundary_error_response( request_encoding(request.headers()), framework_anyhow_status(&error), @@ -36,20 +37,24 @@ impl RegistryDispatcher { instance.ctx.cancel_sleep_timer(); - let rearm_sleep_after_request = |ctx: ActorContext| { - let sleep_ctx = ctx.clone(); - ctx.wait_until(async move { - sleep_ctx.wait_for_http_requests_idle().await; - sleep_ctx.reset_sleep_timer(); - }); + let response = match route { + RegistryHttpRoute::Framework(route) => { + let response = self.handle_framework_fetch(&instance, request, route).await; + rearm_sleep_after_request(instance.ctx.clone()); + response + } + RegistryHttpRoute::UserRawRequest => { + self.handle_user_request_fetch(&instance, request).await + } }; + response + } - if let Some(route) = framework_route { - let response = self.handle_framework_fetch(&instance, request, route).await; - rearm_sleep_after_request(instance.ctx.clone()); - return response; - } - + async fn handle_user_request_fetch( + &self, + instance: &ActorTaskHandle, + request: Request, + ) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); try_send_dispatch_command( &instance.dispatch, @@ -72,7 +77,11 @@ impl RegistryDispatcher { build_envoy_response(response) } Err(error) => { - tracing::error!(actor_id, ?error, "actor request callback failed"); + tracing::error!( + actor_id = instance.actor_id, + ?error, + "actor request callback failed" + ); rearm_sleep_after_request(instance.ctx.clone()); Ok(inspector_anyhow_response(error)) } @@ -373,6 +382,40 @@ impl RegistryDispatcher { } } +enum RegistryHttpRoute { + Framework(FrameworkHttpRoute), + UserRawRequest, +} + +impl RegistryHttpRoute { + fn from_paths(original_path: &str, normalized_path: &str) -> Result { + if let Some(stripped) = original_path.strip_prefix("/request") { + if stripped.is_empty() || matches!(stripped.as_bytes().first(), Some(b'/') | Some(b'?')) + { + return Ok(Self::UserRawRequest); + } + } + + if let Some(segment) = single_path_segment(normalized_path, "/action/") { + return Ok(Self::Framework(FrameworkHttpRoute::Action( + percent_decode_path_segment(segment)?, + ))); + } + if let Some(segment) = single_path_segment(normalized_path, "/queue/") { + return Ok(Self::Framework(FrameworkHttpRoute::Queue( + percent_decode_path_segment(segment)?, + ))); + } + + match normalized_path { + "/metadata" => Ok(Self::Framework(FrameworkHttpRoute::Metadata)), + "/health" => Ok(Self::Framework(FrameworkHttpRoute::Health)), + "/" => Ok(Self::Framework(FrameworkHttpRoute::Root)), + _ => Ok(Self::UserRawRequest), + } + } +} + pub(super) enum FrameworkHttpRoute { Action(String), Queue(String), @@ -387,25 +430,6 @@ pub(super) struct DecodedHttpQueueRequest { timeout: Option, } -pub(super) fn framework_http_route(path: &str) -> Result> { - if let Some(segment) = single_path_segment(path, "/action/") { - return Ok(Some(FrameworkHttpRoute::Action( - percent_decode_path_segment(segment)?, - ))); - } - if let Some(segment) = single_path_segment(path, "/queue/") { - return Ok(Some(FrameworkHttpRoute::Queue( - percent_decode_path_segment(segment)?, - ))); - } - match path { - "/metadata" => Ok(Some(FrameworkHttpRoute::Metadata)), - "/health" => Ok(Some(FrameworkHttpRoute::Health)), - "/" => Ok(Some(FrameworkHttpRoute::Root)), - _ => Ok(None), - } -} - fn handle_metadata_fetch(request: &Request) -> Result { if request.method() != http::Method::GET { return method_not_allowed_response(request); @@ -469,6 +493,14 @@ fn method_not_allowed_response(request: &Request) -> Result { ) } +fn rearm_sleep_after_request(ctx: ActorContext) { + let sleep_ctx = ctx.clone(); + ctx.wait_until(async move { + sleep_ctx.wait_for_http_requests_idle().await; + sleep_ctx.reset_sleep_timer(); + }); +} + pub(super) fn single_path_segment<'a>(path: &'a str, prefix: &str) -> Option<&'a str> { let segment = path.strip_prefix(prefix)?; (!segment.is_empty() && !segment.contains('/')).then_some(segment) @@ -899,216 +931,5 @@ fn bearer_token_from_authorization(value: &str) -> Option<&str> { } #[cfg(test)] -mod tests { - use std::collections::HashMap; - use std::time::Duration; - - use super::{ - HttpRequest, HttpResponseEncoding, authorization_bearer_token, - authorization_bearer_token_map, framework_action_error_response, - message_boundary_error_response, request_encoding, request_has_bearer_token, - workflow_dispatch_result, - }; - use crate::actor::action::ActionDispatchError; - use crate::error::ActorLifecycle as ActorLifecycleError; - use http::StatusCode; - use rivet_error::RivetError; - use serde_json::json; - use vbare::OwnedVersionedData; - - #[derive(RivetError)] - #[error("message", "incoming_too_long", "Incoming message too long")] - struct IncomingMessageTooLong; - - #[derive(RivetError)] - #[error("message", "outgoing_too_long", "Outgoing message too long")] - struct OutgoingMessageTooLong; - - #[test] - fn workflow_dispatch_result_marks_handled_workflow_as_enabled() { - assert_eq!( - workflow_dispatch_result(Ok(Some(vec![1, 2, 3]))) - .expect("workflow dispatch should succeed"), - (true, Some(vec![1, 2, 3])), - ); - assert_eq!( - workflow_dispatch_result(Ok(None)).expect("workflow dispatch should succeed"), - (true, None), - ); - } - - #[test] - fn workflow_dispatch_result_treats_dropped_reply_as_disabled() { - assert_eq!( - workflow_dispatch_result(Err(ActorLifecycleError::DroppedReply.build())) - .expect("dropped reply should map to workflow disabled"), - (false, None), - ); - } - - #[test] - fn workflow_dispatch_result_preserves_non_dropped_reply_errors() { - let error = workflow_dispatch_result(Err(ActorLifecycleError::Destroying.build())) - .expect_err("non-dropped reply errors should be preserved"); - let error = rivet_error::RivetError::extract(&error); - assert_eq!(error.group(), "actor"); - assert_eq!(error.code(), "destroying"); - } - - #[test] - fn inspector_error_status_maps_action_timeout_to_408() { - assert_eq!( - super::inspector_error_status("actor", "action_timed_out"), - StatusCode::REQUEST_TIMEOUT, - ); - } - - #[test] - fn authorization_bearer_token_accepts_case_insensitive_scheme_and_whitespace() { - let mut headers = http::HeaderMap::new(); - headers.insert( - http::header::AUTHORIZATION, - "bearer test-token".parse().unwrap(), - ); - - assert_eq!(authorization_bearer_token(&headers), Some("test-token")); - - let map = HashMap::from([( - http::header::AUTHORIZATION.as_str().to_owned(), - "BEARER\ttest-token".to_owned(), - )]); - assert_eq!(authorization_bearer_token_map(&map), Some("test-token")); - } - - #[test] - fn request_has_bearer_token_uses_same_authorization_parser() { - let request = HttpRequest { - method: "GET".to_owned(), - path: "/metrics".to_owned(), - headers: HashMap::from([( - http::header::AUTHORIZATION.as_str().to_owned(), - "Bearer configured".to_owned(), - )]), - body: Some(Vec::new()), - body_stream: None, - }; - - assert!(request_has_bearer_token(&request, Some("configured"))); - assert!(!request_has_bearer_token(&request, Some("other"))); - } - - #[tokio::test] - async fn action_dispatch_timeout_returns_structured_error() { - let error = super::with_action_dispatch_timeout(Duration::from_millis(1), async { - tokio::time::sleep(Duration::from_secs(60)).await; - Ok::, ActionDispatchError>(Vec::new()) - }) - .await - .expect_err("timeout should return an action dispatch error"); - - assert_eq!(error.group, "actor"); - assert_eq!(error.code, "action_timed_out"); - assert_eq!(error.message, "Action timed out"); - } - - #[tokio::test] - async fn framework_action_timeout_returns_structured_error() { - let error = super::with_framework_action_timeout(Duration::from_millis(1), async { - tokio::time::sleep(Duration::from_secs(60)).await; - Ok::<(), anyhow::Error>(()) - }) - .await - .expect_err("timeout should return a framework error"); - let error = RivetError::extract(&error); - - assert_eq!(error.group(), "actor"); - assert_eq!(error.code(), "action_timed_out"); - assert_eq!(error.message(), "Action timed out"); - } - - #[test] - fn framework_action_error_response_maps_timeout_to_408() { - let response = framework_action_error_response( - HttpResponseEncoding::Json, - ActionDispatchError { - group: "actor".to_owned(), - code: "action_timed_out".to_owned(), - message: "Action timed out".to_owned(), - metadata: None, - }, - ) - .expect("timeout error response should serialize"); - - assert_eq!(response.status, StatusCode::REQUEST_TIMEOUT.as_u16()); - assert_eq!( - response.body, - Some( - serde_json::to_vec(&json!({ - "group": "actor", - "code": "action_timed_out", - "message": "Action timed out", - })) - .expect("json body should encode") - ) - ); - } - - #[test] - fn message_boundary_error_response_defaults_to_json() { - let response = message_boundary_error_response( - HttpResponseEncoding::Json, - StatusCode::BAD_REQUEST, - IncomingMessageTooLong.build(), - ) - .expect("json response should serialize"); - - assert_eq!(response.status, StatusCode::BAD_REQUEST.as_u16()); - assert_eq!( - response.headers.get(http::header::CONTENT_TYPE.as_str()), - Some(&"application/json".to_owned()) - ); - assert_eq!( - response.body, - Some( - serde_json::to_vec(&json!({ - "group": "message", - "code": "incoming_too_long", - "message": "Incoming message too long", - })) - .expect("json body should encode") - ) - ); - } - - #[test] - fn request_encoding_reads_cbor_header() { - let mut headers = http::HeaderMap::new(); - headers.insert("x-rivet-encoding", "cbor".parse().unwrap()); - - assert_eq!(request_encoding(&headers), HttpResponseEncoding::Cbor); - } - - #[test] - fn message_boundary_error_response_serializes_bare_v3() { - let response = message_boundary_error_response( - HttpResponseEncoding::Bare, - StatusCode::BAD_REQUEST, - OutgoingMessageTooLong.build(), - ) - .expect("bare response should serialize"); - - assert_eq!( - response.headers.get(http::header::CONTENT_TYPE.as_str()), - Some(&"application/octet-stream".to_owned()) - ); - - let body = response.body.expect("bare response should include body"); - let decoded = - ::deserialize_with_embedded_version(&body) - .expect("bare error should decode"); - assert_eq!(decoded.group, "message"); - assert_eq!(decoded.code, "outgoing_too_long"); - assert_eq!(decoded.message, "Outgoing message too long"); - assert_eq!(decoded.metadata, None); - } -} +#[path = "../../tests/modules/registry_http.rs"] +mod tests; diff --git a/rivetkit-rust/packages/rivetkit-core/tests/modules/registry_http.rs b/rivetkit-rust/packages/rivetkit-core/tests/modules/registry_http.rs new file mode 100644 index 0000000000..98049305ef --- /dev/null +++ b/rivetkit-rust/packages/rivetkit-core/tests/modules/registry_http.rs @@ -0,0 +1,255 @@ +use std::collections::HashMap; +use std::time::Duration; + +use super::{ + FrameworkHttpRoute, HttpRequest, HttpResponseEncoding, RegistryHttpRoute, + authorization_bearer_token, authorization_bearer_token_map, framework_action_error_response, + inspector_error_status, message_boundary_error_response, normalize_actor_request_path, + request_encoding, request_has_bearer_token, with_action_dispatch_timeout, + with_framework_action_timeout, workflow_dispatch_result, +}; +use crate::actor::action::ActionDispatchError; +use crate::error::ActorLifecycle as ActorLifecycleError; +use http::StatusCode; +use rivet_error::RivetError; +use serde_json::json; +use vbare::OwnedVersionedData; + +#[derive(RivetError)] +#[error("message", "incoming_too_long", "Incoming message too long")] +struct IncomingMessageTooLong; + +#[derive(RivetError)] +#[error("message", "outgoing_too_long", "Outgoing message too long")] +struct OutgoingMessageTooLong; + +#[test] +fn request_prefix_detection_matches_normalization() { + assert_eq!(normalize_actor_request_path("/request"), "/"); + assert_eq!(normalize_actor_request_path("/request/"), "/"); + assert_eq!(normalize_actor_request_path("/request/users"), "/users"); + assert_eq!(normalize_actor_request_path("/request?foo=bar"), "?foo=bar"); + assert_eq!(normalize_actor_request_path("/requestfoo"), "/requestfoo"); + + assert!(matches!( + RegistryHttpRoute::from_paths("/request", "/").expect("route should decode"), + RegistryHttpRoute::UserRawRequest + )); + assert!(matches!( + RegistryHttpRoute::from_paths("/request/users", "/users").expect("route should decode"), + RegistryHttpRoute::UserRawRequest + )); + assert!(matches!( + RegistryHttpRoute::from_paths("/request?foo=bar", "?foo=bar").expect("route should decode"), + RegistryHttpRoute::UserRawRequest + )); + assert!(matches!( + RegistryHttpRoute::from_paths("/requestfoo", "/requestfoo").expect("route should decode"), + RegistryHttpRoute::UserRawRequest + )); +} + +#[test] +fn classifier_keeps_framework_and_user_routing_separate() { + let route = RegistryHttpRoute::from_paths("/action/increment", "/action/increment") + .expect("route should decode"); + assert!(matches!( + route, + RegistryHttpRoute::Framework(FrameworkHttpRoute::Action(name)) if name == "increment" + )); + + let route = RegistryHttpRoute::from_paths("/request/action/increment", "/action/increment") + .expect("route should decode"); + assert!(matches!(route, RegistryHttpRoute::UserRawRequest)); + + let route = RegistryHttpRoute::from_paths("/custom", "/custom").expect("route should decode"); + assert!(matches!(route, RegistryHttpRoute::UserRawRequest)); +} + +#[test] +fn workflow_dispatch_result_marks_handled_workflow_as_enabled() { + assert_eq!( + workflow_dispatch_result(Ok(Some(vec![1, 2, 3]))) + .expect("workflow dispatch should succeed"), + (true, Some(vec![1, 2, 3])), + ); + assert_eq!( + workflow_dispatch_result(Ok(None)).expect("workflow dispatch should succeed"), + (true, None), + ); +} + +#[test] +fn workflow_dispatch_result_treats_dropped_reply_as_disabled() { + assert_eq!( + workflow_dispatch_result(Err(ActorLifecycleError::DroppedReply.build())) + .expect("dropped reply should map to workflow disabled"), + (false, None), + ); +} + +#[test] +fn workflow_dispatch_result_preserves_non_dropped_reply_errors() { + let error = workflow_dispatch_result(Err(ActorLifecycleError::Destroying.build())) + .expect_err("non-dropped reply errors should be preserved"); + let error = rivet_error::RivetError::extract(&error); + assert_eq!(error.group(), "actor"); + assert_eq!(error.code(), "destroying"); +} + +#[test] +fn inspector_error_status_maps_action_timeout_to_408() { + assert_eq!( + inspector_error_status("actor", "action_timed_out"), + StatusCode::REQUEST_TIMEOUT, + ); +} + +#[test] +fn authorization_bearer_token_accepts_case_insensitive_scheme_and_whitespace() { + let mut headers = http::HeaderMap::new(); + headers.insert( + http::header::AUTHORIZATION, + "bearer test-token".parse().unwrap(), + ); + + assert_eq!(authorization_bearer_token(&headers), Some("test-token")); + + let map = HashMap::from([( + http::header::AUTHORIZATION.as_str().to_owned(), + "BEARER\ttest-token".to_owned(), + )]); + assert_eq!(authorization_bearer_token_map(&map), Some("test-token")); +} + +#[test] +fn request_has_bearer_token_uses_same_authorization_parser() { + let request = HttpRequest { + method: "GET".to_owned(), + path: "/metrics".to_owned(), + headers: HashMap::from([( + http::header::AUTHORIZATION.as_str().to_owned(), + "Bearer configured".to_owned(), + )]), + body: Some(Vec::new()), + body_stream: None, + }; + + assert!(request_has_bearer_token(&request, Some("configured"))); + assert!(!request_has_bearer_token(&request, Some("other"))); +} + +#[tokio::test] +async fn action_dispatch_timeout_returns_structured_error() { + let error = with_action_dispatch_timeout(Duration::from_millis(1), async { + tokio::time::sleep(Duration::from_secs(60)).await; + Ok::, ActionDispatchError>(Vec::new()) + }) + .await + .expect_err("timeout should return an action dispatch error"); + + assert_eq!(error.group, "actor"); + assert_eq!(error.code, "action_timed_out"); + assert_eq!(error.message, "Action timed out"); +} + +#[tokio::test] +async fn framework_action_timeout_returns_structured_error() { + let error = with_framework_action_timeout(Duration::from_millis(1), async { + tokio::time::sleep(Duration::from_secs(60)).await; + Ok::<(), anyhow::Error>(()) + }) + .await + .expect_err("timeout should return a framework error"); + let error = rivet_error::RivetError::extract(&error); + + assert_eq!(error.group(), "actor"); + assert_eq!(error.code(), "action_timed_out"); + assert_eq!(error.message(), "Action timed out"); +} + +#[test] +fn framework_action_error_response_maps_timeout_to_408() { + let response = framework_action_error_response( + HttpResponseEncoding::Json, + ActionDispatchError { + group: "actor".to_owned(), + code: "action_timed_out".to_owned(), + message: "Action timed out".to_owned(), + metadata: None, + }, + ) + .expect("timeout error response should serialize"); + + assert_eq!(response.status, StatusCode::REQUEST_TIMEOUT.as_u16()); + assert_eq!( + response.body, + Some( + serde_json::to_vec(&json!({ + "group": "actor", + "code": "action_timed_out", + "message": "Action timed out", + })) + .expect("json body should encode") + ) + ); +} + +#[test] +fn message_boundary_error_response_defaults_to_json() { + let response = message_boundary_error_response( + HttpResponseEncoding::Json, + StatusCode::BAD_REQUEST, + IncomingMessageTooLong.build(), + ) + .expect("json response should serialize"); + + assert_eq!(response.status, StatusCode::BAD_REQUEST.as_u16()); + assert_eq!( + response.headers.get(http::header::CONTENT_TYPE.as_str()), + Some(&"application/json".to_owned()) + ); + assert_eq!( + response.body, + Some( + serde_json::to_vec(&json!({ + "group": "message", + "code": "incoming_too_long", + "message": "Incoming message too long", + })) + .expect("json body should encode") + ) + ); +} + +#[test] +fn request_encoding_reads_cbor_header() { + let mut headers = http::HeaderMap::new(); + headers.insert("x-rivet-encoding", "cbor".parse().unwrap()); + + assert_eq!(request_encoding(&headers), HttpResponseEncoding::Cbor); +} + +#[test] +fn message_boundary_error_response_serializes_bare_v3() { + let response = message_boundary_error_response( + HttpResponseEncoding::Bare, + StatusCode::BAD_REQUEST, + OutgoingMessageTooLong.build(), + ) + .expect("bare response should serialize"); + + assert_eq!( + response.headers.get(http::header::CONTENT_TYPE.as_str()), + Some(&"application/octet-stream".to_owned()) + ); + + let body = response.body.expect("bare response should include body"); + let decoded = + ::deserialize_with_embedded_version(&body) + .expect("bare error should decode"); + assert_eq!(decoded.group, "message"); + assert_eq!(decoded.code, "outgoing_too_long"); + assert_eq!(decoded.message, "Outgoing message too long"); + assert_eq!(decoded.metadata, None); +}