Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions engine/packages/guard-core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use http_body_util::{BodyExt, Full, Limited};
use hyper::{
Request, Response, StatusCode,
body::Incoming as BodyIncoming,
header::{HeaderName, HeaderValue},
header::{HeaderMap, HeaderName, HeaderValue},
};
use hyper_tungstenite;
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
Expand Down Expand Up @@ -43,6 +43,10 @@ pub const X_RIVET_ERROR: HeaderName = HeaderName::from_static("x-rivet-error");

const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour
const WEBSOCKET_CLOSE_LINGER: Duration = Duration::from_millis(5); // Keep TCP connection open briefly after WebSocket close
const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol");
const WS_PROTOCOL_RIVET: &str = "rivet";
const WS_PROTOCOL_CONN_PARAMS: &str = "rivet_conn_params.";
const WS_PROTOCOL_TOKEN: &str = "rivet_token.";

// State shared across all request handlers
pub struct ProxyState {
Expand Down Expand Up @@ -520,12 +524,11 @@ impl ProxyService {
// Extract the parts from the response but preserve all headers and status
let (mut parts, _) = client_response.into_parts();

// Add Sec-WebSocket-Protocol header to the response
// Many WebSocket clients (e.g. node-ws & Cloudflare) require a protocol in the response
parts.headers.insert(
"sec-websocket-protocol",
hyper::header::HeaderValue::from_static("rivet"),
);
if let Some(protocol) =
select_websocket_response_protocol(req_ctx.headers())
{
parts.headers.insert(SEC_WEBSOCKET_PROTOCOL, protocol);
}

// Create a new response with an empty body - WebSocket upgrades don't need a body
Response::from_parts(
Expand Down Expand Up @@ -1762,12 +1765,9 @@ impl ProxyService {
// Extract the parts from the response but preserve all headers and status
let (mut parts, _) = client_response.into_parts();

// Add Sec-WebSocket-Protocol header to the response
// Many WebSocket clients (e.g. node-ws & Cloudflare) require a protocol in the response
parts.headers.insert(
"sec-websocket-protocol",
hyper::header::HeaderValue::from_static("rivet"),
);
if let Some(protocol) = select_websocket_response_protocol(req_ctx.headers()) {
parts.headers.insert(SEC_WEBSOCKET_PROTOCOL, protocol);
}

// Create a new response with an empty body - WebSocket upgrades don't need a body
Ok(Response::from_parts(
Expand All @@ -1777,6 +1777,27 @@ impl ProxyService {
}
}

fn select_websocket_response_protocol(headers: &HeaderMap) -> Option<HeaderValue> {
let protocols = headers.get(SEC_WEBSOCKET_PROTOCOL)?.to_str().ok()?;
let protocols = protocols
.split(',')
.map(str::trim)
.filter(|protocol| !protocol.is_empty())
.collect::<Vec<_>>();

let selected = protocols
.iter()
.find(|protocol| **protocol == WS_PROTOCOL_RIVET)
.or_else(|| {
protocols.iter().find(|protocol| {
!protocol.starts_with(WS_PROTOCOL_CONN_PARAMS)
&& !protocol.starts_with(WS_PROTOCOL_TOKEN)
})
})?;

HeaderValue::from_str(selected).ok()
}

impl Clone for ProxyService {
fn clone(&self) -> Self {
Self {
Expand Down
11 changes: 4 additions & 7 deletions engine/packages/guard/src/routing/pegboard_gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,16 +564,13 @@ fn read_gateway_token_for_path_based<'a>(
}

if req_ctx.is_websocket() {
let protocols_header = req_ctx
let Some(protocols_header) = req_ctx
.headers()
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|protocols| protocols.to_str().ok())
.ok_or_else(|| {
crate::errors::MissingHeader {
header: "sec-websocket-protocol".to_string(),
}
.build()
})?;
else {
return Ok(None);
};

let protocols = protocols_header
.split(',')
Expand Down
Loading