Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion engine/artifacts/openapi.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions engine/packages/guard/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,21 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -
.split(',')
.map(|p| p.trim())
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET))
.map(ToOwned::to_owned)
})
} else {
// For HTTP, use the x-rivet-target header
req_ctx
.headers()
.get(X_RIVET_TARGET)
.and_then(|x| x.to_str().ok())
.map(ToOwned::to_owned)
};

// Read target
if let Some(target) = target {
if let Some(routing_output) =
pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, target)
pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, &target)
.await?
{
metrics::ROUTE_TOTAL.with_label_values(&["gateway"]).inc();
Expand All @@ -116,22 +118,22 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -
}

if let Some(routing_output) =
runner::route_request(&ctx, req_ctx, target).await?
runner::route_request(&ctx, req_ctx, &target).await?
{
metrics::ROUTE_TOTAL.with_label_values(&["runner"]).inc();

return Ok(routing_output);
}

if let Some(routing_output) =
envoy::route_request(&ctx, req_ctx, target).await?
envoy::route_request(&ctx, req_ctx, &target).await?
{
metrics::ROUTE_TOTAL.with_label_values(&["envoy"]).inc();

return Ok(routing_output);
}

if let Some(routing_output) = api_public::route_request(&ctx, target).await? {
if let Some(routing_output) = api_public::route_request(&ctx, &target).await? {
metrics::ROUTE_TOTAL.with_label_values(&["api"]).inc();

return Ok(routing_output);
Expand Down
71 changes: 71 additions & 0 deletions engine/packages/guard/src/routing/pegboard_gateway/cors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::Full;
use hyper::{Request, Response, StatusCode};
use rivet_guard_core::{
ResponseBody,
custom_serve::CustomServeTrait,
request_context::{CorsConfig, RequestContext},
};

pub fn origin_header(req_ctx: &RequestContext) -> String {
req_ctx
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*")
.to_string()
}

pub fn set_non_preflight_cors(req_ctx: &mut RequestContext) {
let allow_origin = origin_header(req_ctx);
req_ctx.set_cors(CorsConfig {
allow_origin,
allow_credentials: true,
expose_headers: "*".to_string(),
allow_methods: None,
allow_headers: None,
max_age: None,
});
}

/// Responds to CORS preflight OPTIONS requests with 204 and permissive CORS
/// headers. Avoids actor lookup, wake, and auth because browsers cannot attach
/// credentials to preflights. The actual request that follows is still authed.
pub struct CorsPreflight;

#[async_trait]
impl CustomServeTrait for CorsPreflight {
async fn handle_request(
&self,
req: Request<Full<Bytes>>,
req_ctx: &mut RequestContext,
) -> Result<Response<ResponseBody>> {
let allow_origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*")
.to_string();
let allow_headers = req
.headers()
.get("access-control-request-headers")
.and_then(|v| v.to_str().ok())
.unwrap_or("*")
.to_string();

req_ctx.set_cors(CorsConfig {
allow_origin,
allow_credentials: true,
expose_headers: "*".to_string(),
allow_methods: Some("GET, POST, PUT, DELETE, OPTIONS, PATCH".to_string()),
allow_headers: Some(allow_headers),
max_age: Some(86400),
});

Ok(Response::builder()
.status(StatusCode::NO_CONTENT)
.body(ResponseBody::Full(Full::new(Bytes::new())))?)
}
}
34 changes: 26 additions & 8 deletions engine/packages/guard/src/routing/pegboard_gateway/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod cors;
mod resolve_actor_query;

use std::time::Duration;
use std::{sync::Arc, time::Duration};

use anyhow::Result;
use gas::{ctx::message::SubscriptionHandle, prelude::*};
Expand All @@ -19,6 +20,7 @@ use crate::{
},
shared_state::SharedState,
};
use cors::{CorsPreflight, set_non_preflight_cors};
use resolve_actor_query::resolve_query;

const ACTOR_FORCE_WAKE_PENDING_TIMEOUT: i64 = util::duration::seconds(60);
Expand All @@ -36,12 +38,16 @@ pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");
pub async fn route_request_path_based(
ctx: &StandaloneCtx,
shared_state: &SharedState,
req_ctx: &RequestContext,
req_ctx: &mut RequestContext,
) -> Result<Option<RoutingOutput>> {
let Some(actor_path) = parse_actor_path(req_ctx.path())? else {
return Ok(None);
};

if req_ctx.method() == hyper::Method::OPTIONS {
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
}

tracing::debug!(?actor_path, "routing using path-based actor routing");

let (actor_id, token, stripped_path, bypass_connectable) = match actor_path {
Expand Down Expand Up @@ -101,14 +107,18 @@ pub async fn route_request_path_based(
pub async fn route_request(
ctx: &StandaloneCtx,
shared_state: &SharedState,
req_ctx: &RequestContext,
req_ctx: &mut RequestContext,
target: &str,
) -> Result<Option<RoutingOutput>> {
// Check target
if target != "actor" {
return Ok(None);
}

if req_ctx.method() == hyper::Method::OPTIONS {
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
}

// Extract actor ID and token from WebSocket protocol or HTTP headers
let (actor_id_str, token, bypass_connectable) = if req_ctx.is_websocket() {
// For WebSocket, parse the sec-websocket-protocol header
Expand Down Expand Up @@ -141,7 +151,8 @@ pub async fn route_request(

let token = protocols
.iter()
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN));
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN))
.map(ToOwned::to_owned);

let bypass_connectable = protocols
.iter()
Expand All @@ -168,7 +179,8 @@ pub async fn route_request(
.get(X_RIVET_TOKEN)
.map(|x| x.to_str())
.transpose()
.context("invalid x-rivet-token header")?;
.context("invalid x-rivet-token header")?
.map(ToOwned::to_owned);

let bypass_connectable = req_ctx.headers().contains_key(X_RIVET_BYPASS_CONNECTABLE);

Expand All @@ -177,14 +189,15 @@ pub async fn route_request(

// Find actor to route to
let actor_id = Id::parse(&actor_id_str).context("invalid x-rivet-actor header")?;
let stripped_path = req_ctx.path().to_owned();

route_request_inner(
ctx,
shared_state,
req_ctx,
actor_id,
req_ctx.path(),
token,
&stripped_path,
token.as_deref(),
bypass_connectable,
)
.await
Expand All @@ -194,12 +207,17 @@ pub async fn route_request(
async fn route_request_inner(
ctx: &StandaloneCtx,
shared_state: &SharedState,
req_ctx: &RequestContext,
req_ctx: &mut RequestContext,
actor_id: Id,
stripped_path: &str,
_token: Option<&str>,
bypass_connectable: bool,
) -> Result<RoutingOutput> {
// Attach CORS headers to the actual (non-OPTIONS) response so both the
// actor response and any early error (e.g. EE auth failure) are readable
// by the browser.
set_non_preflight_cors(req_ctx);

// NOTE: Token validation implemented in EE

// Route to peer dc where the actor lives
Expand Down
52 changes: 1 addition & 51 deletions engine/packages/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rivet_guard_core::{
ResponseBody, WebSocketHandle,
custom_serve::{CustomServeTrait, HibernationResult},
errors::{ServiceUnavailable, WebSocketServiceUnavailable},
request_context::{CorsConfig, RequestContext},
request_context::RequestContext,
utils::is_ws_hibernate,
websocket_handle::WebSocketReceiver,
};
Expand Down Expand Up @@ -92,15 +92,6 @@ impl PegboardGateway {
let actor_id = self.actor_id.to_string();
let request_id = req_ctx.in_flight_request_id()?;

// Extract origin for CORS (before consuming request)
// When credentials: true, we must echo back the actual origin, not "*"
let origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*")
.to_string();

// Extract request parts
let headers = req
.headers()
Expand All @@ -113,47 +104,6 @@ impl PegboardGateway {
})
.collect::<HashableMap<_, _>>();

// Handle CORS preflight OPTIONS requests at gateway level
//
// We need to do this in the gateway because there is no way of sending an OPTIONS request to the
// actor since we don't have the `x-rivet-token` header. This implementation allows
// requests from anywhere and lets the actor handle CORS manually in `onBeforeConnect`.
// This had the added benefit of also applying to WebSockets.
if req.method() == hyper::Method::OPTIONS {
tracing::debug!("handling OPTIONS preflight request at gateway");

// Extract requested headers
let requested_headers = req
.headers()
.get("access-control-request-headers")
.and_then(|v| v.to_str().ok())
.unwrap_or("*");

req_ctx.set_cors(CorsConfig {
allow_origin: origin.clone(),
allow_credentials: true,
expose_headers: "*".to_string(),
allow_methods: Some("GET, POST, PUT, DELETE, OPTIONS, PATCH".to_string()),
allow_headers: Some(requested_headers.to_string()),
max_age: Some(86400),
});

return Ok(Response::builder()
.status(StatusCode::NO_CONTENT)
.body(ResponseBody::Full(Full::new(Bytes::new())))?);
}

// Set CORS headers through guard
req_ctx.set_cors(CorsConfig {
allow_origin: origin.clone(),
allow_credentials: true,
expose_headers: "*".to_string(),
// Not an options req, not required
allow_methods: None,
allow_headers: None,
max_age: None,
});

// NOTE: Size constraints have already been applied by guard
let body_bytes = req
.into_body()
Expand Down
Loading
Loading