diff --git a/engine/packages/guard-core/src/proxy_service.rs b/engine/packages/guard-core/src/proxy_service.rs index 68081dbb23..46d24407f3 100644 --- a/engine/packages/guard-core/src/proxy_service.rs +++ b/engine/packages/guard-core/src/proxy_service.rs @@ -5,7 +5,7 @@ use http_body_util::{BodyExt, Full, Limited}; use hyper::{ Request, Response, StatusCode, body::Incoming as BodyIncoming, - header::{HeaderName, HeaderValue}, + header::{HeaderMap, HeaderName, HeaderValue}, }; use hyper_tungstenite; use hyper_util::{client::legacy::Client, rt::TokioExecutor}; @@ -43,6 +43,10 @@ pub const X_RIVET_ERROR: HeaderName = HeaderName::from_static("x-rivet-error"); const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour const WEBSOCKET_CLOSE_LINGER: Duration = Duration::from_millis(5); // Keep TCP connection open briefly after WebSocket close +const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol"); +const WS_PROTOCOL_RIVET: &str = "rivet"; +const WS_PROTOCOL_CONN_PARAMS: &str = "rivet_conn_params."; +const WS_PROTOCOL_TOKEN: &str = "rivet_token."; // State shared across all request handlers pub struct ProxyState { @@ -520,12 +524,11 @@ impl ProxyService { // Extract the parts from the response but preserve all headers and status let (mut parts, _) = client_response.into_parts(); - // Add Sec-WebSocket-Protocol header to the response - // Many WebSocket clients (e.g. node-ws & Cloudflare) require a protocol in the response - parts.headers.insert( - "sec-websocket-protocol", - hyper::header::HeaderValue::from_static("rivet"), - ); + if let Some(protocol) = + select_websocket_response_protocol(req_ctx.headers()) + { + parts.headers.insert(SEC_WEBSOCKET_PROTOCOL, protocol); + } // Create a new response with an empty body - WebSocket upgrades don't need a body Response::from_parts( @@ -1762,12 +1765,9 @@ impl ProxyService { // Extract the parts from the response but preserve all headers and status let (mut parts, _) = client_response.into_parts(); - // Add Sec-WebSocket-Protocol header to the response - // Many WebSocket clients (e.g. node-ws & Cloudflare) require a protocol in the response - parts.headers.insert( - "sec-websocket-protocol", - hyper::header::HeaderValue::from_static("rivet"), - ); + if let Some(protocol) = select_websocket_response_protocol(req_ctx.headers()) { + parts.headers.insert(SEC_WEBSOCKET_PROTOCOL, protocol); + } // Create a new response with an empty body - WebSocket upgrades don't need a body Ok(Response::from_parts( @@ -1777,6 +1777,27 @@ impl ProxyService { } } +fn select_websocket_response_protocol(headers: &HeaderMap) -> Option { + let protocols = headers.get(SEC_WEBSOCKET_PROTOCOL)?.to_str().ok()?; + let protocols = protocols + .split(',') + .map(str::trim) + .filter(|protocol| !protocol.is_empty()) + .collect::>(); + + let selected = protocols + .iter() + .find(|protocol| **protocol == WS_PROTOCOL_RIVET) + .or_else(|| { + protocols.iter().find(|protocol| { + !protocol.starts_with(WS_PROTOCOL_CONN_PARAMS) + && !protocol.starts_with(WS_PROTOCOL_TOKEN) + }) + })?; + + HeaderValue::from_str(selected).ok() +} + impl Clone for ProxyService { fn clone(&self) -> Self { Self { diff --git a/engine/packages/guard/src/routing/pegboard_gateway/mod.rs b/engine/packages/guard/src/routing/pegboard_gateway/mod.rs index e6f8dc1adc..d7081bcedc 100644 --- a/engine/packages/guard/src/routing/pegboard_gateway/mod.rs +++ b/engine/packages/guard/src/routing/pegboard_gateway/mod.rs @@ -564,16 +564,13 @@ fn read_gateway_token_for_path_based<'a>( } if req_ctx.is_websocket() { - let protocols_header = req_ctx + let Some(protocols_header) = req_ctx .headers() .get(SEC_WEBSOCKET_PROTOCOL) .and_then(|protocols| protocols.to_str().ok()) - .ok_or_else(|| { - crate::errors::MissingHeader { - header: "sec-websocket-protocol".to_string(), - } - .build() - })?; + else { + return Ok(None); + }; let protocols = protocols_header .split(',')