diff --git a/CLAUDE.md b/CLAUDE.md index 5fc34fb0a3..90e735de05 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -216,6 +216,14 @@ When the user asks to track something in a note, store it in `.agent/notes/` by - Any behavior, protocol handling, or test coverage added to one runner should be mirrored in the other runner in the same change whenever possible. - When parity cannot be completed in the same change, explicitly document the gap and add a follow-up task. +### Trust Boundaries +- Treat `client <-> engine` as untrusted. +- Treat `envoy <-> pegboard-envoy` as untrusted. +- Treat traffic inside the engine over `nats`, `fdb`, and other internal backends as trusted. +- Treat `gateway`, `api`, `pegboard-envoy`, `nats`, `fdb`, and similar engine-internal services as one trusted internal boundary once traffic is inside the engine. +- Validate and authorize all client-originated data at the engine edge before it reaches trusted internal systems. +- Validate and authorize all envoy-originated data at `pegboard-envoy` before it reaches trusted internal systems. + ### Important Patterns **Error Handling** diff --git a/Cargo.lock b/Cargo.lock index 80b78289c7..76621100bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -938,15 +938,6 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" -[[package]] -name = "convert_case" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec182b0ca2f35d8fc196cf3404988fd8b8c739a4d270ff118a398feb0cbec1ca" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "cookie" version = "0.18.1" @@ -1040,16 +1031,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "ctor" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" -dependencies = [ - "quote", - "syn 2.0.104", -] - [[package]] name = "curve25519-dalek" version = "4.1.3" @@ -2612,17 +2593,6 @@ dependencies = [ "zstd-sys", ] -[[package]] -name = "libsqlite3-sys" -version = "0.30.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" -dependencies = [ - "cc", - "pkg-config", - "vcpkg", -] - [[package]] name = "libz-sys" version = "1.1.22" @@ -2879,66 +2849,6 @@ dependencies = [ "vbare", ] -[[package]] -name = "napi" -version = "2.16.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55740c4ae1d8696773c78fdafd5d0e5fe9bc9f1b071c7ba493ba5c413a9184f3" -dependencies = [ - "bitflags", - "ctor", - "napi-derive", - "napi-sys", - "once_cell", - "serde", - "serde_json", - "tokio", -] - -[[package]] -name = "napi-build" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d376940fd5b723c6893cd1ee3f33abbfd86acb1cd1ec079f3ab04a2a3bc4d3b1" - -[[package]] -name = "napi-derive" -version = "2.16.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cbe2585d8ac223f7d34f13701434b9d5f4eb9c332cccce8dee57ea18ab8ab0c" -dependencies = [ - "cfg-if", - "convert_case", - "napi-derive-backend", - "proc-macro2", - "quote", - "syn 2.0.104", -] - -[[package]] -name = "napi-derive-backend" -version = "1.0.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1639aaa9eeb76e91c6ae66da8ce3e89e921cd3885e99ec85f4abacae72fc91bf" -dependencies = [ - "convert_case", - "once_cell", - "proc-macro2", - "quote", - "regex", - "semver", - "syn 2.0.104", -] - -[[package]] -name = "napi-sys" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "427802e8ec3a734331fec1035594a210ce1ff4dc5bc1950530920ab717964ea3" -dependencies = [ - "libloading", -] - [[package]] name = "native-tls" version = "0.2.14" @@ -3485,6 +3395,7 @@ dependencies = [ "rivet-metrics", "rivet-runtime", "rivet-types", + "scc", "serde", "serde_bare", "serde_json", @@ -3615,6 +3526,7 @@ dependencies = [ "rivet-runner-protocol", "rivet-runtime", "rivet-types", + "scc", "serde", "serde_bare", "serde_json", @@ -5157,40 +5069,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "rivetkit-native" -version = "2.2.1" -dependencies = [ - "anyhow", - "async-trait", - "base64 0.22.1", - "hex", - "libsqlite3-sys", - "napi", - "napi-build", - "napi-derive", - "rivet-envoy-client", - "rivet-envoy-protocol", - "rivetkit-sqlite-native", - "serde", - "serde_json", - "tokio", - "tracing", - "tracing-subscriber", - "uuid", -] - -[[package]] -name = "rivetkit-sqlite-native" -version = "2.1.6" -dependencies = [ - "async-trait", - "getrandom 0.2.16", - "libsqlite3-sys", - "tokio", - "tracing", -] - [[package]] name = "rocksdb" version = "0.24.0" diff --git a/Cargo.toml b/Cargo.toml index fa7ac778b1..26cad9083e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ members = [ "engine/packages/pegboard-runner", "engine/packages/pools", "engine/packages/postgres-util", + "engine/packages/runner-protocol", "engine/packages/runtime", "engine/packages/service-manager", "engine/packages/telemetry", @@ -53,11 +54,8 @@ members = [ "engine/sdks/rust/envoy-client", "engine/sdks/rust/envoy-protocol", "engine/sdks/rust/epoxy-protocol", - "engine/packages/runner-protocol", "engine/sdks/rust/test-envoy", - "engine/sdks/rust/ups-protocol", - "rivetkit-typescript/packages/sqlite-native", - "rivetkit-typescript/packages/rivetkit-native" + "engine/sdks/rust/ups-protocol" ] [workspace.package] @@ -448,6 +446,9 @@ members = [ [workspace.dependencies.rivet-postgres-util] path = "engine/packages/postgres-util" + [workspace.dependencies.rivet-runner-protocol] + path = "engine/packages/runner-protocol" + [workspace.dependencies.rivet-runtime] path = "engine/packages/runtime" @@ -503,17 +504,11 @@ members = [ [workspace.dependencies.rivet-envoy-client] path = "engine/sdks/rust/envoy-client" - [workspace.dependencies.epoxy-protocol] - path = "engine/sdks/rust/epoxy-protocol" - [workspace.dependencies.rivet-envoy-protocol] path = "engine/sdks/rust/envoy-protocol" - [workspace.dependencies.rivetkit-sqlite-native] - path = "rivetkit-typescript/packages/sqlite-native" - - [workspace.dependencies.rivet-runner-protocol] - path = "engine/packages/runner-protocol" + [workspace.dependencies.epoxy-protocol] + path = "engine/sdks/rust/epoxy-protocol" [workspace.dependencies.rivet-test-envoy] path = "engine/sdks/rust/test-envoy" diff --git a/engine/packages/pegboard-envoy/Cargo.toml b/engine/packages/pegboard-envoy/Cargo.toml index 765fba175c..1429c27e01 100644 --- a/engine/packages/pegboard-envoy/Cargo.toml +++ b/engine/packages/pegboard-envoy/Cargo.toml @@ -27,6 +27,7 @@ rivet-metrics.workspace = true rivet-envoy-protocol.workspace = true rivet-runtime.workspace = true rivet-types.workspace = true +scc.workspace = true serde_bare.workspace = true serde_json.workspace = true serde.workspace = true diff --git a/engine/packages/pegboard-envoy/src/conn.rs b/engine/packages/pegboard-envoy/src/conn.rs index 768e5e9ce2..2f8f2bdba9 100644 --- a/engine/packages/pegboard-envoy/src/conn.rs +++ b/engine/packages/pegboard-envoy/src/conn.rs @@ -15,6 +15,7 @@ use rivet_data::converted::{ActorNameKeyData, MetadataKeyData}; use rivet_envoy_protocol::{self as protocol, versioned}; use rivet_guard_core::WebSocketHandle; use rivet_types::runner_configs::RunnerConfigKind; +use scc::HashMap; use universaldb::prelude::*; use vbare::OwnedVersionedData; @@ -26,6 +27,7 @@ pub struct Conn { pub envoy_key: String, pub protocol_version: u16, pub ws_handle: WebSocketHandle, + pub authorized_tunnel_routes: HashMap<(protocol::GatewayId, protocol::RequestId), ()>, pub is_serverless: bool, pub last_rtt: AtomicU32, /// Timestamp (epoch ms) of the last pong received from the envoy. @@ -101,6 +103,7 @@ pub async fn init_conn( envoy_key, protocol_version, ws_handle, + authorized_tunnel_routes: HashMap::new(), is_serverless: false, last_rtt: AtomicU32::new(0), last_ping_ts: AtomicI64::new(util::timestamp::now()), @@ -114,7 +117,6 @@ pub async fn init_conn( Ok(Arc::new(conn)) } - #[tracing::instrument(skip_all)] pub async fn handle_init( ctx: &StandaloneCtx, diff --git a/engine/packages/pegboard-envoy/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-envoy/src/tunnel_to_ws_task.rs index 0fdb4c1716..57b07c90a4 100644 --- a/engine/packages/pegboard-envoy/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-envoy/src/tunnel_to_ws_task.rs @@ -162,6 +162,10 @@ async fn handle_message( } protocol::ToEnvoyConn::ToEnvoyAckEvents(x) => protocol::ToEnvoy::ToEnvoyAckEvents(x), protocol::ToEnvoyConn::ToEnvoyTunnelMessage(x) => { + let _ = conn + .authorized_tunnel_routes + .insert_async((x.message_id.gateway_id, x.message_id.request_id), ()) + .await; protocol::ToEnvoy::ToEnvoyTunnelMessage(x) } }; diff --git a/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs index 5f1938a75a..b7c12a8194 100644 --- a/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-envoy/src/ws_to_tunnel_task.rs @@ -8,6 +8,7 @@ use pegboard::actor_kv; use pegboard::pubsub_subjects::GatewayReceiverSubject; use rivet_envoy_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; use rivet_guard_core::websocket_handle::WebSocketReceiver; +use scc::HashMap; use std::sync::{Arc, atomic::Ordering}; use tokio::sync::{Mutex, MutexGuard, watch}; use universaldb::utils::end_of_key_range; @@ -366,7 +367,7 @@ async fn handle_message( } } protocol::ToRivet::ToRivetTunnelMessage(tunnel_msg) => { - handle_tunnel_message(&ctx, tunnel_msg) + handle_tunnel_message(ctx, &conn.authorized_tunnel_routes, tunnel_msg) .await .context("failed to handle tunnel message")?; } @@ -447,6 +448,7 @@ async fn ack_commands( #[tracing::instrument(skip_all)] async fn handle_tunnel_message( ctx: &StandaloneCtx, + authorized_tunnel_routes: &HashMap<(protocol::GatewayId, protocol::RequestId), ()>, msg: protocol::ToRivetTunnelMessage, ) -> Result<()> { // Extract inner data length before consuming msg @@ -457,6 +459,15 @@ async fn handle_tunnel_message( return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build()); } + if !authorized_tunnel_routes + .contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id)) + .await + { + return Err( + errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(), + ); + } + let gateway_reply_to = GatewayReceiverSubject::new(msg.message_id.gateway_id).to_string(); let msg_serialized = versioned::ToGateway::wrap_latest(protocol::ToGateway::ToRivetTunnelMessage(msg)) @@ -470,8 +481,7 @@ async fn handle_tunnel_message( ); // Publish message to UPS - ctx.ups() - .context("failed to get UPS instance for tunnel message")? + ctx.ups()? .publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) .await .with_context(|| { @@ -500,6 +510,10 @@ fn tunnel_message_inner_data_len(kind: &protocol::ToRivetTunnelMessageKind) -> u } } +#[cfg(test)] +#[path = "../tests/support/ws_to_tunnel_task.rs"] +mod tests; + async fn send_actor_kv_error(conn: &Conn, request_id: u32, message: &str) -> Result<()> { let res_msg = versioned::ToEnvoy::wrap_latest(protocol::ToEnvoy::ToEnvoyKvResponse( protocol::ToEnvoyKvResponse { diff --git a/engine/packages/pegboard-envoy/tests/support/ws_to_tunnel_task.rs b/engine/packages/pegboard-envoy/tests/support/ws_to_tunnel_task.rs new file mode 100644 index 0000000000..16dea915d5 --- /dev/null +++ b/engine/packages/pegboard-envoy/tests/support/ws_to_tunnel_task.rs @@ -0,0 +1,82 @@ +// TODO: Use TestCtx +// use std::{sync::Arc, time::Duration}; + +// use pegboard::pubsub_subjects::GatewayReceiverSubject; +// use rivet_envoy_protocol as protocol; +// use scc::HashMap; +// use universalpubsub::{NextOutput, PubSub, driver::memory::MemoryDriver}; + +// use super::handle_tunnel_message; + +// fn memory_pubsub(channel: &str) -> PubSub { +// PubSub::new(Arc::new(MemoryDriver::new(channel.to_string()))) +// } + +// fn response_abort_message( +// gateway_id: protocol::GatewayId, +// request_id: protocol::RequestId, +// ) -> protocol::ToRivetTunnelMessage { +// protocol::ToRivetTunnelMessage { +// message_id: protocol::MessageId { +// gateway_id, +// request_id, +// message_index: 0, +// }, +// message_kind: protocol::ToRivetTunnelMessageKind::ToRivetResponseAbort, +// } +// } + +// #[tokio::test] +// async fn rejects_unissued_tunnel_message_pairs() { +// let pubsub = memory_pubsub("pegboard-envoy-ws-to-tunnel-test-reject"); +// let gateway_id = [1, 2, 3, 4]; +// let request_id = [5, 6, 7, 8]; +// let mut sub = pubsub +// .subscribe(&GatewayReceiverSubject::new(gateway_id).to_string()) +// .await +// .unwrap(); +// let authorized_tunnel_routes = HashMap::new(); + +// let err = handle_tunnel_message( +// &pubsub, +// 1024, +// &authorized_tunnel_routes, +// response_abort_message(gateway_id, request_id), +// ) +// .await +// .unwrap_err(); +// assert!(err.to_string().contains("unauthorized tunnel message")); + +// let recv = tokio::time::timeout(Duration::from_millis(50), sub.next()).await; +// assert!(recv.is_err()); +// } + +// #[tokio::test] +// async fn republishes_issued_tunnel_message_pairs() { +// let pubsub = memory_pubsub("pegboard-envoy-ws-to-tunnel-test-allow"); +// let gateway_id = [9, 10, 11, 12]; +// let request_id = [13, 14, 15, 16]; +// let mut sub = pubsub +// .subscribe(&GatewayReceiverSubject::new(gateway_id).to_string()) +// .await +// .unwrap(); +// let authorized_tunnel_routes = HashMap::new(); +// let _ = authorized_tunnel_routes +// .insert_async((gateway_id, request_id), ()) +// .await; + +// handle_tunnel_message( +// &pubsub, +// 1024, +// &authorized_tunnel_routes, +// response_abort_message(gateway_id, request_id), +// ) +// .await +// .unwrap(); + +// let msg = tokio::time::timeout(Duration::from_secs(1), sub.next()) +// .await +// .unwrap() +// .unwrap(); +// assert!(matches!(msg, NextOutput::Message(_))); +// } diff --git a/engine/packages/pegboard-runner/Cargo.toml b/engine/packages/pegboard-runner/Cargo.toml index e68cc102c9..08cdd6727d 100644 --- a/engine/packages/pegboard-runner/Cargo.toml +++ b/engine/packages/pegboard-runner/Cargo.toml @@ -28,6 +28,7 @@ rivet-metrics.workspace = true rivet-runner-protocol.workspace = true rivet-runtime.workspace = true rivet-types.workspace = true +scc.workspace = true serde_bare.workspace = true serde_json.workspace = true serde.workspace = true diff --git a/engine/packages/pegboard-runner/src/conn.rs b/engine/packages/pegboard-runner/src/conn.rs index 46eca7a87c..d8a107c9e7 100644 --- a/engine/packages/pegboard-runner/src/conn.rs +++ b/engine/packages/pegboard-runner/src/conn.rs @@ -16,6 +16,7 @@ use rivet_data::converted::{ActorNameKeyData, MetadataKeyData}; use rivet_guard_core::WebSocketHandle; use rivet_runner_protocol::{self as protocol, versioned}; use rivet_types::runner_configs::RunnerConfigKind; +use scc::HashMap; use universaldb::prelude::*; use vbare::OwnedVersionedData; @@ -29,6 +30,7 @@ pub struct Conn { pub workflow_id: Id, pub protocol_version: u16, pub ws_handle: WebSocketHandle, + pub authorized_tunnel_routes: HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>, pub last_rtt: AtomicU32, /// Timestamp (epoch ms) of the last pong received from the runner. pub last_ping_ts: AtomicI64, @@ -188,6 +190,7 @@ pub async fn init_conn( workflow_id, protocol_version, ws_handle, + authorized_tunnel_routes: HashMap::new(), last_rtt: AtomicU32::new(0), last_ping_ts: AtomicI64::new(util::timestamp::now()), }); @@ -213,7 +216,6 @@ pub async fn init_conn( Ok(conn) } - enum Init { Mk2(protocol::mk2::ToServerInit), Mk1(protocol::ToServerInit), diff --git a/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs index 6763e27519..cfd6ad649b 100644 --- a/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-runner/src/tunnel_to_ws_task.rs @@ -158,6 +158,10 @@ async fn handle_message_mk2( protocol::mk2::ToClient::ToClientAckEvents(x) } protocol::mk2::ToRunner::ToClientTunnelMessage(x) => { + let _ = conn + .authorized_tunnel_routes + .insert_async((x.message_id.gateway_id, x.message_id.request_id), ()) + .await; protocol::mk2::ToClient::ToClientTunnelMessage(x) } }; @@ -250,6 +254,10 @@ async fn handle_message_mk1( protocol::ToRunner::ToClientAckEvents(x) => protocol::ToClient::ToClientAckEvents(x), protocol::ToRunner::ToClientKvResponse(x) => protocol::ToClient::ToClientKvResponse(x), protocol::ToRunner::ToClientTunnelMessage(x) => { + let _ = conn + .authorized_tunnel_routes + .insert_async((x.message_id.gateway_id, x.message_id.request_id), ()) + .await; protocol::ToClient::ToClientTunnelMessage(x) } }; diff --git a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs index f3492a706f..550e9ff216 100644 --- a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs @@ -9,11 +9,12 @@ use pegboard::pubsub_subjects::GatewayReceiverSubject; use rivet_envoy_protocol as ep; use rivet_guard_core::websocket_handle::WebSocketReceiver; use rivet_runner_protocol::{self as protocol, PROTOCOL_MK2_VERSION, versioned}; +use scc::HashMap; use std::sync::{Arc, atomic::Ordering}; use tokio::sync::{Mutex, MutexGuard, watch}; use universaldb::utils::end_of_key_range; -use universalpubsub::PublishOpts; use universalpubsub::Subscriber; +use universalpubsub::{PubSub, PublishOpts}; use vbare::OwnedVersionedData; use crate::{LifecycleResult, actor_event_demuxer::ActorEventDemuxer, conn::Conn, errors, metrics}; @@ -455,9 +456,17 @@ async fn handle_message_mk2( } } protocol::mk2::ToServer::ToServerTunnelMessage(tunnel_msg) => { - handle_tunnel_message_mk2(&ctx, tunnel_msg) - .await - .context("failed to handle tunnel message")?; + handle_tunnel_message_mk2( + &ctx.ups() + .context("failed to get UPS instance for tunnel message")?, + ctx.config() + .pegboard() + .runner_max_response_payload_body_size(), + &conn.authorized_tunnel_routes, + tunnel_msg, + ) + .await + .context("failed to handle tunnel message")?; } // NOTE: This does not process the first init event. See `conn::init_conn` protocol::mk2::ToServer::ToServerInit(_) => { @@ -775,9 +784,17 @@ async fn handle_message_mk1(ctx: &StandaloneCtx, conn: &Conn, msg: Bytes) -> Res } } protocol::ToServer::ToServerTunnelMessage(tunnel_msg) => { - handle_tunnel_message_mk1(&ctx, tunnel_msg) - .await - .context("failed to handle tunnel message")?; + handle_tunnel_message_mk1( + &ctx.ups() + .context("failed to get UPS instance for tunnel message")?, + ctx.config() + .pegboard() + .runner_max_response_payload_body_size(), + &conn.authorized_tunnel_routes, + tunnel_msg, + ) + .await + .context("failed to handle tunnel message")?; } // Forward to runner wf protocol::ToServer::ToServerInit(_) @@ -838,22 +855,28 @@ async fn ack_commands( #[tracing::instrument(skip_all)] async fn handle_tunnel_message_mk2( - ctx: &StandaloneCtx, + ups: &PubSub, + max_payload_size: usize, + authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>, msg: protocol::mk2::ToServerTunnelMessage, ) -> Result<()> { // Extract inner data length before consuming msg let inner_data_len = tunnel_message_inner_data_len_mk2(&msg.message_kind); // Enforce incoming payload size - if inner_data_len - > ctx - .config() - .pegboard() - .runner_max_response_payload_body_size() - { + if inner_data_len > max_payload_size { return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build()); } + if !authorized_tunnel_routes + .contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id)) + .await + { + return Err( + errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(), + ); + } + let gateway_reply_to = GatewayReceiverSubject::new(msg.message_id.gateway_id).to_string(); let msg_serialized = versioned::ToGateway::wrap_latest(protocol::mk2::ToGateway::ToServerTunnelMessage(msg)) @@ -867,9 +890,7 @@ async fn handle_tunnel_message_mk2( ); // Publish message to UPS - ctx.ups() - .context("failed to get UPS instance for tunnel message")? - .publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) + ups.publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) .await .with_context(|| { format!( @@ -883,7 +904,9 @@ async fn handle_tunnel_message_mk2( #[tracing::instrument(skip_all)] async fn handle_tunnel_message_mk1( - ctx: &StandaloneCtx, + ups: &PubSub, + max_payload_size: usize, + authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>, msg: protocol::ToServerTunnelMessage, ) -> Result<()> { // Ignore DeprecatedTunnelAck messages (used only for backwards compatibility) @@ -898,15 +921,19 @@ async fn handle_tunnel_message_mk1( let inner_data_len = tunnel_message_inner_data_len_mk1(&msg.message_kind); // Enforce incoming payload size - if inner_data_len - > ctx - .config() - .pegboard() - .runner_max_response_payload_body_size() - { + if inner_data_len > max_payload_size { return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build()); } + if !authorized_tunnel_routes + .contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id)) + .await + { + return Err( + errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(), + ); + } + // Publish message to UPS let gateway_reply_to = GatewayReceiverSubject::new(msg.message_id.gateway_id).to_string(); let msg_serialized = versioned::ToGateway::v3_to_v7(versioned::ToGateway::V3( @@ -914,9 +941,7 @@ async fn handle_tunnel_message_mk1( ))? .serialize_with_embedded_version(PROTOCOL_MK2_VERSION) .context("failed to serialize tunnel message for gateway")?; - ctx.ups() - .context("failed to get UPS instance for tunnel message")? - .publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) + ups.publish(&gateway_reply_to, &msg_serialized, PublishOpts::one()) .await .with_context(|| { format!( @@ -961,6 +986,10 @@ fn tunnel_message_inner_data_len_mk1(kind: &protocol::ToServerTunnelMessageKind) } } +#[cfg(test)] +#[path = "../tests/support/ws_to_tunnel_task.rs"] +mod tests; + /// Send ack message for deprecated tunnel versions. /// /// We have to parse as specifically a v2 message since we need the exact request & message ID diff --git a/engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs b/engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs new file mode 100644 index 0000000000..53b4278136 --- /dev/null +++ b/engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs @@ -0,0 +1,150 @@ +use std::{sync::Arc, time::Duration}; + +use pegboard::pubsub_subjects::GatewayReceiverSubject; +use rivet_runner_protocol as protocol; +use scc::HashMap; +use universalpubsub::{NextOutput, PubSub, driver::memory::MemoryDriver}; + +use super::{handle_tunnel_message_mk1, handle_tunnel_message_mk2}; + +fn memory_pubsub(channel: &str) -> PubSub { + PubSub::new(Arc::new(MemoryDriver::new(channel.to_string()))) +} + +fn response_abort_message_mk2( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, +) -> protocol::mk2::ToServerTunnelMessage { + protocol::mk2::ToServerTunnelMessage { + message_id: protocol::mk2::MessageId { + gateway_id, + request_id, + message_index: 0, + }, + message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort, + } +} + +fn response_abort_message_mk1( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, +) -> protocol::ToServerTunnelMessage { + protocol::ToServerTunnelMessage { + message_id: protocol::MessageId { + gateway_id, + request_id, + message_index: 0, + }, + message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseAbort, + } +} + +#[tokio::test] +async fn rejects_unissued_mk2_tunnel_message_pairs() { + let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-reject-mk2"); + let gateway_id = [1, 2, 3, 4]; + let request_id = [5, 6, 7, 8]; + let mut sub = pubsub + .subscribe(&GatewayReceiverSubject::new(gateway_id).to_string()) + .await + .unwrap(); + let authorized_tunnel_routes = HashMap::new(); + + let err = handle_tunnel_message_mk2( + &pubsub, + 1024, + &authorized_tunnel_routes, + response_abort_message_mk2(gateway_id, request_id), + ) + .await + .unwrap_err(); + assert!(err.to_string().contains("unauthorized tunnel message")); + + let recv = tokio::time::timeout(Duration::from_millis(50), sub.next()).await; + assert!(recv.is_err()); +} + +#[tokio::test] +async fn republishes_issued_mk2_tunnel_message_pairs() { + let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-allow-mk2"); + let gateway_id = [9, 10, 11, 12]; + let request_id = [13, 14, 15, 16]; + let mut sub = pubsub + .subscribe(&GatewayReceiverSubject::new(gateway_id).to_string()) + .await + .unwrap(); + let authorized_tunnel_routes = HashMap::new(); + let _ = authorized_tunnel_routes + .insert_async((gateway_id, request_id), ()) + .await; + + handle_tunnel_message_mk2( + &pubsub, + 1024, + &authorized_tunnel_routes, + response_abort_message_mk2(gateway_id, request_id), + ) + .await + .unwrap(); + + let msg = tokio::time::timeout(Duration::from_secs(1), sub.next()) + .await + .unwrap() + .unwrap(); + assert!(matches!(msg, NextOutput::Message(_))); +} + +#[tokio::test] +async fn rejects_unissued_mk1_tunnel_message_pairs() { + let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-reject-mk1"); + let gateway_id = [17, 18, 19, 20]; + let request_id = [21, 22, 23, 24]; + let mut sub = pubsub + .subscribe(&GatewayReceiverSubject::new(gateway_id).to_string()) + .await + .unwrap(); + let authorized_tunnel_routes = HashMap::new(); + + let err = handle_tunnel_message_mk1( + &pubsub, + 1024, + &authorized_tunnel_routes, + response_abort_message_mk1(gateway_id, request_id), + ) + .await + .unwrap_err(); + assert!(err.to_string().contains("unauthorized tunnel message")); + + let recv = tokio::time::timeout(Duration::from_millis(50), sub.next()).await; + assert!(recv.is_err()); +} + +#[tokio::test] +async fn republishes_issued_mk1_tunnel_message_pairs() { + let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-allow-mk1"); + let gateway_id = [25, 26, 27, 28]; + let request_id = [29, 30, 31, 32]; + let mut sub = pubsub + .subscribe(&GatewayReceiverSubject::new(gateway_id).to_string()) + .await + .unwrap(); + let authorized_tunnel_routes = HashMap::new(); + let _ = authorized_tunnel_routes + .insert_async((gateway_id, request_id), ()) + .await; + + handle_tunnel_message_mk1( + &pubsub, + 1024, + &authorized_tunnel_routes, + response_abort_message_mk1(gateway_id, request_id), + ) + .await + .unwrap(); + + let msg = tokio::time::timeout(Duration::from_secs(1), sub.next()) + .await + .unwrap() + .unwrap(); + assert!(matches!(msg, NextOutput::Message(_))); +}