Skip to content

Commit a4663b0

Browse files
committed
fix(guard): handle bypass preflight before query parse
1 parent a051729 commit a4663b0

3 files changed

Lines changed: 80 additions & 15 deletions

File tree

engine/packages/guard/src/routing/actor_path.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,20 @@ pub enum ParsedActorPath {
6363
Query(QueryActorPathInfo),
6464
}
6565

66+
pub fn is_actor_gateway_path(path: &str) -> bool {
67+
let (base_path, _) = split_path_and_query(path);
68+
69+
if base_path.contains("//") {
70+
return false;
71+
}
72+
73+
base_path
74+
.split('/')
75+
.filter(|segment| !segment.is_empty())
76+
.next()
77+
== Some("gateway")
78+
}
79+
6680
/// Parsed rvt-* query parameters.
6781
#[derive(Debug, Clone, Deserialize)]
6882
#[serde(deny_unknown_fields)]

engine/packages/guard/src/routing/pegboard_gateway/mod.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use super::{
1515
use crate::{
1616
errors,
1717
routing::{
18-
actor_path::parse_actor_path,
18+
actor_path::{is_actor_gateway_path, parse_actor_path},
1919
pegboard_gateway::resolve_actor_query::ResolveQueryActorResult,
2020
},
2121
shared_state::SharedState,
@@ -56,14 +56,18 @@ pub async fn route_request_path_based_inner(
5656
shared_state: &SharedState,
5757
req_ctx: &mut RequestContext,
5858
) -> Result<Option<RoutingOutput>> {
59+
if req_ctx.method() == hyper::Method::OPTIONS {
60+
if is_actor_gateway_path(req_ctx.path()) {
61+
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
62+
}
63+
64+
return Ok(None);
65+
}
66+
5967
let Some(actor_path) = parse_actor_path(req_ctx.path())? else {
6068
return Ok(None);
6169
};
6270

63-
if req_ctx.method() == hyper::Method::OPTIONS {
64-
return Ok(Some(RoutingOutput::CustomServe(Arc::new(CorsPreflight))));
65-
}
66-
6771
tracing::debug!(?actor_path, "routing using path-based actor routing");
6872

6973
let (actor_id, token, stripped_path, bypass_connectable) = match actor_path {

engine/packages/guard/tests/parse_actor_path.rs

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Keep this test suite in sync with the TypeScript equivalent at
22
// rivetkit-typescript/packages/rivetkit/tests/parse-actor-path.test.ts
33
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
4-
use rivet_guard::routing::actor_path::{ParsedActorPath, QueryActorQuery, parse_actor_path};
4+
use rivet_guard::routing::actor_path::{
5+
ParsedActorPath, QueryActorQuery, is_actor_gateway_path, parse_actor_path,
6+
};
57

68
#[test]
79
fn parses_direct_actor_paths_with_existing_behavior() {
@@ -37,6 +39,7 @@ fn parses_query_actor_get_paths() {
3739
"shard-2".to_string(),
3840
"alpha@beta".to_string(),
3941
],
42+
bypass_connectable: false,
4043
}
4144
);
4245
}
@@ -66,11 +69,12 @@ fn parses_query_actor_get_or_create_paths_with_input_and_region() {
6669
QueryActorQuery::GetOrCreate {
6770
namespace: "default".to_string(),
6871
name: "worker".to_string(),
69-
runner_name: "default".to_string(),
72+
pool_name: "default".to_string(),
7073
key: vec!["shard-1".to_string()],
7174
input: Some(input_bytes),
7275
region: Some("us-west-2".to_string()),
7376
crash_policy: None,
77+
bypass_connectable: false,
7478
}
7579
);
7680
}
@@ -95,11 +99,12 @@ fn parses_query_actor_get_or_create_paths_with_multi_component_key() {
9599
QueryActorQuery::GetOrCreate {
96100
namespace: "default".to_string(),
97101
name: "worker".to_string(),
98-
runner_name: "default".to_string(),
102+
pool_name: "default".to_string(),
99103
key: vec!["tenant".to_string(), "job".to_string()],
100104
input: Some(input_bytes),
101105
region: None,
102106
crash_policy: None,
107+
bypass_connectable: false,
103108
}
104109
);
105110
assert_eq!(path.stripped_path, "/socket");
@@ -121,6 +126,7 @@ fn parses_query_actor_get_paths_with_empty_key() {
121126
namespace: "default".to_string(),
122127
name: "lobby".to_string(),
123128
key: Vec::new(),
129+
bypass_connectable: false,
124130
}
125131
);
126132
assert_eq!(path.stripped_path, "/");
@@ -141,11 +147,12 @@ fn omits_key_when_not_present() {
141147
QueryActorQuery::GetOrCreate {
142148
namespace: "default".to_string(),
143149
name: "builder".to_string(),
144-
runner_name: "default".to_string(),
150+
pool_name: "default".to_string(),
145151
key: Vec::new(),
146152
input: None,
147153
region: None,
148154
crash_policy: None,
155+
bypass_connectable: false,
149156
}
150157
);
151158
assert_eq!(path.stripped_path, "/");
@@ -167,6 +174,7 @@ fn parses_simple_multi_component_keys() {
167174
namespace: "default".to_string(),
168175
name: "lobby".to_string(),
169176
key: vec!["a".to_string(), "b".to_string(), "c".to_string()],
177+
bypass_connectable: false,
170178
}
171179
);
172180
}
@@ -186,18 +194,55 @@ fn parses_crash_policy_param() {
186194
QueryActorQuery::GetOrCreate {
187195
namespace: "default".to_string(),
188196
name: "worker".to_string(),
189-
runner_name: "default".to_string(),
197+
pool_name: "default".to_string(),
190198
key: Vec::new(),
191199
input: None,
192200
region: None,
193201
crash_policy: Some(rivet_types::actors::CrashPolicy::Restart),
202+
bypass_connectable: false,
194203
}
195204
);
196205
}
197206
ParsedActorPath::Direct(_) => panic!("expected query actor path"),
198207
}
199208
}
200209

210+
#[test]
211+
fn parses_bypass_connectable_query_bool_strings() {
212+
let path = "/gateway/worker/request/bypass?rvt-namespace=default&rvt-method=getOrCreate&rvt-runner=default&rvt-bypass_connectable=true";
213+
let result = parse_actor_path(path).unwrap().unwrap();
214+
215+
match result {
216+
ParsedActorPath::Query(path) => {
217+
assert_eq!(
218+
path.query,
219+
QueryActorQuery::GetOrCreate {
220+
namespace: "default".to_string(),
221+
name: "worker".to_string(),
222+
pool_name: "default".to_string(),
223+
key: Vec::new(),
224+
input: None,
225+
region: None,
226+
crash_policy: None,
227+
bypass_connectable: true,
228+
}
229+
);
230+
assert_eq!(path.stripped_path, "/request/bypass");
231+
}
232+
ParsedActorPath::Direct(_) => panic!("expected query actor path"),
233+
}
234+
}
235+
236+
#[test]
237+
fn identifies_gateway_paths_without_parsing_query_params() {
238+
assert!(is_actor_gateway_path(
239+
"/gateway/worker/request/bypass?rvt-bypass_connectable=true"
240+
));
241+
assert!(is_actor_gateway_path("/gateway/actor-id"));
242+
assert!(!is_actor_gateway_path("/request/bypass"));
243+
assert!(!is_actor_gateway_path("/gateway//worker"));
244+
}
245+
201246
#[test]
202247
fn strips_rvt_params_from_remaining_path() {
203248
let path = "/gateway/lobby/api/v1?rvt-namespace=prod&rvt-method=get&foo=bar&baz=qux";
@@ -272,6 +317,7 @@ fn handles_interleaved_rvt_and_actor_params() {
272317
namespace: "default".to_string(),
273318
name: "lobby".to_string(),
274319
key: Vec::new(),
320+
bypass_connectable: false,
275321
}
276322
);
277323
}
@@ -295,6 +341,7 @@ fn decodes_plus_as_space_in_rvt_values() {
295341
namespace: "my ns".to_string(),
296342
name: "lobby".to_string(),
297343
key: vec!["hello world".to_string()],
344+
bypass_connectable: false,
298345
}
299346
);
300347
// Actor param + is preserved literally.
@@ -421,7 +468,7 @@ fn rejects_input_for_get_queries() {
421468
.unwrap_err()
422469
.to_string();
423470
assert!(err.contains(
424-
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-runner params"
471+
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-pool params"
425472
));
426473
}
427474

@@ -433,7 +480,7 @@ fn rejects_region_for_get_queries() {
433480
.unwrap_err()
434481
.to_string();
435482
assert!(err.contains(
436-
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-runner params"
483+
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-pool params"
437484
));
438485
}
439486

@@ -445,7 +492,7 @@ fn rejects_crash_policy_for_get_queries() {
445492
.unwrap_err()
446493
.to_string();
447494
assert!(err.contains(
448-
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-runner params"
495+
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-pool params"
449496
));
450497
}
451498

@@ -456,7 +503,7 @@ fn rejects_runner_for_get_queries() {
456503
.unwrap_err()
457504
.to_string();
458505
assert!(err.contains(
459-
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-runner params"
506+
"query gateway method=get does not allow rvt-input, rvt-region, rvt-crash-policy, or rvt-pool params"
460507
));
461508
}
462509

@@ -465,7 +512,7 @@ fn rejects_missing_runner_for_get_or_create_queries() {
465512
let err = parse_actor_path("/gateway/lobby?rvt-namespace=default&rvt-method=getOrCreate")
466513
.unwrap_err()
467514
.to_string();
468-
assert!(err.contains("query gateway method=getOrCreate requires rvt-runner param"));
515+
assert!(err.contains("query gateway method=getOrCreate requires rvt-pool param"));
469516
}
470517

471518
#[test]

0 commit comments

Comments
 (0)