diff --git a/rivetkit-rust/packages/rivetkit-core/src/serverless.rs b/rivetkit-rust/packages/rivetkit-core/src/serverless.rs index 8f61305941..cae84bfa89 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/serverless.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/serverless.rs @@ -79,7 +79,7 @@ pub struct ServerlessStreamError { #[derive(Debug)] struct StartHeaders { endpoint: String, - token: String, + token: Option, pool_name: String, namespace: String, } @@ -428,6 +428,8 @@ impl CoreServerlessRuntime { } let mut guard = self.envoy.lock().await; if let Some(handle) = guard.as_ref() { + // The start request token authenticates the serverless callback. It is not part + // of envoy identity, and may differ from the token used for the engine connection. if !endpoints_match(handle.endpoint(), &headers.endpoint) || handle.namespace() != headers.namespace || handle.pool_name() != headers.pool_name @@ -447,7 +449,7 @@ impl CoreServerlessRuntime { let handle = start_envoy(EnvoyConfig { version: self.settings.version, endpoint: headers.endpoint.clone(), - token: Some(headers.token.clone()), + token: headers.token.clone(), namespace: headers.namespace.clone(), pool_name: headers.pool_name.clone(), prepopulate_actor_names: HashMap::new(), @@ -494,7 +496,7 @@ fn route_path(base_path: &str, url: &str) -> Result { fn parse_start_headers(headers: &HashMap) -> Result { Ok(StartHeaders { endpoint: required_header(headers, "x-rivet-endpoint")?, - token: required_header(headers, "x-rivet-token")?, + token: optional_header(headers, "x-rivet-token"), pool_name: required_header(headers, "x-rivet-pool-name")?, namespace: required_header(headers, "x-rivet-namespace-name")?, }) diff --git a/rivetkit-rust/packages/rivetkit-core/tests/serverless.rs b/rivetkit-rust/packages/rivetkit-core/tests/serverless.rs index b91e3a1d49..f7fd5aea16 100644 --- a/rivetkit-rust/packages/rivetkit-core/tests/serverless.rs +++ b/rivetkit-rust/packages/rivetkit-core/tests/serverless.rs @@ -7,6 +7,7 @@ mod moved_tests { use super::{ CoreServerlessRuntime, ServerlessRequest, endpoints_match, normalize_endpoint_url, + parse_start_headers, }; use crate::registry::ServeConfig; @@ -95,6 +96,43 @@ mod moved_tests { assert_eq!(body["code"], "invalid"); } + #[test] + fn start_headers_do_not_require_token() { + let headers = HashMap::from([ + ( + "x-rivet-endpoint".to_owned(), + "http://127.0.0.1:6420".to_owned(), + ), + ("x-rivet-pool-name".to_owned(), "default".to_owned()), + ("x-rivet-namespace-name".to_owned(), "default".to_owned()), + ]); + + let parsed = parse_start_headers(&headers).expect("headers should parse"); + + assert_eq!(parsed.token, None); + } + + #[test] + fn start_headers_only_use_x_rivet_token() { + let headers = HashMap::from([ + ( + "x-rivet-endpoint".to_owned(), + "http://127.0.0.1:6420".to_owned(), + ), + ("authorization".to_owned(), "Bearer fallback".to_owned()), + ("x-rivet-pool-name".to_owned(), "default".to_owned()), + ("x-rivet-namespace-name".to_owned(), "default".to_owned()), + ]); + + let parsed = parse_start_headers(&headers).expect("headers should parse"); + assert_eq!(parsed.token, None); + + let mut headers = headers; + headers.insert("x-rivet-token".to_owned(), "dev".to_owned()); + let parsed = parse_start_headers(&headers).expect("headers should parse"); + assert_eq!(parsed.token.as_deref(), Some("dev")); + } + async fn test_runtime() -> CoreServerlessRuntime { CoreServerlessRuntime::new( HashMap::new(),