Skip to content

Commit 1d1fa8a

Browse files
committed
chore: fix cors for envoys
1 parent 94df783 commit 1d1fa8a

6 files changed

Lines changed: 131 additions & 115 deletions

File tree

engine/artifacts/openapi.json

Lines changed: 26 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/packages/guard/src/routing/mod.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,21 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -
9595
.split(',')
9696
.map(|p| p.trim())
9797
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET))
98+
.map(ToOwned::to_owned)
9899
})
99100
} else {
100101
// For HTTP, use the x-rivet-target header
101102
req_ctx
102103
.headers()
103104
.get(X_RIVET_TARGET)
104105
.and_then(|x| x.to_str().ok())
106+
.map(ToOwned::to_owned)
105107
};
106108

107109
// Read target
108110
if let Some(target) = target {
109111
if let Some(routing_output) =
110-
pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, target)
112+
pegboard_gateway::route_request(&ctx, &shared_state, req_ctx, &target)
111113
.await?
112114
{
113115
metrics::ROUTE_TOTAL.with_label_values(&["gateway"]).inc();
@@ -116,22 +118,22 @@ pub fn create_routing_function(ctx: &StandaloneCtx, shared_state: SharedState) -
116118
}
117119

118120
if let Some(routing_output) =
119-
runner::route_request(&ctx, req_ctx, target).await?
121+
runner::route_request(&ctx, req_ctx, &target).await?
120122
{
121123
metrics::ROUTE_TOTAL.with_label_values(&["runner"]).inc();
122124

123125
return Ok(routing_output);
124126
}
125127

126128
if let Some(routing_output) =
127-
envoy::route_request(&ctx, req_ctx, target).await?
129+
envoy::route_request(&ctx, req_ctx, &target).await?
128130
{
129131
metrics::ROUTE_TOTAL.with_label_values(&["envoy"]).inc();
130132

131133
return Ok(routing_output);
132134
}
133135

134-
if let Some(routing_output) = api_public::route_request(&ctx, target).await? {
136+
if let Some(routing_output) = api_public::route_request(&ctx, &target).await? {
135137
metrics::ROUTE_TOTAL.with_label_values(&["api"]).inc();
136138

137139
return Ok(routing_output);
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use anyhow::Result;
2+
use async_trait::async_trait;
3+
use bytes::Bytes;
4+
use http_body_util::Full;
5+
use hyper::{Request, Response, StatusCode};
6+
use rivet_guard_core::{
7+
ResponseBody,
8+
custom_serve::CustomServeTrait,
9+
request_context::{CorsConfig, RequestContext},
10+
};
11+
12+
pub fn origin_header(req_ctx: &RequestContext) -> String {
13+
req_ctx
14+
.headers()
15+
.get("origin")
16+
.and_then(|v| v.to_str().ok())
17+
.unwrap_or("*")
18+
.to_string()
19+
}
20+
21+
pub fn set_non_preflight_cors(req_ctx: &mut RequestContext) {
22+
let allow_origin = origin_header(req_ctx);
23+
req_ctx.set_cors(CorsConfig {
24+
allow_origin,
25+
allow_credentials: true,
26+
expose_headers: "*".to_string(),
27+
allow_methods: None,
28+
allow_headers: None,
29+
max_age: None,
30+
});
31+
}
32+
33+
/// Responds to CORS preflight OPTIONS requests with 204 and permissive CORS
34+
/// headers. Avoids actor lookup, wake, and auth because browsers cannot attach
35+
/// credentials to preflights. The actual request that follows is still authed.
36+
pub struct CorsPreflight;
37+
38+
#[async_trait]
39+
impl CustomServeTrait for CorsPreflight {
40+
async fn handle_request(
41+
&self,
42+
req: Request<Full<Bytes>>,
43+
req_ctx: &mut RequestContext,
44+
) -> Result<Response<ResponseBody>> {
45+
let allow_origin = req
46+
.headers()
47+
.get("origin")
48+
.and_then(|v| v.to_str().ok())
49+
.unwrap_or("*")
50+
.to_string();
51+
let allow_headers = req
52+
.headers()
53+
.get("access-control-request-headers")
54+
.and_then(|v| v.to_str().ok())
55+
.unwrap_or("*")
56+
.to_string();
57+
58+
req_ctx.set_cors(CorsConfig {
59+
allow_origin,
60+
allow_credentials: true,
61+
expose_headers: "*".to_string(),
62+
allow_methods: Some("GET, POST, PUT, DELETE, OPTIONS, PATCH".to_string()),
63+
allow_headers: Some(allow_headers),
64+
max_age: Some(86400),
65+
});
66+
67+
Ok(Response::builder()
68+
.status(StatusCode::NO_CONTENT)
69+
.body(ResponseBody::Full(Full::new(Bytes::new())))?)
70+
}
71+
}

engine/packages/guard/src/routing/pegboard_gateway/mod.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
mod cors;
12
mod resolve_actor_query;
23

3-
use std::time::Duration;
4+
use std::{sync::Arc, time::Duration};
45

56
use anyhow::Result;
67
use gas::{ctx::message::SubscriptionHandle, prelude::*};
@@ -19,6 +20,7 @@ use crate::{
1920
},
2021
shared_state::SharedState,
2122
};
23+
use cors::{CorsPreflight, set_non_preflight_cors};
2224
use resolve_actor_query::resolve_query;
2325

2426
const ACTOR_FORCE_WAKE_PENDING_TIMEOUT: i64 = util::duration::seconds(60);
@@ -36,12 +38,16 @@ pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");
3638
pub async fn route_request_path_based(
3739
ctx: &StandaloneCtx,
3840
shared_state: &SharedState,
39-
req_ctx: &RequestContext,
41+
req_ctx: &mut RequestContext,
4042
) -> Result<Option<RoutingOutput>> {
4143
let Some(actor_path) = parse_actor_path(req_ctx.path())? else {
4244
return Ok(None);
4345
};
4446

47+
if req_ctx.method() == hyper::Method::OPTIONS {
48+
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
49+
}
50+
4551
tracing::debug!(?actor_path, "routing using path-based actor routing");
4652

4753
let (actor_id, token, stripped_path, bypass_connectable) = match actor_path {
@@ -101,14 +107,18 @@ pub async fn route_request_path_based(
101107
pub async fn route_request(
102108
ctx: &StandaloneCtx,
103109
shared_state: &SharedState,
104-
req_ctx: &RequestContext,
110+
req_ctx: &mut RequestContext,
105111
target: &str,
106112
) -> Result<Option<RoutingOutput>> {
107113
// Check target
108114
if target != "actor" {
109115
return Ok(None);
110116
}
111117

118+
if req_ctx.method() == hyper::Method::OPTIONS {
119+
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
120+
}
121+
112122
// Extract actor ID and token from WebSocket protocol or HTTP headers
113123
let (actor_id_str, token, bypass_connectable) = if req_ctx.is_websocket() {
114124
// For WebSocket, parse the sec-websocket-protocol header
@@ -141,7 +151,8 @@ pub async fn route_request(
141151

142152
let token = protocols
143153
.iter()
144-
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN));
154+
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN))
155+
.map(ToOwned::to_owned);
145156

146157
let bypass_connectable = protocols
147158
.iter()
@@ -168,7 +179,8 @@ pub async fn route_request(
168179
.get(X_RIVET_TOKEN)
169180
.map(|x| x.to_str())
170181
.transpose()
171-
.context("invalid x-rivet-token header")?;
182+
.context("invalid x-rivet-token header")?
183+
.map(ToOwned::to_owned);
172184

173185
let bypass_connectable = req_ctx.headers().contains_key(X_RIVET_BYPASS_CONNECTABLE);
174186

@@ -177,14 +189,15 @@ pub async fn route_request(
177189

178190
// Find actor to route to
179191
let actor_id = Id::parse(&actor_id_str).context("invalid x-rivet-actor header")?;
192+
let stripped_path = req_ctx.path().to_owned();
180193

181194
route_request_inner(
182195
ctx,
183196
shared_state,
184197
req_ctx,
185198
actor_id,
186-
req_ctx.path(),
187-
token,
199+
&stripped_path,
200+
token.as_deref(),
188201
bypass_connectable,
189202
)
190203
.await
@@ -194,12 +207,17 @@ pub async fn route_request(
194207
async fn route_request_inner(
195208
ctx: &StandaloneCtx,
196209
shared_state: &SharedState,
197-
req_ctx: &RequestContext,
210+
req_ctx: &mut RequestContext,
198211
actor_id: Id,
199212
stripped_path: &str,
200213
_token: Option<&str>,
201214
bypass_connectable: bool,
202215
) -> Result<RoutingOutput> {
216+
// Attach CORS headers to the actual (non-OPTIONS) response so both the
217+
// actor response and any early error (e.g. EE auth failure) are readable
218+
// by the browser.
219+
set_non_preflight_cors(req_ctx);
220+
203221
// NOTE: Token validation implemented in EE
204222

205223
// Route to peer dc where the actor lives

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use rivet_guard_core::{
1010
ResponseBody, WebSocketHandle,
1111
custom_serve::{CustomServeTrait, HibernationResult},
1212
errors::{ServiceUnavailable, WebSocketServiceUnavailable},
13-
request_context::{CorsConfig, RequestContext},
13+
request_context::RequestContext,
1414
utils::is_ws_hibernate,
1515
websocket_handle::WebSocketReceiver,
1616
};
@@ -92,15 +92,6 @@ impl PegboardGateway {
9292
let actor_id = self.actor_id.to_string();
9393
let request_id = req_ctx.in_flight_request_id()?;
9494

95-
// Extract origin for CORS (before consuming request)
96-
// When credentials: true, we must echo back the actual origin, not "*"
97-
let origin = req
98-
.headers()
99-
.get("origin")
100-
.and_then(|v| v.to_str().ok())
101-
.unwrap_or("*")
102-
.to_string();
103-
10495
// Extract request parts
10596
let headers = req
10697
.headers()
@@ -113,47 +104,6 @@ impl PegboardGateway {
113104
})
114105
.collect::<HashableMap<_, _>>();
115106

116-
// Handle CORS preflight OPTIONS requests at gateway level
117-
//
118-
// We need to do this in the gateway because there is no way of sending an OPTIONS request to the
119-
// actor since we don't have the `x-rivet-token` header. This implementation allows
120-
// requests from anywhere and lets the actor handle CORS manually in `onBeforeConnect`.
121-
// This had the added benefit of also applying to WebSockets.
122-
if req.method() == hyper::Method::OPTIONS {
123-
tracing::debug!("handling OPTIONS preflight request at gateway");
124-
125-
// Extract requested headers
126-
let requested_headers = req
127-
.headers()
128-
.get("access-control-request-headers")
129-
.and_then(|v| v.to_str().ok())
130-
.unwrap_or("*");
131-
132-
req_ctx.set_cors(CorsConfig {
133-
allow_origin: origin.clone(),
134-
allow_credentials: true,
135-
expose_headers: "*".to_string(),
136-
allow_methods: Some("GET, POST, PUT, DELETE, OPTIONS, PATCH".to_string()),
137-
allow_headers: Some(requested_headers.to_string()),
138-
max_age: Some(86400),
139-
});
140-
141-
return Ok(Response::builder()
142-
.status(StatusCode::NO_CONTENT)
143-
.body(ResponseBody::Full(Full::new(Bytes::new())))?);
144-
}
145-
146-
// Set CORS headers through guard
147-
req_ctx.set_cors(CorsConfig {
148-
allow_origin: origin.clone(),
149-
allow_credentials: true,
150-
expose_headers: "*".to_string(),
151-
// Not an options req, not required
152-
allow_methods: None,
153-
allow_headers: None,
154-
max_age: None,
155-
});
156-
157107
// NOTE: Size constraints have already been applied by guard
158108
let body_bytes = req
159109
.into_body()

0 commit comments

Comments
 (0)