diff --git a/engine/artifacts/openapi.json b/engine/artifacts/openapi.json index fe9ae7bd15..fcc973f4cd 100644 --- a/engine/artifacts/openapi.json +++ b/engine/artifacts/openapi.json @@ -2188,11 +2188,15 @@ "invalid_response_json": { "type": "object", "required": [ - "body" + "body", + "parse_error" ], "properties": { "body": { "type": "string" + }, + "parse_error": { + "type": "string" } } } @@ -2220,6 +2224,27 @@ } } } + }, + { + "type": "object", + "required": [ + "invalid_envoy_protocol_version" + ], + "properties": { + "invalid_envoy_protocol_version": { + "type": "object", + "required": [ + "version" + ], + "properties": { + "version": { + "type": "integer", + "format": "int32", + "minimum": 0 + } + } + } + } } ] }, diff --git a/engine/packages/guard/src/routing/mod.rs b/engine/packages/guard/src/routing/mod.rs index 27542eced8..7f5ec70706 100644 --- a/engine/packages/guard/src/routing/mod.rs +++ b/engine/packages/guard/src/routing/mod.rs @@ -95,6 +95,7 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) - .split(',') .map(|p| p.trim()) .find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET)) + .map(ToOwned::to_owned) }) } else { // For HTTP, use the x-rivet-target header @@ -102,12 +103,13 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) - .headers() .get(X_RIVET_TARGET) .and_then(|x| x.to_str().ok()) + .map(ToOwned::to_owned) }; // Read target if let Some(target) = target { if let Some(routing_output) = - pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, target) + pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, &target) .await? { metrics::ROUTE_TOTAL.with_label_values(&["gateway"]).inc(); @@ -116,7 +118,7 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) - } if let Some(routing_output) = - runner::route_request(&ctx, req_ctx, target).await? + runner::route_request(&ctx, req_ctx, &target).await? { metrics::ROUTE_TOTAL.with_label_values(&["runner"]).inc(); @@ -124,14 +126,14 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) - } if let Some(routing_output) = - envoy::route_request(&ctx, req_ctx, target).await? + envoy::route_request(&ctx, req_ctx, &target).await? { metrics::ROUTE_TOTAL.with_label_values(&["envoy"]).inc(); return Ok(routing_output); } - if let Some(routing_output) = api_public::route_request(&ctx, target).await? { + if let Some(routing_output) = api_public::route_request(&ctx, &target).await? { metrics::ROUTE_TOTAL.with_label_values(&["api"]).inc(); return Ok(routing_output); diff --git a/engine/packages/guard/src/routing/pegboard_gateway/cors.rs b/engine/packages/guard/src/routing/pegboard_gateway/cors.rs new file mode 100644 index 0000000000..76f15c222d --- /dev/null +++ b/engine/packages/guard/src/routing/pegboard_gateway/cors.rs @@ -0,0 +1,71 @@ +use anyhow::Result; +use async_trait::async_trait; +use bytes::Bytes; +use http_body_util::Full; +use hyper::{Request, Response, StatusCode}; +use rivet_guard_core::{ + ResponseBody, + custom_serve::CustomServeTrait, + request_context::{CorsConfig, RequestContext}, +}; + +pub fn origin_header(req_ctx: &RequestContext) -> String { + req_ctx + .headers() + .get("origin") + .and_then(|v| v.to_str().ok()) + .unwrap_or("*") + .to_string() +} + +pub fn set_non_preflight_cors(req_ctx: &mut RequestContext) { + let allow_origin = origin_header(req_ctx); + req_ctx.set_cors(CorsConfig { + allow_origin, + allow_credentials: true, + expose_headers: "*".to_string(), + allow_methods: None, + allow_headers: None, + max_age: None, + }); +} + +/// Responds to CORS preflight OPTIONS requests with 204 and permissive CORS +/// headers. Avoids actor lookup, wake, and auth because browsers cannot attach +/// credentials to preflights. The actual request that follows is still authed. +pub struct CorsPreflight; + +#[async_trait] +impl CustomServeTrait for CorsPreflight { + async fn handle_request( + &self, + req: Request>, + req_ctx: &mut RequestContext, + ) -> Result> { + let allow_origin = req + .headers() + .get("origin") + .and_then(|v| v.to_str().ok()) + .unwrap_or("*") + .to_string(); + let allow_headers = req + .headers() + .get("access-control-request-headers") + .and_then(|v| v.to_str().ok()) + .unwrap_or("*") + .to_string(); + + req_ctx.set_cors(CorsConfig { + allow_origin, + allow_credentials: true, + expose_headers: "*".to_string(), + allow_methods: Some("GET, POST, PUT, DELETE, OPTIONS, PATCH".to_string()), + allow_headers: Some(allow_headers), + max_age: Some(86400), + }); + + Ok(Response::builder() + .status(StatusCode::NO_CONTENT) + .body(ResponseBody::Full(Full::new(Bytes::new())))?) + } +} diff --git a/engine/packages/guard/src/routing/pegboard_gateway/mod.rs b/engine/packages/guard/src/routing/pegboard_gateway/mod.rs index 5637e8e770..c3eb5cdc56 100644 --- a/engine/packages/guard/src/routing/pegboard_gateway/mod.rs +++ b/engine/packages/guard/src/routing/pegboard_gateway/mod.rs @@ -1,6 +1,7 @@ +mod cors; mod resolve_actor_query; -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use anyhow::Result; use gas::{ctx::message::SubscriptionHandle, prelude::*}; @@ -19,6 +20,7 @@ use crate::{ }, shared_state::SharedState, }; +use cors::{CorsPreflight, set_non_preflight_cors}; use resolve_actor_query::resolve_query; const ACTOR_FORCE_WAKE_PENDING_TIMEOUT: i64 = util::duration::seconds(60); @@ -36,12 +38,16 @@ pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor"); pub async fn route_request_path_based( ctx: &StandaloneCtx, shared_state: &SharedState, - req_ctx: &RequestContext, + req_ctx: &mut RequestContext, ) -> Result> { let Some(actor_path) = parse_actor_path(req_ctx.path())? else { return Ok(None); }; + if req_ctx.method() == hyper::Method::OPTIONS { + return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight)))); + } + tracing::debug!(?actor_path, "routing using path-based actor routing"); let (actor_id, token, stripped_path, bypass_connectable) = match actor_path { @@ -101,7 +107,7 @@ pub async fn route_request_path_based( pub async fn route_request( ctx: &StandaloneCtx, shared_state: &SharedState, - req_ctx: &RequestContext, + req_ctx: &mut RequestContext, target: &str, ) -> Result> { // Check target @@ -109,6 +115,10 @@ pub async fn route_request( return Ok(None); } + if req_ctx.method() == hyper::Method::OPTIONS { + return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight)))); + } + // Extract actor ID and token from WebSocket protocol or HTTP headers let (actor_id_str, token, bypass_connectable) = if req_ctx.is_websocket() { // For WebSocket, parse the sec-websocket-protocol header @@ -141,7 +151,8 @@ pub async fn route_request( let token = protocols .iter() - .find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN)); + .find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN)) + .map(ToOwned::to_owned); let bypass_connectable = protocols .iter() @@ -168,7 +179,8 @@ pub async fn route_request( .get(X_RIVET_TOKEN) .map(|x| x.to_str()) .transpose() - .context("invalid x-rivet-token header")?; + .context("invalid x-rivet-token header")? + .map(ToOwned::to_owned); let bypass_connectable = req_ctx.headers().contains_key(X_RIVET_BYPASS_CONNECTABLE); @@ -177,14 +189,15 @@ pub async fn route_request( // Find actor to route to let actor_id = Id::parse(&actor_id_str).context("invalid x-rivet-actor header")?; + let stripped_path = req_ctx.path().to_owned(); route_request_inner( ctx, shared_state, req_ctx, actor_id, - req_ctx.path(), - token, + &stripped_path, + token.as_deref(), bypass_connectable, ) .await @@ -194,12 +207,17 @@ pub async fn route_request( async fn route_request_inner( ctx: &StandaloneCtx, shared_state: &SharedState, - req_ctx: &RequestContext, + req_ctx: &mut RequestContext, actor_id: Id, stripped_path: &str, _token: Option<&str>, bypass_connectable: bool, ) -> Result { + // Attach CORS headers to the actual (non-OPTIONS) response so both the + // actor response and any early error (e.g. EE auth failure) are readable + // by the browser. + set_non_preflight_cors(req_ctx); + // NOTE: Token validation implemented in EE // Route to peer dc where the actor lives diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index 218f7ba5ef..c3b3b8dcdb 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -10,7 +10,7 @@ use rivet_guard_core::{ ResponseBody, WebSocketHandle, custom_serve::{CustomServeTrait, HibernationResult}, errors::{ServiceUnavailable, WebSocketServiceUnavailable}, - request_context::{CorsConfig, RequestContext}, + request_context::RequestContext, utils::is_ws_hibernate, websocket_handle::WebSocketReceiver, }; @@ -92,15 +92,6 @@ impl PegboardGateway { let actor_id = self.actor_id.to_string(); let request_id = req_ctx.in_flight_request_id()?; - // Extract origin for CORS (before consuming request) - // When credentials: true, we must echo back the actual origin, not "*" - let origin = req - .headers() - .get("origin") - .and_then(|v| v.to_str().ok()) - .unwrap_or("*") - .to_string(); - // Extract request parts let headers = req .headers() @@ -113,47 +104,6 @@ impl PegboardGateway { }) .collect::>(); - // Handle CORS preflight OPTIONS requests at gateway level - // - // We need to do this in the gateway because there is no way of sending an OPTIONS request to the - // actor since we don't have the `x-rivet-token` header. This implementation allows - // requests from anywhere and lets the actor handle CORS manually in `onBeforeConnect`. - // This had the added benefit of also applying to WebSockets. - if req.method() == hyper::Method::OPTIONS { - tracing::debug!("handling OPTIONS preflight request at gateway"); - - // Extract requested headers - let requested_headers = req - .headers() - .get("access-control-request-headers") - .and_then(|v| v.to_str().ok()) - .unwrap_or("*"); - - req_ctx.set_cors(CorsConfig { - allow_origin: origin.clone(), - allow_credentials: true, - expose_headers: "*".to_string(), - allow_methods: Some("GET, POST, PUT, DELETE, OPTIONS, PATCH".to_string()), - allow_headers: Some(requested_headers.to_string()), - max_age: Some(86400), - }); - - return Ok(Response::builder() - .status(StatusCode::NO_CONTENT) - .body(ResponseBody::Full(Full::new(Bytes::new())))?); - } - - // Set CORS headers through guard - req_ctx.set_cors(CorsConfig { - allow_origin: origin.clone(), - allow_credentials: true, - expose_headers: "*".to_string(), - // Not an options req, not required - allow_methods: None, - allow_headers: None, - max_age: None, - }); - // NOTE: Size constraints have already been applied by guard let body_bytes = req .into_body() diff --git a/engine/packages/pegboard-gateway2/src/lib.rs b/engine/packages/pegboard-gateway2/src/lib.rs index 69cacb2eef..e1571c9699 100644 --- a/engine/packages/pegboard-gateway2/src/lib.rs +++ b/engine/packages/pegboard-gateway2/src/lib.rs @@ -11,7 +11,7 @@ use rivet_guard_core::{ ResponseBody, WebSocketHandle, custom_serve::{CustomServeTrait, HibernationResult}, errors::{ServiceUnavailable, WebSocketServiceUnavailable}, - request_context::{CorsConfig, RequestContext}, + request_context::RequestContext, utils::is_ws_hibernate, websocket_handle::WebSocketReceiver, }; @@ -95,15 +95,6 @@ impl PegboardGateway2 { let actor_id = self.actor_id.to_string(); let request_id = req_ctx.in_flight_request_id()?; - // Extract origin for CORS (before consuming request) - // When credentials: true, we must echo back the actual origin, not "*" - let origin = req - .headers() - .get("origin") - .and_then(|v| v.to_str().ok()) - .unwrap_or("*") - .to_string(); - // Extract request parts let headers = req .headers() @@ -116,47 +107,6 @@ impl PegboardGateway2 { }) .collect::>(); - // Handle CORS preflight OPTIONS requests at gateway level - // - // We need to do this in the gateway because there is no way of sending an OPTIONS request to the - // actor since we don't have the `x-rivet-token` header. This implementation allows - // requests from anywhere and lets the actor handle CORS manually in `onBeforeConnect`. - // This had the added benefit of also applying to WebSockets. - if req.method() == hyper::Method::OPTIONS { - tracing::debug!("handling OPTIONS preflight request at gateway"); - - // Extract requested headers - let requested_headers = req - .headers() - .get("access-control-request-headers") - .and_then(|v| v.to_str().ok()) - .unwrap_or("*"); - - req_ctx.set_cors(CorsConfig { - allow_origin: origin.clone(), - allow_credentials: true, - expose_headers: "*".to_string(), - allow_methods: Some("GET, POST, PUT, DELETE, OPTIONS, PATCH".to_string()), - allow_headers: Some(requested_headers.to_string()), - max_age: Some(86400), - }); - - return Ok(Response::builder() - .status(StatusCode::NO_CONTENT) - .body(ResponseBody::Full(Full::new(Bytes::new())))?); - } - - // Set CORS headers through guard - req_ctx.set_cors(CorsConfig { - allow_origin: origin.clone(), - allow_credentials: true, - expose_headers: "*".to_string(), - // Not an options req, not required - allow_methods: None, - allow_headers: None, - max_age: None, - }); - // NOTE: Size constraints have already been applied by guard let body_bytes = req .into_body()