Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions rivetkit-rust/packages/rivetkit-core/src/serverless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub struct ServerlessStreamError {
#[derive(Debug)]
struct StartHeaders {
endpoint: String,
token: String,
token: Option<String>,
pool_name: String,
namespace: String,
}
Expand Down Expand Up @@ -428,6 +428,8 @@ impl CoreServerlessRuntime {
}
let mut guard = self.envoy.lock().await;
if let Some(handle) = guard.as_ref() {
// The start request token authenticates the serverless callback. It is not part
// of envoy identity, and may differ from the token used for the engine connection.
if !endpoints_match(handle.endpoint(), &headers.endpoint)
|| handle.namespace() != headers.namespace
|| handle.pool_name() != headers.pool_name
Expand All @@ -447,7 +449,7 @@ impl CoreServerlessRuntime {
let handle = start_envoy(EnvoyConfig {
version: self.settings.version,
endpoint: headers.endpoint.clone(),
token: Some(headers.token.clone()),
token: headers.token.clone(),
namespace: headers.namespace.clone(),
pool_name: headers.pool_name.clone(),
prepopulate_actor_names: HashMap::new(),
Expand Down Expand Up @@ -494,7 +496,7 @@ fn route_path(base_path: &str, url: &str) -> Result<String> {
fn parse_start_headers(headers: &HashMap<String, String>) -> Result<StartHeaders> {
Ok(StartHeaders {
endpoint: required_header(headers, "x-rivet-endpoint")?,
token: required_header(headers, "x-rivet-token")?,
token: optional_header(headers, "x-rivet-token"),
pool_name: required_header(headers, "x-rivet-pool-name")?,
namespace: required_header(headers, "x-rivet-namespace-name")?,
})
Expand Down
38 changes: 38 additions & 0 deletions rivetkit-rust/packages/rivetkit-core/tests/serverless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod moved_tests {

use super::{
CoreServerlessRuntime, ServerlessRequest, endpoints_match, normalize_endpoint_url,
parse_start_headers,
};
use crate::registry::ServeConfig;

Expand Down Expand Up @@ -95,6 +96,43 @@ mod moved_tests {
assert_eq!(body["code"], "invalid");
}

#[test]
fn start_headers_do_not_require_token() {
let headers = HashMap::from([
(
"x-rivet-endpoint".to_owned(),
"http://127.0.0.1:6420".to_owned(),
),
("x-rivet-pool-name".to_owned(), "default".to_owned()),
("x-rivet-namespace-name".to_owned(), "default".to_owned()),
]);

let parsed = parse_start_headers(&headers).expect("headers should parse");

assert_eq!(parsed.token, None);
}

#[test]
fn start_headers_only_use_x_rivet_token() {
let headers = HashMap::from([
(
"x-rivet-endpoint".to_owned(),
"http://127.0.0.1:6420".to_owned(),
),
("authorization".to_owned(), "Bearer fallback".to_owned()),
("x-rivet-pool-name".to_owned(), "default".to_owned()),
("x-rivet-namespace-name".to_owned(), "default".to_owned()),
]);

let parsed = parse_start_headers(&headers).expect("headers should parse");
assert_eq!(parsed.token, None);

let mut headers = headers;
headers.insert("x-rivet-token".to_owned(), "dev".to_owned());
let parsed = parse_start_headers(&headers).expect("headers should parse");
assert_eq!(parsed.token.as_deref(), Some("dev"));
}

async fn test_runtime() -> CoreServerlessRuntime {
CoreServerlessRuntime::new(
HashMap::new(),
Expand Down
Loading