Skip to content

Commit 822f594

Browse files
committed
chore: fix cors for envoys
1 parent 0586791 commit 822f594

6 files changed

Lines changed: 139 additions & 116 deletions

File tree

engine/artifacts/openapi.json

Lines changed: 26 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/packages/guard/src/routing/mod.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,21 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -
9292
.split(',')
9393
.map(|p| p.trim())
9494
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET))
95+
.map(ToOwned::to_owned)
9596
})
9697
} else {
9798
// For HTTP, use the x-rivet-target header
9899
req_ctx
99100
.headers()
100101
.get(X_RIVET_TARGET)
101102
.and_then(|x| x.to_str().ok())
103+
.map(ToOwned::to_owned)
102104
};
103105

104106
// Read target
105107
if let Some(target) = target {
106108
if let Some(routing_output) =
107-
pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, target)
109+
pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, &target)
108110
.await?
109111
{
110112
metrics::ROUTE_TOTAL.with_label_values(&["gateway"]).inc();
@@ -113,22 +115,22 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -
113115
}
114116

115117
if let Some(routing_output) =
116-
runner::route_request(&ctx, req_ctx, target).await?
118+
runner::route_request(&ctx, req_ctx, &target).await?
117119
{
118120
metrics::ROUTE_TOTAL.with_label_values(&["runner"]).inc();
119121

120122
return Ok(routing_output);
121123
}
122124

123125
if let Some(routing_output) =
124-
envoy::route_request(&ctx, req_ctx, target).await?
126+
envoy::route_request(&ctx, req_ctx, &target).await?
125127
{
126128
metrics::ROUTE_TOTAL.with_label_values(&["envoy"]).inc();
127129

128130
return Ok(routing_output);
129131
}
130132

131-
if let Some(routing_output) = api_public::route_request(&ctx, target).await? {
133+
if let Some(routing_output) = api_public::route_request(&ctx, &target).await? {
132134
metrics::ROUTE_TOTAL.with_label_values(&["api"]).inc();
133135

134136
return Ok(routing_output);
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use anyhow::Result;
2+
use async_trait::async_trait;
3+
use bytes::Bytes;
4+
use http_body_util::Full;
5+
use hyper::{Request, Response, StatusCode};
6+
use rivet_guard_core::{
7+
ResponseBody,
8+
custom_serve::CustomServeTrait,
9+
request_context::{CorsConfig, RequestContext},
10+
};
11+
12+
pub fn origin_header(req_ctx: &RequestContext) -> String {
13+
req_ctx
14+
.headers()
15+
.get("origin")
16+
.and_then(|v| v.to_str().ok())
17+
.unwrap_or("*")
18+
.to_string()
19+
}
20+
21+
pub fn set_non_preflight_cors(req_ctx: &mut RequestContext) {
22+
let allow_origin = origin_header(req_ctx);
23+
req_ctx.set_cors(CorsConfig {
24+
allow_origin,
25+
allow_credentials: true,
26+
expose_headers: "*".to_string(),
27+
allow_methods: None,
28+
allow_headers: None,
29+
max_age: None,
30+
});
31+
}
32+
33+
/// Responds to CORS preflight OPTIONS requests with 204 and permissive CORS
34+
/// headers. Avoids actor lookup, wake, and auth because browsers cannot attach
35+
/// credentials to preflights. The actual request that follows is still authed.
36+
pub struct CorsPreflight;
37+
38+
#[async_trait]
39+
impl CustomServeTrait for CorsPreflight {
40+
async fn handle_request(
41+
&self,
42+
req: Request<Full<Bytes>>,
43+
req_ctx: &mut RequestContext,
44+
) -> Result<Response<ResponseBody>> {
45+
let allow_origin = req
46+
.headers()
47+
.get("origin")
48+
.and_then(|v| v.to_str().ok())
49+
.unwrap_or("*")
50+
.to_string();
51+
let allow_headers = req
52+
.headers()
53+
.get("access-control-request-headers")
54+
.and_then(|v| v.to_str().ok())
55+
.unwrap_or("*")
56+
.to_string();
57+
58+
req_ctx.set_cors(CorsConfig {
59+
allow_origin,
60+
allow_credentials: true,
61+
expose_headers: "*".to_string(),
62+
allow_methods: Some("GET, POST, PUT, DELETE, OPTIONS, PATCH".to_string()),
63+
allow_headers: Some(allow_headers),
64+
max_age: Some(86400),
65+
});
66+
67+
Ok(Response::builder()
68+
.status(StatusCode::NO_CONTENT)
69+
.body(ResponseBody::Full(Full::new(Bytes::new())))?)
70+
}
71+
}

engine/packages/guard/src/routing/pegboard_gateway/mod.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
mod cors;
12
mod resolve_actor_query;
23

3-
use std::time::Duration;
4+
use std::{sync::Arc, time::Duration};
45

56
use anyhow::Result;
67
use gas::{ctx::message::SubscriptionHandle, prelude::*};
@@ -12,6 +13,7 @@ use super::{
1213
actor_path::ParsedActorPath,
1314
};
1415
use crate::{errors, routing::actor_path::parse_actor_path, shared_state::SharedState};
16+
use cors::{CorsPreflight, set_non_preflight_cors};
1517
use resolve_actor_query::resolve_query_actor_id;
1618

1719
const ACTOR_FORCE_WAKE_PENDING_TIMEOUT: i64 = util::duration::seconds(60);
@@ -29,12 +31,16 @@ pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");
2931
pub async fn route_request_path_based(
3032
ctx: &StandaloneCtx,
3133
shared_state: &SharedState,
32-
req_ctx: &RequestContext,
34+
req_ctx: &mut RequestContext,
3335
) -> Result<Option<RoutingOutput>> {
3436
let Some(actor_path) = parse_actor_path(req_ctx.path())? else {
3537
return Ok(None);
3638
};
3739

40+
if req_ctx.method() == hyper::Method::OPTIONS {
41+
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
42+
}
43+
3844
tracing::debug!(?actor_path, "routing using path-based actor routing");
3945

4046
let resolved_route = resolve_path_based_route(ctx, req_ctx, &actor_path).await?;
@@ -56,14 +62,18 @@ pub async fn route_request_path_based(
5662
pub async fn route_request(
5763
ctx: &StandaloneCtx,
5864
shared_state: &SharedState,
59-
req_ctx: &RequestContext,
65+
req_ctx: &mut RequestContext,
6066
target: &str,
6167
) -> Result<Option<RoutingOutput>> {
6268
// Check target
6369
if target != "actor" {
6470
return Ok(None);
6571
}
6672

73+
if req_ctx.method() == hyper::Method::OPTIONS {
74+
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
75+
}
76+
6777
// Extract actor ID and token from WebSocket protocol or HTTP headers
6878
let (actor_id_str, token) = if req_ctx.is_websocket() {
6979
// For WebSocket, parse the sec-websocket-protocol header
@@ -96,7 +106,8 @@ pub async fn route_request(
96106

97107
let token = protocols
98108
.iter()
99-
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN));
109+
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN))
110+
.map(ToOwned::to_owned);
100111

101112
(actor_id, token)
102113
} else {
@@ -119,17 +130,26 @@ pub async fn route_request(
119130
.get(X_RIVET_TOKEN)
120131
.map(|x| x.to_str())
121132
.transpose()
122-
.context("invalid x-rivet-token header")?;
133+
.context("invalid x-rivet-token header")?
134+
.map(ToOwned::to_owned);
123135

124136
(actor_id.to_string(), token)
125137
};
126138

127139
// Find actor to route to
128140
let actor_id = Id::parse(&actor_id_str).context("invalid x-rivet-actor header")?;
141+
let stripped_path = req_ctx.path().to_owned();
129142

130-
route_request_inner(ctx, shared_state, req_ctx, actor_id, req_ctx.path(), token)
131-
.await
132-
.map(Some)
143+
route_request_inner(
144+
ctx,
145+
shared_state,
146+
req_ctx,
147+
actor_id,
148+
&stripped_path,
149+
token.as_deref(),
150+
)
151+
.await
152+
.map(Some)
133153
}
134154

135155
#[derive(Debug)]
@@ -201,11 +221,16 @@ fn read_gateway_token_from_request<'a>(
201221
async fn route_request_inner(
202222
ctx: &StandaloneCtx,
203223
shared_state: &SharedState,
204-
req_ctx: &RequestContext,
224+
req_ctx: &mut RequestContext,
205225
actor_id: Id,
206226
stripped_path: &str,
207227
_token: Option<&str>,
208228
) -> Result<RoutingOutput> {
229+
// Attach CORS headers to the actual (non-OPTIONS) response so both the
230+
// actor response and any early error (e.g. EE auth failure) are readable
231+
// by the browser.
232+
set_non_preflight_cors(req_ctx);
233+
209234
// NOTE: Token validation implemented in EE
210235

211236
// Route to peer dc where the actor lives

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use rivet_guard_core::{
1010
ResponseBody, WebSocketHandle,
1111
custom_serve::{CustomServeTrait, HibernationResult},
1212
errors::{ServiceUnavailable, WebSocketServiceUnavailable},
13-
request_context::{CorsConfig, RequestContext},
13+
request_context::RequestContext,
1414
utils::is_ws_hibernate,
1515
websocket_handle::WebSocketReceiver,
1616
};
@@ -92,15 +92,6 @@ impl PegboardGateway {
9292
let actor_id = self.actor_id.to_string();
9393
let request_id = req_ctx.in_flight_request_id()?;
9494

95-
// Extract origin for CORS (before consuming request)
96-
// When credentials: true, we must echo back the actual origin, not "*"
97-
let origin = req
98-
.headers()
99-
.get("origin")
100-
.and_then(|v| v.to_str().ok())
101-
.unwrap_or("*")
102-
.to_string();
103-
10495
// Extract request parts
10596
let headers = req
10697
.headers()
@@ -113,47 +104,6 @@ impl PegboardGateway {
113104
})
114105
.collect::<HashableMap<_, _>>();
115106

116-
// Handle CORS preflight OPTIONS requests at gateway level
117-
//
118-
// We need to do this in the gateway because there is no way of sending an OPTIONS request to the
119-
// actor since we don't have the `x-rivet-token` header. This implementation allows
120-
// requests from anywhere and lets the actor handle CORS manually in `onBeforeConnect`.
121-
// This had the added benefit of also applying to WebSockets.
122-
if req.method() == hyper::Method::OPTIONS {
123-
tracing::debug!("handling OPTIONS preflight request at gateway");
124-
125-
// Extract requested headers
126-
let requested_headers = req
127-
.headers()
128-
.get("access-control-request-headers")
129-
.and_then(|v| v.to_str().ok())
130-
.unwrap_or("*");
131-
132-
req_ctx.set_cors(CorsConfig {
133-
allow_origin: origin.clone(),
134-
allow_credentials: true,
135-
expose_headers: "*".to_string(),
136-
allow_methods: Some("GET, POST, PUT, DELETE, OPTIONS, PATCH".to_string()),
137-
allow_headers: Some(requested_headers.to_string()),
138-
max_age: Some(86400),
139-
});
140-
141-
return Ok(Response::builder()
142-
.status(StatusCode::NO_CONTENT)
143-
.body(ResponseBody::Full(Full::new(Bytes::new())))?);
144-
}
145-
146-
// Set CORS headers through guard
147-
req_ctx.set_cors(CorsConfig {
148-
allow_origin: origin.clone(),
149-
allow_credentials: true,
150-
expose_headers: "*".to_string(),
151-
// Not an options req, not required
152-
allow_methods: None,
153-
allow_headers: None,
154-
max_age: None,
155-
});
156-
157107
// NOTE: Size constraints have already been applied by guard
158108
let body_bytes = req
159109
.into_body()

0 commit comments

Comments
 (0)