From c3aad6f409afa14ddaff5ec6c19e75a84ca209eb Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 12:59:04 -0400 Subject: [PATCH 01/23] fix(acp-nats-ws): align websocket bridge with acp transport draft Signed-off-by: Yordis Prieto --- rsworkspace/Cargo.lock | 1 + rsworkspace/crates/acp-nats-ws/Cargo.toml | 1 + rsworkspace/crates/acp-nats-ws/README.md | 8 ++-- .../acp-nats-ws/src/acp_connection_id.rs | 30 +++++++++++++ .../crates/acp-nats-ws/src/connection.rs | 42 ++++++++++++++++--- .../crates/acp-nats-ws/src/constants.rs | 3 ++ rsworkspace/crates/acp-nats-ws/src/main.rs | 18 ++++---- rsworkspace/crates/acp-nats-ws/src/upgrade.rs | 35 ++++++++++++---- 8 files changed, 115 insertions(+), 23 deletions(-) create mode 100644 rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs diff --git a/rsworkspace/Cargo.lock b/rsworkspace/Cargo.lock index 9ed295afb..669da1dc4 100644 --- a/rsworkspace/Cargo.lock +++ b/rsworkspace/Cargo.lock @@ -83,6 +83,7 @@ dependencies = [ "tracing-subscriber", "trogon-nats", "trogon-std", + "uuid", ] [[package]] diff --git a/rsworkspace/crates/acp-nats-ws/Cargo.toml b/rsworkspace/crates/acp-nats-ws/Cargo.toml index c5451d991..0eccdc916 100644 --- a/rsworkspace/crates/acp-nats-ws/Cargo.toml +++ b/rsworkspace/crates/acp-nats-ws/Cargo.toml @@ -21,6 +21,7 @@ tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal", " tracing = { workspace = true } trogon-nats = { workspace = true } trogon-std = { workspace = true, features = ["telemetry-http"] } +uuid = { workspace = true } [dev-dependencies] serde_json = { workspace = true } diff --git a/rsworkspace/crates/acp-nats-ws/README.md b/rsworkspace/crates/acp-nats-ws/README.md index 846fc17b6..831d5a229 100644 --- a/rsworkspace/crates/acp-nats-ws/README.md +++ b/rsworkspace/crates/acp-nats-ws/README.md @@ -1,6 +1,6 @@ # ACP NATS WebSocket -Translates [Agent Client Protocol](https://agentclientprotocol.com) (ACP) messages between WebSocket connections and [NATS](https://nats.io), letting browser-based UIs and remote clients talk to distributed agent backends over a standard WebSocket endpoint. +Translates [Agent Client Protocol](https://agentclientprotocol.com) (ACP) messages between WebSocket connections and [NATS](https://nats.io), letting browser-based UIs and remote clients talk to distributed agent backends over the draft remote transport's WebSocket profile. For managed NATS infrastructure in production, we recommend Synadia Synadia. @@ -14,7 +14,7 @@ graph LR ## Features -- Multiple concurrent WebSocket connections, each with its own ACP session +- Multiple concurrent WebSocket connections - Bidirectional ACP bridge with request forwarding - OpenTelemetry integration (logs, metrics, traces) - Graceful shutdown (SIGINT/SIGTERM) with per-connection drain @@ -33,9 +33,11 @@ cargo build --release -p acp-nats-ws Connect with any WebSocket client: ```bash -websocat ws://127.0.0.1:8080/ws +websocat ws://127.0.0.1:8080/acp ``` +The WebSocket upgrade response includes `Acp-Connection-Id`. A legacy `/ws` alias remains available for older clients. + ## Configuration ### WebSocket Server diff --git a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs new file mode 100644 index 000000000..d2d1fcd18 --- /dev/null +++ b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs @@ -0,0 +1,30 @@ +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct AcpConnectionId(uuid::Uuid); + +impl AcpConnectionId { + pub fn new() -> Self { + Self(uuid::Uuid::now_v7()) + } +} + +impl Default for AcpConnectionId { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Display for AcpConnectionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_generates_non_empty_id() { + assert!(!AcpConnectionId::new().to_string().is_empty()); + } +} diff --git a/rsworkspace/crates/acp-nats-ws/src/connection.rs b/rsworkspace/crates/acp-nats-ws/src/connection.rs index 7f54f25b9..8e54ced6e 100644 --- a/rsworkspace/crates/acp-nats-ws/src/connection.rs +++ b/rsworkspace/crates/acp-nats-ws/src/connection.rs @@ -9,10 +9,12 @@ use tokio::sync::watch; use tracing::{error, info, warn}; use trogon_std::time::SystemClock; +use crate::acp_connection_id::AcpConnectionId; use crate::constants::DUPLEX_BUFFER_SIZE; /// Handles a single WebSocket connection by bridging it to NATS via ACP. pub async fn handle( + connection_id: AcpConnectionId, socket: WebSocket, nats_client: N, js_client: J, @@ -68,7 +70,7 @@ pub async fn handle( let mut io_task = tokio::task::spawn_local(io_task); - info!("WebSocket connection established, ACP bridge running"); + info!(%connection_id, "WebSocket connection established, ACP bridge running"); let shutdown_result = tokio::select! { result = &mut client_task => { @@ -118,8 +120,8 @@ pub async fn handle( } match shutdown_result { - Ok(()) => info!("WebSocket connection closed cleanly"), - Err(e) => warn!(error = e, "WebSocket connection closed with error"), + Ok(()) => info!(%connection_id, "WebSocket connection closed cleanly"), + Err(e) => warn!(%connection_id, error = e, "WebSocket connection closed with error"), } } @@ -130,7 +132,7 @@ async fn run_recv_pump( while let Some(Ok(msg)) = ws_receiver.next().await { let bytes = match msg { Message::Text(t) => bytes::Bytes::from(t), - Message::Binary(b) => b, + Message::Binary(_) => continue, Message::Close(_) => break, _ => continue, }; @@ -196,6 +198,8 @@ mod tests { use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message as TungsteniteMessage; + use crate::constants::ACP_ENDPOINT; + #[derive(Clone)] struct EchoState; @@ -213,12 +217,12 @@ mod tests { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let app = axum::Router::new() - .route("/ws", axum::routing::get(echo_handler)) + .route(ACP_ENDPOINT, axum::routing::get(echo_handler)) .with_state(EchoState); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); - format!("ws://{}/ws", addr) + format!("ws://{}{}", addr, ACP_ENDPOINT) } #[tokio::test] @@ -245,4 +249,30 @@ mod tests { } } } + + #[tokio::test] + async fn binary_messages_are_ignored() { + let url = start_echo_server().await; + let (mut ws, _) = connect_async(&url).await.unwrap(); + + ws.send(TungsteniteMessage::Binary(bytes::Bytes::from_static( + b"ignored", + ))) + .await + .unwrap(); + ws.send(TungsteniteMessage::Text("kept".into())) + .await + .unwrap(); + + let msg = tokio::time::timeout(Duration::from_secs(2), ws.next()) + .await + .expect("timeout") + .expect("stream ended") + .unwrap(); + + match msg { + TungsteniteMessage::Text(text) => assert_eq!(text, "kept"), + other => panic!("expected text frame, got {other:?}"), + } + } } diff --git a/rsworkspace/crates/acp-nats-ws/src/constants.rs b/rsworkspace/crates/acp-nats-ws/src/constants.rs index f78fd678c..30f1f8097 100644 --- a/rsworkspace/crates/acp-nats-ws/src/constants.rs +++ b/rsworkspace/crates/acp-nats-ws/src/constants.rs @@ -1,6 +1,9 @@ use std::net::{IpAddr, Ipv4Addr}; +pub const ACP_CONNECTION_ID_HEADER: &str = "acp-connection-id"; +pub const ACP_ENDPOINT: &str = "/acp"; pub const DEFAULT_HOST: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); pub const DEFAULT_PORT: u16 = 8080; pub const DUPLEX_BUFFER_SIZE: usize = 64 * 1024; +pub const LEGACY_WS_ENDPOINT: &str = "/ws"; pub const THREAD_NAME: &str = "acp-ws-local"; diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs index 6faee93d7..409612fb1 100644 --- a/rsworkspace/crates/acp-nats-ws/src/main.rs +++ b/rsworkspace/crates/acp-nats-ws/src/main.rs @@ -1,3 +1,4 @@ +mod acp_connection_id; mod config; mod connection; mod constants; @@ -48,7 +49,8 @@ async fn main() -> Result<(), Box> { let app = trogon_std::telemetry::http::instrument_router( axum::Router::new() - .route("/ws", axum::routing::get(upgrade::handle)) + .route(ACP_ENDPOINT, axum::routing::get(upgrade::handle)) + .route(LEGACY_WS_ENDPOINT, axum::routing::get(upgrade::handle)) .with_state(state), ); @@ -88,7 +90,7 @@ async fn main() -> Result<(), Box> { #[cfg(coverage)] fn main() {} -use constants::THREAD_NAME; +use constants::{ACP_CONNECTION_ID_HEADER, ACP_ENDPOINT, LEGACY_WS_ENDPOINT, THREAD_NAME}; /// Runs a single-threaded tokio runtime with a /// `LocalSet`. All WebSocket connections are processed here because the ACP @@ -153,6 +155,7 @@ async fn process_connections( let js = js_client.clone(); let cfg = config.clone(); conn_handles.push(tokio::task::spawn_local(connection::handle( + req.connection_id, req.socket, client, js, @@ -259,7 +262,7 @@ mod tests { }; let app = axum::Router::new() - .route("/ws", axum::routing::get(upgrade::handle)) + .route(ACP_ENDPOINT, axum::routing::get(upgrade::handle)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -279,8 +282,9 @@ mod tests { nats_mock.set_response("acp.agent.initialize", nats_response.into()); // Connect client - let ws_url = format!("ws://{}/ws", addr); - let (mut ws_stream, _) = connect_async(ws_url).await.unwrap(); + let ws_url = format!("ws://{}{}", addr, ACP_ENDPOINT); + let (mut ws_stream, response) = connect_async(ws_url).await.unwrap(); + assert!(response.headers().contains_key(ACP_CONNECTION_ID_HEADER)); // Send initialize request let req = @@ -349,7 +353,7 @@ mod tests { }; let app = axum::Router::new() - .route("/ws", axum::routing::get(upgrade::handle)) + .route(ACP_ENDPOINT, axum::routing::get(upgrade::handle)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -364,7 +368,7 @@ mod tests { .unwrap(); }); - let ws_url = format!("ws://{}/ws", addr); + let ws_url = format!("ws://{}{}", addr, ACP_ENDPOINT); let (mut ws_stream, _) = connect_async(&ws_url).await.unwrap(); let req = diff --git a/rsworkspace/crates/acp-nats-ws/src/upgrade.rs b/rsworkspace/crates/acp-nats-ws/src/upgrade.rs index a63ad1343..ad577fcd6 100644 --- a/rsworkspace/crates/acp-nats-ws/src/upgrade.rs +++ b/rsworkspace/crates/acp-nats-ws/src/upgrade.rs @@ -1,10 +1,14 @@ +use crate::acp_connection_id::AcpConnectionId; +use crate::constants::ACP_CONNECTION_ID_HEADER; use axum::extract::State; use axum::extract::ws::{WebSocket, WebSocketUpgrade}; +use axum::http::HeaderValue; use axum::response::Response; use tokio::sync::{mpsc, watch}; use tracing::error; pub struct ConnectionRequest { + pub connection_id: AcpConnectionId, pub socket: WebSocket, pub shutdown_rx: watch::Receiver, } @@ -16,11 +20,15 @@ pub struct UpgradeState { } pub async fn handle(ws: WebSocketUpgrade, State(state): State) -> Response { + let connection_id = AcpConnectionId::new(); + let response_header = HeaderValue::from_str(&connection_id.to_string()) + .expect("generated ACP connection id must be a valid header value"); let shutdown_rx = state.shutdown_tx.subscribe(); - ws.on_upgrade(move |socket| async move { + let mut response = ws.on_upgrade(move |socket| async move { if state .conn_tx .send(ConnectionRequest { + connection_id, socket, shutdown_rx, }) @@ -28,12 +36,17 @@ pub async fn handle(ws: WebSocketUpgrade, State(state): State) -> { error!("Connection thread is gone; dropping WebSocket"); } - }) + }); + response + .headers_mut() + .insert(ACP_CONNECTION_ID_HEADER, response_header); + response } #[cfg(test)] mod tests { use super::*; + use crate::constants::{ACP_CONNECTION_ID_HEADER, ACP_ENDPOINT}; use std::time::Duration; use tokio::net::TcpListener; use tokio_tungstenite::connect_async; @@ -49,7 +62,7 @@ mod tests { }; let app = axum::Router::new() - .route("/ws", axum::routing::get(handle)) + .route(ACP_ENDPOINT, axum::routing::get(handle)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -58,8 +71,15 @@ mod tests { axum::serve(listener, app).await.unwrap(); }); - let url = format!("ws://{}/ws", addr); - let (_ws, _) = connect_async(&url).await.unwrap(); + let url = format!("ws://{}{}", addr, ACP_ENDPOINT); + let (_ws, response) = connect_async(&url).await.unwrap(); + let connection_id = response + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .expect("upgrade response should include Acp-Connection-Id") + .to_str() + .unwrap() + .to_string(); let req = tokio::time::timeout(Duration::from_secs(2), conn_rx.recv()) .await @@ -67,6 +87,7 @@ mod tests { .expect("channel closed"); assert!(!*req.shutdown_rx.borrow()); + assert_eq!(req.connection_id.to_string(), connection_id); } #[tokio::test] @@ -80,7 +101,7 @@ mod tests { }; let app = axum::Router::new() - .route("/ws", axum::routing::get(handle)) + .route(ACP_ENDPOINT, axum::routing::get(handle)) .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -91,7 +112,7 @@ mod tests { drop(conn_rx); - let url = format!("ws://{}/ws", addr); + let url = format!("ws://{}{}", addr, ACP_ENDPOINT); let (_ws, _) = connect_async(&url).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; From 38d7d3bd0ba14b4f7321015028288335f16a0418 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 13:37:24 -0400 Subject: [PATCH 02/23] feat(acp-nats-ws): unblock streamable HTTP clients Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/Cargo.toml | 4 +- rsworkspace/crates/acp-nats-ws/README.md | 25 +- .../acp-nats-ws/src/acp_connection_id.rs | 39 + .../crates/acp-nats-ws/src/constants.rs | 2 + rsworkspace/crates/acp-nats-ws/src/main.rs | 429 +++++-- .../crates/acp-nats-ws/src/transport.rs | 1021 +++++++++++++++++ rsworkspace/crates/acp-nats-ws/src/upgrade.rs | 120 -- 7 files changed, 1421 insertions(+), 219 deletions(-) create mode 100644 rsworkspace/crates/acp-nats-ws/src/transport.rs delete mode 100644 rsworkspace/crates/acp-nats-ws/src/upgrade.rs diff --git a/rsworkspace/crates/acp-nats-ws/Cargo.toml b/rsworkspace/crates/acp-nats-ws/Cargo.toml index 0eccdc916..9850b06d3 100644 --- a/rsworkspace/crates/acp-nats-ws/Cargo.toml +++ b/rsworkspace/crates/acp-nats-ws/Cargo.toml @@ -17,6 +17,8 @@ bytes = { workspace = true } clap = { workspace = true, features = ["env"] } futures-util = { workspace = true, features = ["sink"] } opentelemetry = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal", "net", "sync", "io-util"] } tracing = { workspace = true } trogon-nats = { workspace = true } @@ -24,7 +26,7 @@ trogon-std = { workspace = true, features = ["telemetry-http"] } uuid = { workspace = true } [dev-dependencies] -serde_json = { workspace = true } +tower = { version = "0.5", features = ["util"] } tokio-tungstenite = { workspace = true } tracing-subscriber = { workspace = true, features = ["fmt"] } trogon-nats = { workspace = true, features = ["test-support"] } diff --git a/rsworkspace/crates/acp-nats-ws/README.md b/rsworkspace/crates/acp-nats-ws/README.md index 831d5a229..e08d933f6 100644 --- a/rsworkspace/crates/acp-nats-ws/README.md +++ b/rsworkspace/crates/acp-nats-ws/README.md @@ -1,21 +1,22 @@ -# ACP NATS WebSocket +# ACP NATS Streamable HTTP & WebSocket -Translates [Agent Client Protocol](https://agentclientprotocol.com) (ACP) messages between WebSocket connections and [NATS](https://nats.io), letting browser-based UIs and remote clients talk to distributed agent backends over the draft remote transport's WebSocket profile. +Translates [Agent Client Protocol](https://agentclientprotocol.com) (ACP) messages between [NATS](https://nats.io) and the draft remote transport served on `/acp`, including both Streamable HTTP (`POST`/`GET`/`DELETE`) and WebSocket upgrade. For managed NATS infrastructure in production, we recommend Synadia Synadia. ```mermaid graph LR - A1[Client1] <-->|ws| B[acp-nats-ws] - A2[Client2] <-->|ws| B + A1[Client1] <-->|http or ws| B[acp-nats-ws] + A2[Client2] <-->|http or ws| B AN[ClientN] <-->|ws| B B <-->|NATS| C[Backend] ``` ## Features -- Multiple concurrent WebSocket connections -- Bidirectional ACP bridge with request forwarding +- Streamable HTTP transport on `/acp` with session-scoped SSE listeners +- WebSocket upgrade on `/acp` plus a legacy `/ws` compatibility alias +- Multiple concurrent ACP connections sharing the same NATS bridge - OpenTelemetry integration (logs, metrics, traces) - Graceful shutdown (SIGINT/SIGTERM) with per-connection drain - Custom prefix support for multi-tenancy @@ -36,7 +37,17 @@ Connect with any WebSocket client: websocat ws://127.0.0.1:8080/acp ``` -The WebSocket upgrade response includes `Acp-Connection-Id`. A legacy `/ws` alias remains available for older clients. +Or use Streamable HTTP: + +```bash +curl -i \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -d '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}' \ + http://127.0.0.1:8080/acp +``` + +`POST /acp` returns an SSE response for JSON-RPC requests, `GET /acp` opens a session-scoped SSE listener with `Acp-Connection-Id` and `Acp-Session-Id`, and `DELETE /acp` terminates a connection. The WebSocket upgrade response and HTTP initialize response both include `Acp-Connection-Id`. ## Configuration diff --git a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs index d2d1fcd18..5cc001dee 100644 --- a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs +++ b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs @@ -5,6 +5,12 @@ impl AcpConnectionId { pub fn new() -> Self { Self(uuid::Uuid::now_v7()) } + + pub fn parse(s: &str) -> Result { + uuid::Uuid::parse_str(s) + .map(Self) + .map_err(AcpConnectionIdError::InvalidUuid) + } } impl Default for AcpConnectionId { @@ -19,6 +25,27 @@ impl std::fmt::Display for AcpConnectionId { } } +#[derive(Debug)] +pub enum AcpConnectionIdError { + InvalidUuid(uuid::Error), +} + +impl std::fmt::Display for AcpConnectionIdError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidUuid(error) => write!(f, "invalid ACP connection id: {error}"), + } + } +} + +impl std::error::Error for AcpConnectionIdError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::InvalidUuid(error) => Some(error), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -27,4 +54,16 @@ mod tests { fn new_generates_non_empty_id() { assert!(!AcpConnectionId::new().to_string().is_empty()); } + + #[test] + fn parse_round_trips_uuid() { + let id = AcpConnectionId::new(); + let parsed = AcpConnectionId::parse(&id.to_string()).unwrap(); + assert_eq!(parsed, id); + } + + #[test] + fn parse_rejects_invalid_uuid() { + assert!(AcpConnectionId::parse("not-a-uuid").is_err()); + } } diff --git a/rsworkspace/crates/acp-nats-ws/src/constants.rs b/rsworkspace/crates/acp-nats-ws/src/constants.rs index 30f1f8097..501c15547 100644 --- a/rsworkspace/crates/acp-nats-ws/src/constants.rs +++ b/rsworkspace/crates/acp-nats-ws/src/constants.rs @@ -2,8 +2,10 @@ use std::net::{IpAddr, Ipv4Addr}; pub const ACP_CONNECTION_ID_HEADER: &str = "acp-connection-id"; pub const ACP_ENDPOINT: &str = "/acp"; +pub const ACP_SESSION_ID_HEADER: &str = "acp-session-id"; pub const DEFAULT_HOST: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); pub const DEFAULT_PORT: u16 = 8080; pub const DUPLEX_BUFFER_SIZE: usize = 64 * 1024; pub const LEGACY_WS_ENDPOINT: &str = "/ws"; pub const THREAD_NAME: &str = "acp-ws-local"; +pub const X_ACCEL_BUFFERING_HEADER: &str = "x-accel-buffering"; diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs index 409612fb1..270b13f28 100644 --- a/rsworkspace/crates/acp-nats-ws/src/main.rs +++ b/rsworkspace/crates/acp-nats-ws/src/main.rs @@ -2,11 +2,11 @@ mod acp_connection_id; mod config; mod connection; mod constants; -mod upgrade; +mod transport; use tokio::sync::{mpsc, watch}; use tracing::info; -use upgrade::{ConnectionRequest, UpgradeState}; +use transport::{AppState, ManagerRequest}; #[cfg(not(coverage))] use { @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box> { ); let ws_config = config::apply_timeout_overrides(ws_config, &SystemEnv); - info!("ACP WebSocket bridge starting"); + info!("ACP remote transport bridge starting"); let nats_connect_timeout = acp_nats::nats_connect_timeout(&SystemEnv); let nats_client = nats::connect(ws_config.acp.nats(), nats_connect_timeout).await?; @@ -36,28 +36,36 @@ async fn main() -> Result<(), Box> { let js_client = trogon_nats::jetstream::NatsJetStreamClient::new(js_context); let (shutdown_tx, _) = watch::channel(false); - let (conn_tx, conn_rx) = mpsc::unbounded_channel::(); + let (manager_tx, manager_rx) = mpsc::unbounded_channel::(); let conn_thread = std::thread::Builder::new() .name(THREAD_NAME.into()) - .spawn(move || run_connection_thread(conn_rx, nats_client, js_client, ws_config.acp))?; + .spawn(move || run_connection_thread(manager_rx, nats_client, js_client, ws_config.acp))?; - let state = UpgradeState { - conn_tx, + let state = AppState { + manager_tx, shutdown_tx: shutdown_tx.clone(), }; let app = trogon_std::telemetry::http::instrument_router( axum::Router::new() - .route(ACP_ENDPOINT, axum::routing::get(upgrade::handle)) - .route(LEGACY_WS_ENDPOINT, axum::routing::get(upgrade::handle)) + .route( + ACP_ENDPOINT, + axum::routing::get(transport::get) + .post(transport::post) + .delete(transport::delete), + ) + .route( + LEGACY_WS_ENDPOINT, + axum::routing::get(transport::legacy_websocket_get), + ) .with_state(state), ); let addr = SocketAddr::from((ws_config.host, ws_config.port)); let listener = tokio::net::TcpListener::bind(addr).await?; - info!(address = %addr, "Listening for WebSocket connections"); + info!(address = %addr, "Listening for ACP transport connections"); let result = axum::serve(listener, app) .with_graceful_shutdown(async move { @@ -68,8 +76,8 @@ async fn main() -> Result<(), Box> { .await; match &result { - Ok(()) => info!("ACP WebSocket bridge stopped"), - Err(e) => error!(error = %e, "ACP WebSocket bridge stopped with error"), + Ok(()) => info!("ACP remote transport bridge stopped"), + Err(e) => error!(error = %e, "ACP remote transport bridge stopped with error"), } // `serve` returning drops the Router (and its AppState.conn_tx), which @@ -90,13 +98,13 @@ async fn main() -> Result<(), Box> { #[cfg(coverage)] fn main() {} -use constants::{ACP_CONNECTION_ID_HEADER, ACP_ENDPOINT, LEGACY_WS_ENDPOINT, THREAD_NAME}; +use constants::{ACP_ENDPOINT, LEGACY_WS_ENDPOINT, THREAD_NAME}; /// Runs a single-threaded tokio runtime with a /// `LocalSet`. All WebSocket connections are processed here because the ACP /// `Agent` trait is `?Send`, requiring `spawn_local` / `Rc`. fn run_connection_thread( - conn_rx: mpsc::UnboundedReceiver, + manager_rx: mpsc::UnboundedReceiver, nats_client: N, js_client: J, config: acp_nats::Config, @@ -117,7 +125,12 @@ fn run_connection_thread( .expect("failed to create per-connection runtime"); let local = tokio::task::LocalSet::new(); - rt.block_on(local.run_until(process_connections(conn_rx, nats_client, js_client, config))); + rt.block_on(local.run_until(process_connections( + manager_rx, + nats_client, + js_client, + config, + ))); // run_until returns once process_connections completes, but // sub-tasks spawned by connection handlers (pumps, @@ -132,7 +145,7 @@ fn run_connection_thread( } async fn process_connections( - mut conn_rx: mpsc::UnboundedReceiver, + mut manager_rx: mpsc::UnboundedReceiver, nats_client: N, js_client: J, config: acp_nats::Config, @@ -147,30 +160,31 @@ async fn process_connections( J: acp_nats::JetStreamPublisher + acp_nats::JetStreamGetStream + 'static, trogon_nats::jetstream::JsMessageOf: trogon_nats::jetstream::JsRequestMessage, { - let mut conn_handles: Vec> = Vec::new(); - - while let Some(req) = conn_rx.recv().await { - conn_handles.retain(|h| !h.is_finished()); - let client = nats_client.clone(); - let js = js_client.clone(); - let cfg = config.clone(); - conn_handles.push(tokio::task::spawn_local(connection::handle( - req.connection_id, - req.socket, - client, - js, - cfg, - req.shutdown_rx, - ))); + let mut websocket_handles: Vec> = Vec::new(); + let mut http_connections = std::collections::HashMap::new(); + + while let Some(request) = manager_rx.recv().await { + transport::process_manager_request( + request, + &mut http_connections, + &mut websocket_handles, + &nats_client, + &js_client, + &config, + ) + .await; } - let active = conn_handles.iter().filter(|h| !h.is_finished()).count(); + let active = websocket_handles + .iter() + .filter(|h| !h.is_finished()) + .count(); info!( active_connections = active, "Connection channel closed, draining active connections" ); - for handle in conn_handles { + for handle in websocket_handles { let _ = handle.await; } @@ -180,12 +194,18 @@ async fn process_connections( #[cfg(test)] mod tests { use super::*; + use crate::constants::{ACP_CONNECTION_ID_HEADER, ACP_SESSION_ID_HEADER}; use acp_nats::Config; + use axum::body::{Body, to_bytes}; + use axum::http::header::{ACCEPT, CONTENT_TYPE}; + use axum::http::{Request, StatusCode}; use futures_util::{SinkExt, StreamExt}; + use serde_json::Value; use std::time::Duration; use tokio::net::TcpListener; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message; + use tower::ServiceExt; use trogon_nats::AdvancedMockNatsClient; #[derive(Clone)] @@ -233,41 +253,77 @@ mod tests { } } - #[tokio::test] - async fn test_websocket_connection_lifecycle() { - let nats_mock = AdvancedMockNatsClient::new(); - let config = Config::new( + fn test_config() -> Config { + Config::new( acp_nats::AcpPrefix::new("acp").unwrap(), acp_nats::NatsConfig { servers: vec!["localhost:4222".to_string()], auth: trogon_nats::NatsAuth::None, }, - ); - - // Required by AdvancedMockNatsClient to not error out on subscribe() - let _injector = nats_mock.inject_messages(); + ) + } - let (shutdown_tx, mut shutdown_rx) = watch::channel(false); - let (conn_tx, conn_rx) = mpsc::unbounded_channel::(); + fn test_app(state: AppState) -> axum::Router { + axum::Router::new() + .route( + ACP_ENDPOINT, + axum::routing::get(transport::get) + .post(transport::post) + .delete(transport::delete), + ) + .route( + LEGACY_WS_ENDPOINT, + axum::routing::get(transport::legacy_websocket_get), + ) + .with_state(state) + } - let nats_mock_clone = nats_mock.clone(); - let conn_thread = std::thread::Builder::new() + fn spawn_connection_thread( + nats_mock: AdvancedMockNatsClient, + manager_rx: mpsc::UnboundedReceiver, + ) -> std::thread::JoinHandle<()> { + let config = test_config(); + std::thread::Builder::new() .name(THREAD_NAME.into()) - .spawn(move || run_connection_thread(conn_rx, nats_mock_clone, MockJs::new(), config)) - .expect("failed to spawn connection thread"); + .spawn(move || run_connection_thread(manager_rx, nats_mock, MockJs::new(), config)) + .expect("failed to spawn connection thread") + } - let state = UpgradeState { - conn_tx, + fn build_test_app( + nats_mock: AdvancedMockNatsClient, + ) -> ( + axum::Router, + watch::Sender, + std::thread::JoinHandle<()>, + ) { + let (shutdown_tx, _) = watch::channel(false); + let (manager_tx, manager_rx) = mpsc::unbounded_channel::(); + let conn_thread = spawn_connection_thread(nats_mock, manager_rx); + let app = test_app(AppState { + manager_tx, shutdown_tx: shutdown_tx.clone(), - }; + }); + (app, shutdown_tx, conn_thread) + } - let app = axum::Router::new() - .route(ACP_ENDPOINT, axum::routing::get(upgrade::handle)) - .with_state(state); + async fn start_test_server( + nats_mock: AdvancedMockNatsClient, + ) -> ( + std::net::SocketAddr, + watch::Sender, + tokio::task::JoinHandle<()>, + std::thread::JoinHandle<()>, + ) { + let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + let (manager_tx, manager_rx) = mpsc::unbounded_channel::(); + let conn_thread = spawn_connection_thread(nats_mock, manager_rx); + let app = test_app(AppState { + manager_tx, + shutdown_tx: shutdown_tx.clone(), + }); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let server_task = tokio::spawn(async move { axum::serve(listener, app) .with_graceful_shutdown(async move { @@ -277,10 +333,44 @@ mod tests { .unwrap(); }); + (addr, shutdown_tx, server_task, conn_thread) + } + + async fn body_text(response: axum::response::Response) -> String { + let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + String::from_utf8(bytes.to_vec()).unwrap() + } + + fn sse_events(body: &str) -> Vec { + body.lines() + .filter_map(|line| line.strip_prefix("data: ")) + .map(|json| serde_json::from_str(json).unwrap()) + .collect() + } + + fn http_post_request(body: &str) -> Request { + Request::builder() + .method("POST") + .uri(ACP_ENDPOINT) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .body(Body::from(body.to_owned())) + .unwrap() + } + + #[tokio::test] + async fn test_websocket_connection_lifecycle() { + let nats_mock = AdvancedMockNatsClient::new(); + + // Required by AdvancedMockNatsClient to not error out on subscribe() + let _injector = nats_mock.inject_messages(); + // Setup mock response for NATS let nats_response = r#"{"agentCapabilities": {"loadSession": false, "mcpCapabilities": {"http": false, "sse": false}, "promptCapabilities": {"audio": false, "embeddedContext": false, "image": false}, "sessionCapabilities": {}}, "authMethods": [], "protocolVersion": 0}"#; nats_mock.set_response("acp.agent.initialize", nats_response.into()); + let (addr, shutdown_tx, server_task, conn_thread) = start_test_server(nats_mock).await; + // Connect client let ws_url = format!("ws://{}{}", addr, ACP_ENDPOINT); let (mut ws_stream, response) = connect_async(ws_url).await.unwrap(); @@ -326,47 +416,11 @@ mod tests { #[tokio::test] async fn test_shutdown_while_connection_active() { let nats_mock = AdvancedMockNatsClient::new(); - let config = Config::new( - acp_nats::AcpPrefix::new("acp").unwrap(), - acp_nats::NatsConfig { - servers: vec!["localhost:4222".to_string()], - auth: trogon_nats::NatsAuth::None, - }, - ); let _injector = nats_mock.inject_messages(); nats_mock.hang_next_request(); - - let (shutdown_tx, mut shutdown_rx) = watch::channel(false); - let (conn_tx, conn_rx) = mpsc::unbounded_channel::(); - - let nats_mock_clone = nats_mock.clone(); - let conn_thread = std::thread::Builder::new() - .name(THREAD_NAME.into()) - .spawn(move || run_connection_thread(conn_rx, nats_mock_clone, MockJs::new(), config)) - .expect("failed to spawn connection thread"); - - let state = UpgradeState { - conn_tx, - shutdown_tx: shutdown_tx.clone(), - }; - - let app = axum::Router::new() - .route(ACP_ENDPOINT, axum::routing::get(upgrade::handle)) - .with_state(state); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let server_task = tokio::spawn(async move { - axum::serve(listener, app) - .with_graceful_shutdown(async move { - let _ = shutdown_rx.changed().await; - }) - .await - .unwrap(); - }); + let (addr, shutdown_tx, server_task, conn_thread) = start_test_server(nats_mock).await; let ws_url = format!("ws://{}{}", addr, ACP_ENDPOINT); let (mut ws_stream, _) = connect_async(&ws_url).await.unwrap(); @@ -387,4 +441,197 @@ mod tests { conn_thread.join().unwrap(); } + + #[tokio::test] + async fn streamable_http_initialize_returns_connection_id_and_sse_response() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let response = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(CONTENT_TYPE).unwrap(), + "text/event-stream" + ); + + let connection_id = response + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + assert!(crate::acp_connection_id::AcpConnectionId::parse(&connection_id).is_ok()); + + let body = body_text(response).await; + let events = sse_events(&body); + assert_eq!(events.len(), 1); + assert_eq!(events[0]["id"], 1); + assert_eq!(events[0]["result"]["protocolVersion"], 0); + + shutdown_tx.send(true).unwrap(); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn streamable_http_session_new_returns_session_header_and_body() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + nats_mock.set_response( + "acp.agent.session.new", + r#"{"sessionId":"test-session-1"}"#.into(), + ); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let initialize = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + let connection_id = initialize + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + let _ = body_text(initialize).await; + + let session_new = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(ACP_ENDPOINT) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .body(Body::from( + r#"{"jsonrpc":"2.0","id":2,"method":"session/new","params":{"cwd":".","mcpServers":[]}}"#, + )) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(session_new.status(), StatusCode::OK); + assert_eq!( + session_new + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap(), + connection_id + ); + + let session_id = session_new + .headers() + .get(ACP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(str::to_owned); + let body = body_text(session_new).await; + let events = sse_events(&body); + assert_eq!(events.len(), 1); + assert_eq!(events[0]["id"], 2); + assert_eq!(events[0]["result"]["sessionId"], "test-session-1"); + assert_eq!(session_id.as_deref(), Some("test-session-1")); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn streamable_http_get_requires_connection_and_session_headers() { + let nats_mock = AdvancedMockNatsClient::new(); + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + + let response = app + .clone() + .oneshot( + Request::builder() + .method("GET") + .uri(ACP_ENDPOINT) + .header(ACCEPT, "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn streamable_http_delete_terminates_initialized_connection() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let initialize = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + let connection_id = initialize + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + let _ = body_text(initialize).await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("DELETE") + .uri(ACP_ENDPOINT) + .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::ACCEPTED); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs new file mode 100644 index 000000000..03f146436 --- /dev/null +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -0,0 +1,1021 @@ +use crate::acp_connection_id::{AcpConnectionId, AcpConnectionIdError}; +use crate::connection; +use crate::constants::{ACP_CONNECTION_ID_HEADER, ACP_SESSION_ID_HEADER, X_ACCEL_BUFFERING_HEADER}; +use acp_nats::{StdJsonSerialize, agent::Bridge, client, spawn_notification_forwarder}; +use agent_client_protocol::{AgentSideConnection, RequestId, SessionNotification}; +use axum::extract::FromRequestParts; +use axum::extract::Request; +use axum::extract::State; +use axum::extract::ws::{WebSocket, WebSocketUpgrade}; +use axum::http::header::{ACCEPT, CONTENT_TYPE}; +use axum::http::{HeaderMap, HeaderValue, StatusCode}; +use axum::response::sse::{Event, Sse}; +use axum::response::{IntoResponse, Response}; +use futures_util::stream; +use serde::Deserialize; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; +use std::convert::Infallible; +use std::rc::Rc; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; +use tokio::sync::{mpsc, oneshot, watch}; +use tracing::{error, info, warn}; +use trogon_std::time::SystemClock; + +type SseSender = mpsc::UnboundedSender; +type SseReceiver = mpsc::UnboundedReceiver; + +#[derive(Clone)] +pub struct AppState { + pub manager_tx: mpsc::UnboundedSender, + pub shutdown_tx: watch::Sender, +} + +pub struct ConnectionRequest { + pub connection_id: AcpConnectionId, + pub socket: WebSocket, + pub shutdown_rx: watch::Receiver, +} + +pub enum ManagerRequest { + WebSocket(ConnectionRequest), + HttpPost { + connection_id: Option, + session_id: Option, + message: IncomingHttpMessage, + response: oneshot::Sender>, + shutdown_rx: watch::Receiver, + }, + HttpGet { + connection_id: AcpConnectionId, + session_id: acp_nats::AcpSessionId, + response: oneshot::Sender>, + }, + HttpDelete { + connection_id: AcpConnectionId, + response: oneshot::Sender>, + }, +} + +#[derive(Debug)] +pub enum HttpPostOutcome { + Accepted, + Live { + connection_id: AcpConnectionId, + session_id: Option, + stream: SseReceiver, + }, + Buffered { + connection_id: AcpConnectionId, + session_id: Option, + events: Vec, + }, +} + +#[derive(Debug)] +pub enum HttpTransportError { + BadRequest(&'static str), + NotFound(&'static str), + Conflict(&'static str), + UnsupportedMediaType(&'static str), + NotAcceptable(&'static str), + NotImplemented(&'static str), + Internal(&'static str), +} + +impl HttpTransportError { + fn into_response(self) -> Response { + let (status, message) = match self { + Self::BadRequest(message) => (StatusCode::BAD_REQUEST, message), + Self::NotFound(message) => (StatusCode::NOT_FOUND, message), + Self::Conflict(message) => (StatusCode::CONFLICT, message), + Self::UnsupportedMediaType(message) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, message), + Self::NotAcceptable(message) => (StatusCode::NOT_ACCEPTABLE, message), + Self::NotImplemented(message) => (StatusCode::NOT_IMPLEMENTED, message), + Self::Internal(message) => (StatusCode::INTERNAL_SERVER_ERROR, message), + }; + + (status, message).into_response() + } +} + +impl std::fmt::Display for HttpTransportError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::BadRequest(message) + | Self::NotFound(message) + | Self::Conflict(message) + | Self::UnsupportedMediaType(message) + | Self::NotAcceptable(message) + | Self::NotImplemented(message) + | Self::Internal(message) => f.write_str(message), + } + } +} + +impl std::error::Error for HttpTransportError {} + +#[derive(Debug)] +pub struct HttpConnectionHandle { + pub command_tx: mpsc::UnboundedSender, +} + +#[derive(Debug)] +pub enum HttpConnectionCommand { + Post { + session_id: Option, + message: IncomingHttpMessage, + response: oneshot::Sender>, + }, + AttachListener { + session_id: acp_nats::AcpSessionId, + response: oneshot::Sender>, + }, + Close { + response: oneshot::Sender>, + }, +} + +#[derive(Clone, Debug)] +pub(crate) enum SseFrame { + Json(String), +} + +impl SseFrame { + fn into_event(self) -> Event { + match self { + Self::Json(json) => Event::default().data(json), + } + } +} + +#[derive(Debug)] +enum PendingRequest { + Live { + request_id: RequestId, + sender: SseSender, + }, + Buffered { + request_id: RequestId, + events: Vec, + response: oneshot::Sender>, + }, +} + +#[derive(Debug, Deserialize)] +pub struct IncomingHttpMessage { + pub id: Option, + pub method: Option, + pub params: Option, + pub result: Option, + pub error: Option, + #[serde(skip)] + pub raw: String, +} + +impl IncomingHttpMessage { + pub fn parse(raw: String) -> Result { + let trimmed = raw.trim_start(); + if trimmed.starts_with('[') { + return Err(HttpTransportError::NotImplemented( + "batch JSON-RPC requests are not supported", + )); + } + + let mut parsed = serde_json::from_str::(&raw) + .map_err(|_| HttpTransportError::BadRequest("invalid JSON-RPC payload"))?; + parsed.raw = raw; + Ok(parsed) + } + + fn is_request(&self) -> bool { + self.id.is_some() && self.method.is_some() + } + + fn is_notification(&self) -> bool { + self.id.is_none() && self.method.is_some() + } + + fn is_response(&self) -> bool { + self.id.is_some() + && self.method.is_none() + && (self.result.is_some() || self.error.is_some()) + } + + fn method_name(&self) -> Option<&str> { + self.method.as_deref() + } + + fn is_initialize(&self) -> bool { + self.method_name() == Some("initialize") + } + + fn is_session_new(&self) -> bool { + self.method_name() == Some("session/new") + } + + fn requires_session_id(&self) -> bool { + if self.is_response() { + return true; + } + + match self.method_name() { + Some("initialize") | Some("authenticate") | Some("session/new") + | Some("session/list") => false, + Some(method) if method.starts_with("session/") => true, + _ => false, + } + } + + fn params_session_id(&self) -> Result, HttpTransportError> { + let Some(params) = &self.params else { + return Ok(None); + }; + + let Some(session_id) = params.get("sessionId").and_then(Value::as_str) else { + return Ok(None); + }; + + acp_nats::AcpSessionId::new(session_id) + .map(Some) + .map_err(|_| HttpTransportError::BadRequest("invalid sessionId in request body")) + } +} + +#[derive(Debug, Deserialize)] +struct OutgoingHttpMessage { + id: Option, + params: Option, + result: Option, +} + +impl OutgoingHttpMessage { + fn parse(raw: &str) -> Option { + serde_json::from_str(raw).ok() + } + + fn params_session_id(&self) -> Option { + let params = self.params.as_ref()?; + let session_id = params.get("sessionId")?.as_str()?; + acp_nats::AcpSessionId::new(session_id).ok() + } + + fn result_session_id(&self) -> Option { + let result = self.result.as_ref()?; + let session_id = result.get("sessionId")?.as_str()?; + acp_nats::AcpSessionId::new(session_id).ok() + } +} + +pub async fn get(State(state): State, request: Request) -> Response { + if is_websocket_request(request.headers()) { + let (mut parts, _body) = request.into_parts(); + match WebSocketUpgrade::from_request_parts(&mut parts, &state).await { + Ok(ws) => websocket_response(ws, state), + Err(_) => { + HttpTransportError::BadRequest("invalid WebSocket upgrade request").into_response() + } + } + } else { + match http_get(request.headers().clone(), state).await { + Ok(response) => response, + Err(error) => error.into_response(), + } + } +} + +pub async fn legacy_websocket_get(ws: WebSocketUpgrade, State(state): State) -> Response { + websocket_response(ws, state) +} + +pub async fn post(headers: HeaderMap, State(state): State, body: String) -> Response { + match http_post(headers, state, body).await { + Ok(response) => response, + Err(error) => error.into_response(), + } +} + +pub async fn delete(headers: HeaderMap, State(state): State) -> Response { + match http_delete(headers, state).await { + Ok(response) => response, + Err(error) => error.into_response(), + } +} + +fn websocket_response(ws: WebSocketUpgrade, state: AppState) -> Response { + let connection_id = AcpConnectionId::new(); + let response_header = HeaderValue::from_str(&connection_id.to_string()) + .expect("generated ACP connection id must be a valid header value"); + let shutdown_rx = state.shutdown_tx.subscribe(); + let mut response = ws.on_upgrade(move |socket| async move { + if state + .manager_tx + .send(ManagerRequest::WebSocket(ConnectionRequest { + connection_id, + socket, + shutdown_rx, + })) + .is_err() + { + error!("Connection thread is gone; dropping WebSocket"); + } + }); + response + .headers_mut() + .insert(ACP_CONNECTION_ID_HEADER, response_header); + response +} + +async fn http_post( + headers: HeaderMap, + state: AppState, + body: String, +) -> Result { + validate_post_headers(&headers)?; + + let message = IncomingHttpMessage::parse(body)?; + if !(message.is_request() || message.is_notification() || message.is_response()) { + return Err(HttpTransportError::BadRequest( + "invalid JSON-RPC message shape", + )); + } + + let connection_id = parse_connection_id_header(&headers)?; + let session_id = parse_session_id_header(&headers)?; + + validate_http_context(&message, connection_id.as_ref(), session_id.as_ref())?; + + let (response_tx, response_rx) = oneshot::channel(); + state + .manager_tx + .send(ManagerRequest::HttpPost { + connection_id, + session_id, + message, + response: response_tx, + shutdown_rx: state.shutdown_tx.subscribe(), + }) + .map_err(|_| HttpTransportError::Internal("connection manager is unavailable"))?; + + match response_rx + .await + .map_err(|_| HttpTransportError::Internal("connection manager dropped the request"))?? + { + HttpPostOutcome::Accepted => Ok(StatusCode::ACCEPTED.into_response()), + HttpPostOutcome::Live { + connection_id, + session_id, + stream, + } => Ok(build_sse_response(connection_id, session_id, stream)), + HttpPostOutcome::Buffered { + connection_id, + session_id, + events, + } => Ok(build_buffered_sse_response( + connection_id, + session_id, + events, + )), + } +} + +async fn http_get(headers: HeaderMap, state: AppState) -> Result { + validate_get_headers(&headers)?; + + let connection_id = parse_connection_id_header(&headers)?.ok_or( + HttpTransportError::BadRequest("missing Acp-Connection-Id header"), + )?; + let session_id = parse_session_id_header(&headers)?.ok_or(HttpTransportError::BadRequest( + "missing Acp-Session-Id header", + ))?; + + let (response_tx, response_rx) = oneshot::channel(); + state + .manager_tx + .send(ManagerRequest::HttpGet { + connection_id, + session_id, + response: response_tx, + }) + .map_err(|_| HttpTransportError::Internal("connection manager is unavailable"))?; + + let stream = response_rx + .await + .map_err(|_| HttpTransportError::Internal("connection manager dropped the request"))??; + + let mut response = Sse::new(stream::unfold(stream, |mut stream| async move { + stream + .recv() + .await + .map(|item| (Ok::(item.into_event()), stream)) + })) + .into_response(); + response + .headers_mut() + .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); + Ok(response) +} + +async fn http_delete(headers: HeaderMap, state: AppState) -> Result { + let connection_id = parse_connection_id_header(&headers)?.ok_or( + HttpTransportError::BadRequest("missing Acp-Connection-Id header"), + )?; + + let (response_tx, response_rx) = oneshot::channel(); + state + .manager_tx + .send(ManagerRequest::HttpDelete { + connection_id, + response: response_tx, + }) + .map_err(|_| HttpTransportError::Internal("connection manager is unavailable"))?; + + response_rx + .await + .map_err(|_| HttpTransportError::Internal("connection manager dropped the request"))??; + + Ok(StatusCode::ACCEPTED.into_response()) +} + +fn validate_post_headers(headers: &HeaderMap) -> Result<(), HttpTransportError> { + match headers + .get(CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + { + Some(value) if value.eq_ignore_ascii_case("application/json") => {} + _ => { + return Err(HttpTransportError::UnsupportedMediaType( + "Content-Type must be application/json", + )); + } + } + + let accept = headers + .get(ACCEPT) + .and_then(|value| value.to_str().ok()) + .ok_or(HttpTransportError::NotAcceptable( + "Accept must include application/json and text/event-stream", + ))?; + + if !accept_contains(accept, "application/json") || !accept_contains(accept, "text/event-stream") + { + return Err(HttpTransportError::NotAcceptable( + "Accept must include application/json and text/event-stream", + )); + } + + Ok(()) +} + +fn validate_get_headers(headers: &HeaderMap) -> Result<(), HttpTransportError> { + let accept = headers + .get(ACCEPT) + .and_then(|value| value.to_str().ok()) + .ok_or(HttpTransportError::NotAcceptable( + "Accept must include text/event-stream", + ))?; + + if !accept_contains(accept, "text/event-stream") { + return Err(HttpTransportError::NotAcceptable( + "Accept must include text/event-stream", + )); + } + + Ok(()) +} + +fn validate_http_context( + message: &IncomingHttpMessage, + connection_id: Option<&AcpConnectionId>, + session_id: Option<&acp_nats::AcpSessionId>, +) -> Result<(), HttpTransportError> { + if message.is_initialize() { + if connection_id.is_some() { + return Err(HttpTransportError::BadRequest( + "initialize must not include Acp-Connection-Id", + )); + } + if session_id.is_some() { + return Err(HttpTransportError::BadRequest( + "initialize must not include Acp-Session-Id", + )); + } + return Ok(()); + } + + if connection_id.is_none() { + return Err(HttpTransportError::BadRequest( + "missing Acp-Connection-Id header", + )); + } + + let body_session_id = message.params_session_id()?; + if message.requires_session_id() && session_id.is_none() { + return Err(HttpTransportError::BadRequest( + "missing Acp-Session-Id header", + )); + } + + if let (Some(header_session_id), Some(body_session_id)) = (session_id, body_session_id.as_ref()) + { + if header_session_id != body_session_id { + return Err(HttpTransportError::BadRequest( + "Acp-Session-Id header does not match request body sessionId", + )); + } + } + + Ok(()) +} + +fn accept_contains(header: &str, expected: &str) -> bool { + header + .split(',') + .map(str::trim) + .any(|value| value.eq_ignore_ascii_case(expected)) +} + +fn parse_connection_id_header( + headers: &HeaderMap, +) -> Result, HttpTransportError> { + headers + .get(ACP_CONNECTION_ID_HEADER) + .map(|value| { + value + .to_str() + .map_err(|_| HttpTransportError::BadRequest("invalid Acp-Connection-Id header")) + .and_then(|value| { + AcpConnectionId::parse(value).map_err(|error| match error { + AcpConnectionIdError::InvalidUuid(_) => { + HttpTransportError::BadRequest("invalid Acp-Connection-Id header") + } + }) + }) + }) + .transpose() +} + +fn parse_session_id_header( + headers: &HeaderMap, +) -> Result, HttpTransportError> { + headers + .get(ACP_SESSION_ID_HEADER) + .map(|value| { + value + .to_str() + .map_err(|_| HttpTransportError::BadRequest("invalid Acp-Session-Id header")) + .and_then(|value| { + acp_nats::AcpSessionId::new(value).map_err(|_| { + HttpTransportError::BadRequest("invalid Acp-Session-Id header") + }) + }) + }) + .transpose() +} + +fn is_websocket_request(headers: &HeaderMap) -> bool { + headers + .get("upgrade") + .and_then(|value| value.to_str().ok()) + .map(|value| value.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) +} + +fn build_sse_response( + connection_id: AcpConnectionId, + session_id: Option, + stream: SseReceiver, +) -> Response { + let mut response = Sse::new(stream::unfold(stream, |mut stream| async move { + stream + .recv() + .await + .map(|item| (Ok::(item.into_event()), stream)) + })) + .into_response(); + set_transport_headers(response.headers_mut(), &connection_id, session_id.as_ref()); + response + .headers_mut() + .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); + response +} + +fn build_buffered_sse_response( + connection_id: AcpConnectionId, + session_id: Option, + events: Vec, +) -> Response { + let stream = stream::iter( + events + .into_iter() + .map(|item| Ok::(item.into_event())), + ); + let mut response = Sse::new(stream).into_response(); + set_transport_headers(response.headers_mut(), &connection_id, session_id.as_ref()); + response + .headers_mut() + .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); + response +} + +fn set_transport_headers( + headers: &mut HeaderMap, + connection_id: &AcpConnectionId, + session_id: Option<&acp_nats::AcpSessionId>, +) { + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()) + .expect("generated ACP connection id must be a valid header value"), + ); + if let Some(session_id) = session_id { + headers.insert( + ACP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id.as_str()) + .expect("generated ACP session id must be a valid header value"), + ); + } +} + +pub async fn run_http_connection( + connection_id: AcpConnectionId, + nats_client: N, + js_client: J, + config: acp_nats::Config, + mut command_rx: mpsc::UnboundedReceiver, + mut shutdown_rx: watch::Receiver, +) where + N: acp_nats::RequestClient + + acp_nats::PublishClient + + acp_nats::FlushClient + + acp_nats::SubscribeClient + + Clone + + 'static, + J: acp_nats::JetStreamPublisher + acp_nats::JetStreamGetStream + 'static, + trogon_nats::jetstream::JsMessageOf: trogon_nats::jetstream::JsRequestMessage, +{ + let (agent_write, mut output_read) = tokio::io::duplex(crate::constants::DUPLEX_BUFFER_SIZE); + let (mut input_write, agent_read) = tokio::io::duplex(crate::constants::DUPLEX_BUFFER_SIZE); + + let incoming = async_compat::Compat::new(agent_read); + let outgoing = async_compat::Compat::new(agent_write); + + let meter = acp_telemetry::meter("acp-nats-ws"); + let (notification_tx, notification_rx) = tokio::sync::mpsc::channel::(64); + let bridge = Rc::new(Bridge::new( + nats_client.clone(), + js_client, + SystemClock, + &meter, + config, + notification_tx, + )); + + let (connection, io_task) = + AgentSideConnection::new(bridge.clone(), outgoing, incoming, |fut| { + tokio::task::spawn_local(fut); + }); + + let connection = Rc::new(connection); + spawn_notification_forwarder(connection.clone(), notification_rx); + + let (input_tx, mut input_rx) = mpsc::unbounded_channel::(); + let input_task = tokio::task::spawn_local(async move { + while let Some(message) = input_rx.recv().await { + if input_write.write_all(message.as_bytes()).await.is_err() { + break; + } + if input_write.write_all(b"\n").await.is_err() { + break; + } + } + }); + + let (output_tx, mut output_rx) = mpsc::unbounded_channel::(); + let output_task = tokio::task::spawn_local(async move { + let mut reader = tokio::io::BufReader::new(&mut output_read); + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let trimmed = line.trim_end_matches(['\r', '\n']); + if trimmed.is_empty() { + continue; + } + if output_tx.send(trimmed.to_string()).is_err() { + break; + } + } + Err(_) => break, + } + } + }); + + let mut client_task = tokio::task::spawn_local(client::run( + nats_client, + connection.clone(), + bridge, + StdJsonSerialize, + )); + let mut io_task = tokio::task::spawn_local(io_task); + + let mut pending_request: Option = None; + let mut sessions = HashSet::::new(); + let mut get_listeners = HashMap::>::new(); + + info!(%connection_id, "HTTP connection established"); + + loop { + tokio::select! { + command = command_rx.recv() => { + let Some(command) = command else { + break; + }; + + match command { + HttpConnectionCommand::Post { session_id, message, response } => { + if message.is_request() { + if pending_request.is_some() { + let _ = response.send(Err(HttpTransportError::Conflict( + "only one in-flight HTTP request is supported per ACP connection", + ))); + continue; + } + + if message.is_session_new() { + pending_request = Some(PendingRequest::Buffered { + request_id: message.id.clone().expect("request must have id"), + events: Vec::new(), + response, + }); + let _ = input_tx.send(message.raw); + continue; + } + + if let Some(session_id) = session_id.clone() { + sessions.insert(session_id); + } + + let (stream_tx, stream_rx) = mpsc::unbounded_channel(); + pending_request = Some(PendingRequest::Live { + request_id: message.id.clone().expect("request must have id"), + sender: stream_tx, + }); + let _ = input_tx.send(message.raw); + let _ = response.send(Ok(HttpPostOutcome::Live { + connection_id: connection_id.clone(), + session_id, + stream: stream_rx, + })); + continue; + } + + if let Some(session_id) = session_id.clone() { + sessions.insert(session_id); + } + + if input_tx.send(message.raw).is_err() { + let _ = response.send(Err(HttpTransportError::Internal( + "failed to forward HTTP payload into ACP runtime", + ))); + continue; + } + + let _ = response.send(Ok(HttpPostOutcome::Accepted)); + } + HttpConnectionCommand::AttachListener { session_id, response } => { + if !sessions.contains(&session_id) { + let _ = response.send(Err(HttpTransportError::NotFound( + "unknown ACP session", + ))); + continue; + } + + let (stream_tx, stream_rx) = mpsc::unbounded_channel(); + get_listeners.entry(session_id).or_default().push(stream_tx); + let _ = response.send(Ok(stream_rx)); + } + HttpConnectionCommand::Close { response } => { + let _ = response.send(Ok(())); + break; + } + } + } + outbound = output_rx.recv() => { + let Some(outbound) = outbound else { + break; + }; + + let frame = SseFrame::Json(outbound.clone()); + let parsed = OutgoingHttpMessage::parse(&outbound); + + if let Some(pending) = pending_request.as_mut() { + match pending { + PendingRequest::Live { request_id, sender, .. } => { + if sender.send(frame.clone()).is_err() { + pending_request = None; + continue; + } + + if parsed.as_ref().and_then(|message| message.id.as_ref()) == Some(request_id) { + pending_request = None; + } + continue; + } + PendingRequest::Buffered { request_id, events, .. } => { + events.push(frame); + if parsed.as_ref().and_then(|message| message.id.as_ref()) == Some(request_id) { + let session_id = parsed.and_then(|message| message.result_session_id()); + if let Some(session_id) = session_id.clone() { + sessions.insert(session_id); + } + let events = std::mem::take(events); + if let Some(PendingRequest::Buffered { response, .. }) = pending_request.take() { + let _ = response.send(Ok(HttpPostOutcome::Buffered { + connection_id: connection_id.clone(), + session_id, + events, + })); + } + } + continue; + } + } + } + + let Some(session_id) = parsed.and_then(|message| message.params_session_id()) else { + continue; + }; + + let Some(listeners) = get_listeners.get_mut(&session_id) else { + continue; + }; + + listeners.retain(|listener| listener.send(frame.clone()).is_ok()); + } + result = &mut client_task => { + match result { + Ok(()) => info!(%connection_id, "HTTP client task completed"), + Err(error) => warn!(%connection_id, error = %error, "HTTP client task failed"), + } + break; + } + result = &mut io_task => { + match result { + Ok(Ok(())) => info!(%connection_id, "HTTP IO task completed"), + Ok(Err(error)) => warn!(%connection_id, error = %error, "HTTP IO task failed"), + Err(error) => warn!(%connection_id, error = %error, "HTTP IO task join failed"), + } + break; + } + _ = shutdown_rx.wait_for(|&shutdown| shutdown) => { + info!(%connection_id, "HTTP connection shutting down"); + break; + } + } + } + + if let Some(PendingRequest::Buffered { response, .. }) = pending_request.take() { + let _ = response.send(Err(HttpTransportError::Internal( + "HTTP connection closed before the request completed", + ))); + } + + input_task.abort(); + output_task.abort(); + + if !client_task.is_finished() { + client_task.abort(); + let _ = client_task.await; + } + if !io_task.is_finished() { + io_task.abort(); + let _ = io_task.await; + } + + info!(%connection_id, "HTTP connection closed"); +} + +pub async fn process_manager_request( + request: ManagerRequest, + http_connections: &mut HashMap, + websocket_handles: &mut Vec>, + nats_client: &N, + js_client: &J, + config: &acp_nats::Config, +) where + N: acp_nats::RequestClient + + acp_nats::PublishClient + + acp_nats::FlushClient + + acp_nats::SubscribeClient + + Clone + + Send + + 'static, + J: acp_nats::JetStreamPublisher + acp_nats::JetStreamGetStream + Clone + 'static, + trogon_nats::jetstream::JsMessageOf: trogon_nats::jetstream::JsRequestMessage, +{ + websocket_handles.retain(|handle| !handle.is_finished()); + + match request { + ManagerRequest::WebSocket(request) => { + websocket_handles.push(tokio::task::spawn_local(connection::handle( + request.connection_id, + request.socket, + nats_client.clone(), + js_client.clone(), + config.clone(), + request.shutdown_rx, + ))); + } + ManagerRequest::HttpPost { + connection_id, + session_id, + message, + response, + shutdown_rx, + } => { + let connection_id = match connection_id { + Some(connection_id) => connection_id, + None => { + if !message.is_initialize() { + let _ = response.send(Err(HttpTransportError::BadRequest( + "missing Acp-Connection-Id header", + ))); + return; + } + + let connection_id = AcpConnectionId::new(); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + http_connections.insert( + connection_id.clone(), + HttpConnectionHandle { + command_tx: command_tx.clone(), + }, + ); + tokio::task::spawn_local(run_http_connection( + connection_id.clone(), + nats_client.clone(), + js_client.clone(), + config.clone(), + command_rx, + shutdown_rx, + )); + connection_id + } + }; + + let Some(handle) = http_connections.get(&connection_id) else { + let _ = response.send(Err(HttpTransportError::NotFound("unknown ACP connection"))); + return; + }; + + if handle + .command_tx + .send(HttpConnectionCommand::Post { + session_id, + message, + response, + }) + .is_err() + { + http_connections.remove(&connection_id); + } + } + ManagerRequest::HttpGet { + connection_id, + session_id, + response, + } => { + let Some(handle) = http_connections.get(&connection_id) else { + let _ = response.send(Err(HttpTransportError::NotFound("unknown ACP connection"))); + return; + }; + + if handle + .command_tx + .send(HttpConnectionCommand::AttachListener { + session_id, + response, + }) + .is_err() + { + http_connections.remove(&connection_id); + } + } + ManagerRequest::HttpDelete { + connection_id, + response, + } => { + let Some(handle) = http_connections.remove(&connection_id) else { + let _ = response.send(Err(HttpTransportError::NotFound("unknown ACP connection"))); + return; + }; + + let _ = handle + .command_tx + .send(HttpConnectionCommand::Close { response }); + } + } +} diff --git a/rsworkspace/crates/acp-nats-ws/src/upgrade.rs b/rsworkspace/crates/acp-nats-ws/src/upgrade.rs deleted file mode 100644 index ad577fcd6..000000000 --- a/rsworkspace/crates/acp-nats-ws/src/upgrade.rs +++ /dev/null @@ -1,120 +0,0 @@ -use crate::acp_connection_id::AcpConnectionId; -use crate::constants::ACP_CONNECTION_ID_HEADER; -use axum::extract::State; -use axum::extract::ws::{WebSocket, WebSocketUpgrade}; -use axum::http::HeaderValue; -use axum::response::Response; -use tokio::sync::{mpsc, watch}; -use tracing::error; - -pub struct ConnectionRequest { - pub connection_id: AcpConnectionId, - pub socket: WebSocket, - pub shutdown_rx: watch::Receiver, -} - -#[derive(Clone)] -pub struct UpgradeState { - pub conn_tx: mpsc::UnboundedSender, - pub shutdown_tx: watch::Sender, -} - -pub async fn handle(ws: WebSocketUpgrade, State(state): State) -> Response { - let connection_id = AcpConnectionId::new(); - let response_header = HeaderValue::from_str(&connection_id.to_string()) - .expect("generated ACP connection id must be a valid header value"); - let shutdown_rx = state.shutdown_tx.subscribe(); - let mut response = ws.on_upgrade(move |socket| async move { - if state - .conn_tx - .send(ConnectionRequest { - connection_id, - socket, - shutdown_rx, - }) - .is_err() - { - error!("Connection thread is gone; dropping WebSocket"); - } - }); - response - .headers_mut() - .insert(ACP_CONNECTION_ID_HEADER, response_header); - response -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::constants::{ACP_CONNECTION_ID_HEADER, ACP_ENDPOINT}; - use std::time::Duration; - use tokio::net::TcpListener; - use tokio_tungstenite::connect_async; - - #[tokio::test] - async fn handle_sends_connection_request_through_channel() { - let (shutdown_tx, _shutdown_rx) = watch::channel(false); - let (conn_tx, mut conn_rx) = mpsc::unbounded_channel::(); - - let state = UpgradeState { - conn_tx, - shutdown_tx: shutdown_tx.clone(), - }; - - let app = axum::Router::new() - .route(ACP_ENDPOINT, axum::routing::get(handle)) - .with_state(state); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - tokio::spawn(async move { - axum::serve(listener, app).await.unwrap(); - }); - - let url = format!("ws://{}{}", addr, ACP_ENDPOINT); - let (_ws, response) = connect_async(&url).await.unwrap(); - let connection_id = response - .headers() - .get(ACP_CONNECTION_ID_HEADER) - .expect("upgrade response should include Acp-Connection-Id") - .to_str() - .unwrap() - .to_string(); - - let req = tokio::time::timeout(Duration::from_secs(2), conn_rx.recv()) - .await - .expect("timeout waiting for ConnectionRequest") - .expect("channel closed"); - - assert!(!*req.shutdown_rx.borrow()); - assert_eq!(req.connection_id.to_string(), connection_id); - } - - #[tokio::test] - async fn handle_logs_error_when_conn_rx_dropped() { - let (shutdown_tx, _shutdown_rx) = watch::channel(false); - let (conn_tx, conn_rx) = mpsc::unbounded_channel::(); - - let state = UpgradeState { - conn_tx, - shutdown_tx: shutdown_tx.clone(), - }; - - let app = axum::Router::new() - .route(ACP_ENDPOINT, axum::routing::get(handle)) - .with_state(state); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - tokio::spawn(async move { - axum::serve(listener, app).await.unwrap(); - }); - - drop(conn_rx); - - let url = format!("ws://{}{}", addr, ACP_ENDPOINT); - let (_ws, _) = connect_async(&url).await.unwrap(); - - tokio::time::sleep(Duration::from_millis(100)).await; - } -} From db2f17a6078d64ded212fb40411ab2e7b683b742 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 13:38:47 -0400 Subject: [PATCH 03/23] chore(acp-nats-ws): keep lockfile aligned Signed-off-by: Yordis Prieto --- rsworkspace/Cargo.lock | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rsworkspace/Cargo.lock b/rsworkspace/Cargo.lock index 669da1dc4..2f720feae 100644 --- a/rsworkspace/Cargo.lock +++ b/rsworkspace/Cargo.lock @@ -76,9 +76,11 @@ dependencies = [ "clap", "futures-util", "opentelemetry", + "serde", "serde_json", "tokio", "tokio-tungstenite 0.29.0", + "tower", "tracing", "tracing-subscriber", "trogon-nats", From 32deb246dd4230acf0d3ca43ab94f7f632c16a41 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 14:06:34 -0400 Subject: [PATCH 04/23] fix(acp-nats-ws): restore rust ci signal Signed-off-by: Yordis Prieto --- .../acp-nats-ws/src/acp_connection_id.rs | 18 + .../crates/acp-nats-ws/src/transport.rs | 635 +++++++++++++++++- 2 files changed, 648 insertions(+), 5 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs index 5cc001dee..c00b3037f 100644 --- a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs +++ b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs @@ -49,6 +49,7 @@ impl std::error::Error for AcpConnectionIdError { #[cfg(test)] mod tests { use super::*; + use std::error::Error as _; #[test] fn new_generates_non_empty_id() { @@ -66,4 +67,21 @@ mod tests { fn parse_rejects_invalid_uuid() { assert!(AcpConnectionId::parse("not-a-uuid").is_err()); } + + #[test] + fn default_generates_non_empty_id() { + assert!(!AcpConnectionId::default().to_string().is_empty()); + } + + #[test] + fn parse_error_displays_context() { + let error = AcpConnectionId::parse("not-a-uuid").unwrap_err(); + assert!(error.to_string().contains("invalid ACP connection id")); + } + + #[test] + fn parse_error_exposes_uuid_source() { + let error = AcpConnectionId::parse("not-a-uuid").unwrap_err(); + assert!(error.source().is_some()); + } } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 03f146436..d379377e4 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -517,12 +517,11 @@ fn validate_http_context( } if let (Some(header_session_id), Some(body_session_id)) = (session_id, body_session_id.as_ref()) + && header_session_id != body_session_id { - if header_session_id != body_session_id { - return Err(HttpTransportError::BadRequest( - "Acp-Session-Id header does not match request body sessionId", - )); - } + return Err(HttpTransportError::BadRequest( + "Acp-Session-Id header does not match request body sessionId", + )); } Ok(()) @@ -1019,3 +1018,629 @@ pub async fn process_manager_request( } } } + +#[cfg(test)] +mod tests { + use super::*; + use acp_nats::Config; + use axum::body::{Body, to_bytes}; + use axum::http::Request as HttpRequest; + use serde_json::{Value, json}; + use tokio::sync::{mpsc, oneshot, watch}; + use trogon_nats::AdvancedMockNatsClient; + + #[derive(Clone)] + struct MockJs { + publisher: trogon_nats::jetstream::MockJetStreamPublisher, + consumer_factory: trogon_nats::jetstream::MockJetStreamConsumerFactory, + } + + impl MockJs { + fn new() -> Self { + Self { + publisher: trogon_nats::jetstream::MockJetStreamPublisher::new(), + consumer_factory: trogon_nats::jetstream::MockJetStreamConsumerFactory::new(), + } + } + } + + impl trogon_nats::jetstream::JetStreamPublisher for MockJs { + type PublishError = trogon_nats::mocks::MockError; + type AckFuture = std::future::Ready< + Result, + >; + + async fn publish_with_headers( + &self, + subject: S, + headers: async_nats::HeaderMap, + payload: bytes::Bytes, + ) -> Result { + self.publisher + .publish_with_headers(subject, headers, payload) + .await + } + } + + impl trogon_nats::jetstream::JetStreamGetStream for MockJs { + type Error = trogon_nats::mocks::MockError; + type Stream = trogon_nats::jetstream::MockJetStreamStream; + + async fn get_stream + Send>( + &self, + stream_name: T, + ) -> Result { + self.consumer_factory.get_stream(stream_name).await + } + } + + fn test_config() -> Config { + Config::new( + acp_nats::AcpPrefix::new("acp").unwrap(), + acp_nats::NatsConfig { + servers: vec!["localhost:4222".to_string()], + auth: trogon_nats::NatsAuth::None, + }, + ) + } + + fn test_state() -> (AppState, mpsc::UnboundedReceiver) { + let (manager_tx, manager_rx) = mpsc::unbounded_channel(); + let (shutdown_tx, _) = watch::channel(false); + ( + AppState { + manager_tx, + shutdown_tx, + }, + manager_rx, + ) + } + + async fn json_event_body(response: Response) -> Vec { + let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let body = String::from_utf8(bytes.to_vec()).unwrap(); + body.lines() + .filter_map(|line| line.strip_prefix("data: ")) + .map(|json| serde_json::from_str(json).unwrap()) + .collect() + } + + fn post_headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + headers.insert( + ACCEPT, + HeaderValue::from_static("application/json, text/event-stream"), + ); + headers + } + + fn get_headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream")); + headers + } + + fn session_id() -> acp_nats::AcpSessionId { + acp_nats::AcpSessionId::new("session-1").unwrap() + } + + #[test] + fn http_transport_error_into_response_maps_status_codes() { + let cases = [ + ( + HttpTransportError::BadRequest("bad"), + StatusCode::BAD_REQUEST, + ), + ( + HttpTransportError::NotFound("missing"), + StatusCode::NOT_FOUND, + ), + ( + HttpTransportError::Conflict("conflict"), + StatusCode::CONFLICT, + ), + ( + HttpTransportError::UnsupportedMediaType("unsupported"), + StatusCode::UNSUPPORTED_MEDIA_TYPE, + ), + ( + HttpTransportError::NotAcceptable("not-acceptable"), + StatusCode::NOT_ACCEPTABLE, + ), + ( + HttpTransportError::NotImplemented("not-implemented"), + StatusCode::NOT_IMPLEMENTED, + ), + ( + HttpTransportError::Internal("internal"), + StatusCode::INTERNAL_SERVER_ERROR, + ), + ]; + + for (error, status) in cases { + assert_eq!(error.into_response().status(), status); + } + } + + #[test] + fn http_transport_error_display_returns_message() { + assert_eq!(HttpTransportError::BadRequest("bad").to_string(), "bad"); + assert_eq!( + HttpTransportError::Conflict("conflict").to_string(), + "conflict" + ); + assert_eq!( + HttpTransportError::Internal("internal").to_string(), + "internal" + ); + } + + #[test] + fn incoming_http_message_parses_and_classifies_shapes() { + let request = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"session/new","params":{"cwd":".","mcpServers":[]}}"# + .to_string(), + ) + .unwrap(); + assert!(request.is_request()); + assert!(request.is_session_new()); + + let notification = + IncomingHttpMessage::parse(r#"{"jsonrpc":"2.0","method":"initialized"}"#.to_string()) + .unwrap(); + assert!(notification.is_notification()); + assert!(!notification.requires_session_id()); + + let response = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":99,"result":{"ok":true}}"#.to_string(), + ) + .unwrap(); + assert!(response.is_response()); + assert!(response.requires_session_id()); + } + + #[test] + fn incoming_http_message_parse_rejects_batch_and_invalid_json() { + let batch = IncomingHttpMessage::parse(r#"[{"jsonrpc":"2.0"}]"#.to_string()).unwrap_err(); + assert!(matches!( + batch, + HttpTransportError::NotImplemented("batch JSON-RPC requests are not supported") + )); + + let invalid = IncomingHttpMessage::parse("{".to_string()).unwrap_err(); + assert!(matches!( + invalid, + HttpTransportError::BadRequest("invalid JSON-RPC payload") + )); + } + + #[test] + fn incoming_http_message_session_id_helpers_handle_valid_and_invalid_values() { + let message = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"session/prompt","params":{"sessionId":"session-1"}}"# + .to_string(), + ) + .unwrap(); + assert_eq!(message.params_session_id().unwrap(), Some(session_id())); + + let invalid = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"session/prompt","params":{"sessionId":"bad.session"}}"# + .to_string(), + ) + .unwrap(); + assert!(matches!( + invalid.params_session_id(), + Err(HttpTransportError::BadRequest( + "invalid sessionId in request body" + )) + )); + } + + #[test] + fn outgoing_http_message_extracts_session_ids() { + let outbound = OutgoingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"params":{"sessionId":"session-1"}}"#, + ) + .unwrap(); + assert_eq!(outbound.params_session_id(), Some(session_id())); + + let response = OutgoingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"result":{"sessionId":"session-1"}}"#, + ) + .unwrap(); + assert_eq!(response.result_session_id(), Some(session_id())); + } + + #[test] + fn header_validators_enforce_content_negotiation() { + let valid_post = post_headers(); + assert!(validate_post_headers(&valid_post).is_ok()); + + let mut bad_content_type = valid_post.clone(); + bad_content_type.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + assert!(matches!( + validate_post_headers(&bad_content_type), + Err(HttpTransportError::UnsupportedMediaType( + "Content-Type must be application/json" + )) + )); + + let mut bad_accept = HeaderMap::new(); + bad_accept.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + bad_accept.insert(ACCEPT, HeaderValue::from_static("application/json")); + assert!(matches!( + validate_post_headers(&bad_accept), + Err(HttpTransportError::NotAcceptable( + "Accept must include application/json and text/event-stream" + )) + )); + + let valid_get = get_headers(); + assert!(validate_get_headers(&valid_get).is_ok()); + + let mut invalid_get = HeaderMap::new(); + invalid_get.insert(ACCEPT, HeaderValue::from_static("application/json")); + assert!(matches!( + validate_get_headers(&invalid_get), + Err(HttpTransportError::NotAcceptable( + "Accept must include text/event-stream" + )) + )); + } + + #[test] + fn validate_http_context_enforces_connection_and_session_rules() { + let initialize = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"# + .to_string(), + ) + .unwrap(); + let connection_id = AcpConnectionId::new(); + let session_id = session_id(); + + assert!(validate_http_context(&initialize, None, None).is_ok()); + assert!(matches!( + validate_http_context(&initialize, Some(&connection_id), None), + Err(HttpTransportError::BadRequest( + "initialize must not include Acp-Connection-Id" + )) + )); + + let initialized = + IncomingHttpMessage::parse(r#"{"jsonrpc":"2.0","method":"initialized"}"#.to_string()) + .unwrap(); + assert!(matches!( + validate_http_context(&initialized, None, None), + Err(HttpTransportError::BadRequest( + "missing Acp-Connection-Id header" + )) + )); + + let prompt = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/prompt","params":{"sessionId":"session-1"}}"# + .to_string(), + ) + .unwrap(); + assert!(matches!( + validate_http_context(&prompt, Some(&connection_id), None), + Err(HttpTransportError::BadRequest( + "missing Acp-Session-Id header" + )) + )); + + let mismatched = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/prompt","params":{"sessionId":"session-2"}}"# + .to_string(), + ) + .unwrap(); + assert!(matches!( + validate_http_context(&mismatched, Some(&connection_id), Some(&session_id)), + Err(HttpTransportError::BadRequest( + "Acp-Session-Id header does not match request body sessionId" + )) + )); + } + + #[test] + fn header_parsers_and_websocket_detection_handle_valid_and_invalid_values() { + let connection_id = AcpConnectionId::new(); + let session_id = session_id(); + let mut headers = HeaderMap::new(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + headers.insert( + ACP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id.as_str()).unwrap(), + ); + headers.insert("upgrade", HeaderValue::from_static("websocket")); + + assert_eq!( + parse_connection_id_header(&headers).unwrap(), + Some(connection_id.clone()) + ); + assert_eq!(parse_session_id_header(&headers).unwrap(), Some(session_id)); + assert!(is_websocket_request(&headers)); + + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_static("not-a-uuid"), + ); + assert!(matches!( + parse_connection_id_header(&headers), + Err(HttpTransportError::BadRequest( + "invalid Acp-Connection-Id header" + )) + )); + } + + #[tokio::test] + async fn http_post_returns_accepted_for_notifications() { + let (state, mut manager_rx) = test_state(); + let connection_id = AcpConnectionId::new(); + let expected_connection_id = connection_id.clone(); + + tokio::spawn(async move { + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpPost { + connection_id: Some(actual_connection_id), + session_id: None, + message, + response, + .. + } => { + assert_eq!(actual_connection_id, expected_connection_id); + assert!(message.is_notification()); + let _ = response.send(Ok(HttpPostOutcome::Accepted)); + } + _ => panic!("unexpected manager request"), + } + }); + + let mut headers = post_headers(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + + let response = http_post( + headers, + state, + r#"{"jsonrpc":"2.0","method":"initialized"}"#.to_string(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::ACCEPTED); + } + + #[tokio::test] + async fn http_post_returns_buffered_sse_with_session_headers() { + let (state, mut manager_rx) = test_state(); + let connection_id = AcpConnectionId::new(); + let session_id = session_id(); + let event = json!({ + "jsonrpc": "2.0", + "id": 2, + "result": { "sessionId": session_id.as_str() } + }); + let expected_event = event.clone(); + + let expected_connection_id = connection_id.clone(); + let expected_session_id = session_id.clone(); + tokio::spawn(async move { + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpPost { + connection_id: Some(actual_connection_id), + session_id: None, + message, + response, + .. + } => { + assert_eq!(actual_connection_id, expected_connection_id.clone()); + assert!(message.is_session_new()); + let _ = response.send(Ok(HttpPostOutcome::Buffered { + connection_id: expected_connection_id, + session_id: Some(expected_session_id), + events: vec![SseFrame::Json(event.to_string())], + })); + } + _ => panic!("unexpected manager request"), + } + }); + + let mut headers = post_headers(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + + let response = http_post( + headers, + state, + r#"{"jsonrpc":"2.0","id":2,"method":"session/new","params":{"cwd":".","mcpServers":[]}}"# + .to_string(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(ACP_CONNECTION_ID_HEADER).unwrap(), + HeaderValue::from_str(&connection_id.to_string()).unwrap() + ); + assert_eq!( + response.headers().get(ACP_SESSION_ID_HEADER).unwrap(), + HeaderValue::from_str(session_id.as_str()).unwrap() + ); + assert_eq!(json_event_body(response).await, vec![expected_event]); + } + + #[tokio::test] + async fn http_get_and_delete_round_trip_through_manager() { + let (state, mut manager_rx) = test_state(); + let connection_id = AcpConnectionId::new(); + let session_id = session_id(); + let expected_connection_id = connection_id.clone(); + let expected_session_id = session_id.clone(); + + tokio::spawn(async move { + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpGet { + connection_id: actual_connection_id, + session_id: actual_session_id, + response, + } => { + assert_eq!(actual_connection_id, expected_connection_id.clone()); + assert_eq!(actual_session_id, expected_session_id); + let (stream_tx, stream_rx) = mpsc::unbounded_channel(); + let _ = stream_tx.send(SseFrame::Json( + json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { "sessionId": "session-1" } + }) + .to_string(), + )); + drop(stream_tx); + let _ = response.send(Ok(stream_rx)); + } + _ => panic!("unexpected manager request"), + } + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpDelete { + connection_id: actual_connection_id, + response, + } => { + assert_eq!(actual_connection_id, expected_connection_id); + let _ = response.send(Ok(())); + } + _ => panic!("unexpected manager request"), + } + }); + + let mut headers = get_headers(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + headers.insert( + ACP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id.as_str()).unwrap(), + ); + + let response = http_get(headers, state.clone()).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(X_ACCEL_BUFFERING_HEADER).unwrap(), + HeaderValue::from_static("no") + ); + assert_eq!( + json_event_body(response).await, + vec![json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { "sessionId": "session-1" } + })] + ); + + let mut delete_headers = HeaderMap::new(); + delete_headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + + let response = http_delete(delete_headers, state).await.unwrap(); + assert_eq!(response.status(), StatusCode::ACCEPTED); + } + + #[tokio::test] + async fn get_rejects_incomplete_websocket_upgrade() { + let (state, _manager_rx) = test_state(); + let request = HttpRequest::builder() + .method("GET") + .uri("/acp") + .header("upgrade", "websocket") + .body(Body::empty()) + .unwrap(); + + let response = get(State(state), request).await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn process_manager_request_rejects_invalid_or_unknown_http_targets() { + let nats_client = AdvancedMockNatsClient::new(); + let js_client = MockJs::new(); + let config = test_config(); + let mut http_connections = HashMap::new(); + let mut websocket_handles = Vec::new(); + let (_shutdown_tx, shutdown_rx) = watch::channel(false); + + let (post_response_tx, post_response_rx) = oneshot::channel(); + let post_message = + IncomingHttpMessage::parse(r#"{"jsonrpc":"2.0","method":"initialized"}"#.to_string()) + .unwrap(); + process_manager_request( + ManagerRequest::HttpPost { + connection_id: None, + session_id: None, + message: post_message, + response: post_response_tx, + shutdown_rx: shutdown_rx.clone(), + }, + &mut http_connections, + &mut websocket_handles, + &nats_client, + &js_client, + &config, + ) + .await; + assert!(matches!( + post_response_rx.await.unwrap(), + Err(HttpTransportError::BadRequest( + "missing Acp-Connection-Id header" + )) + )); + + let unknown_connection_id = AcpConnectionId::new(); + let (get_response_tx, get_response_rx) = oneshot::channel(); + process_manager_request( + ManagerRequest::HttpGet { + connection_id: unknown_connection_id.clone(), + session_id: session_id(), + response: get_response_tx, + }, + &mut http_connections, + &mut websocket_handles, + &nats_client, + &js_client, + &config, + ) + .await; + assert!(matches!( + get_response_rx.await.unwrap(), + Err(HttpTransportError::NotFound("unknown ACP connection")) + )); + + let (delete_response_tx, delete_response_rx) = oneshot::channel(); + process_manager_request( + ManagerRequest::HttpDelete { + connection_id: unknown_connection_id, + response: delete_response_tx, + }, + &mut http_connections, + &mut websocket_handles, + &nats_client, + &js_client, + &config, + ) + .await; + assert!(matches!( + delete_response_rx.await.unwrap(), + Err(HttpTransportError::NotFound("unknown ACP connection")) + )); + } +} From 5c738a7d331e9e2c092cf6bf5f16d0086b0cbfdd Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 15:02:18 -0400 Subject: [PATCH 05/23] fix(acp-nats-ws): keep websocket logs accurate Signed-off-by: Yordis Prieto --- .../crates/acp-nats-ws/src/connection.rs | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/connection.rs b/rsworkspace/crates/acp-nats-ws/src/connection.rs index 8e54ced6e..0cfa840d7 100644 --- a/rsworkspace/crates/acp-nats-ws/src/connection.rs +++ b/rsworkspace/crates/acp-nats-ws/src/connection.rs @@ -130,34 +130,24 @@ async fn run_recv_pump( mut ws_recv_write: tokio::io::DuplexStream, ) { while let Some(Ok(msg)) = ws_receiver.next().await { - let bytes = match msg { - Message::Text(t) => bytes::Bytes::from(t), + let text = match msg { + Message::Text(text) => text, Message::Binary(_) => continue, Message::Close(_) => break, _ => continue, }; - match std::str::from_utf8(&bytes) { - Ok(text) => { - let line = text.trim_end_matches(['\r', '\n']); - if line.is_empty() { - continue; - } + let line = text.trim_end_matches(['\r', '\n']); + if line.is_empty() { + continue; + } - if ws_recv_write.write_all(line.as_bytes()).await.is_err() { - break; - } + if ws_recv_write.write_all(line.as_bytes()).await.is_err() { + break; + } - if ws_recv_write.write_all(b"\n").await.is_err() { - break; - } - } - Err(e) => { - warn!( - error = %e, - "Received non-UTF-8 WebSocket message, dropping frame" - ); - } + if ws_recv_write.write_all(b"\n").await.is_err() { + break; } } } From 0646d6bcdc406577cab85d422e91dbd652d8f222 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 15:17:03 -0400 Subject: [PATCH 06/23] fix(acp-nats-ws): drain HTTP connections on shutdown Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/src/main.rs | 11 ++- .../crates/acp-nats-ws/src/transport.rs | 70 ++++++++++++++++++- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs index 270b13f28..eb2cc40ae 100644 --- a/rsworkspace/crates/acp-nats-ws/src/main.rs +++ b/rsworkspace/crates/acp-nats-ws/src/main.rs @@ -161,6 +161,7 @@ async fn process_connections( trogon_nats::jetstream::JsMessageOf: trogon_nats::jetstream::JsRequestMessage, { let mut websocket_handles: Vec> = Vec::new(); + let mut http_connection_handles: Vec> = Vec::new(); let mut http_connections = std::collections::HashMap::new(); while let Some(request) = manager_rx.recv().await { @@ -168,6 +169,7 @@ async fn process_connections( request, &mut http_connections, &mut websocket_handles, + &mut http_connection_handles, &nats_client, &js_client, &config, @@ -178,7 +180,11 @@ async fn process_connections( let active = websocket_handles .iter() .filter(|h| !h.is_finished()) - .count(); + .count() + + http_connection_handles + .iter() + .filter(|h| !h.is_finished()) + .count(); info!( active_connections = active, "Connection channel closed, draining active connections" @@ -187,6 +193,9 @@ async fn process_connections( for handle in websocket_handles { let _ = handle.await; } + for handle in http_connection_handles { + let _ = handle.await; + } info!("All connections drained"); } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index d379377e4..264f25719 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -901,6 +901,7 @@ pub async fn process_manager_request( request: ManagerRequest, http_connections: &mut HashMap, websocket_handles: &mut Vec>, + http_connection_handles: &mut Vec>, nats_client: &N, js_client: &J, config: &acp_nats::Config, @@ -916,6 +917,7 @@ pub async fn process_manager_request( trogon_nats::jetstream::JsMessageOf: trogon_nats::jetstream::JsRequestMessage, { websocket_handles.retain(|handle| !handle.is_finished()); + http_connection_handles.retain(|handle| !handle.is_finished()); match request { ManagerRequest::WebSocket(request) => { @@ -953,14 +955,14 @@ pub async fn process_manager_request( command_tx: command_tx.clone(), }, ); - tokio::task::spawn_local(run_http_connection( + http_connection_handles.push(tokio::task::spawn_local(run_http_connection( connection_id.clone(), nats_client.clone(), js_client.clone(), config.clone(), command_rx, shutdown_rx, - )); + ))); connection_id } }; @@ -1577,6 +1579,7 @@ mod tests { let config = test_config(); let mut http_connections = HashMap::new(); let mut websocket_handles = Vec::new(); + let mut http_connection_handles = Vec::new(); let (_shutdown_tx, shutdown_rx) = watch::channel(false); let (post_response_tx, post_response_rx) = oneshot::channel(); @@ -1593,6 +1596,7 @@ mod tests { }, &mut http_connections, &mut websocket_handles, + &mut http_connection_handles, &nats_client, &js_client, &config, @@ -1615,6 +1619,7 @@ mod tests { }, &mut http_connections, &mut websocket_handles, + &mut http_connection_handles, &nats_client, &js_client, &config, @@ -1633,6 +1638,7 @@ mod tests { }, &mut http_connections, &mut websocket_handles, + &mut http_connection_handles, &nats_client, &js_client, &config, @@ -1643,4 +1649,64 @@ mod tests { Err(HttpTransportError::NotFound("unknown ACP connection")) )); } + + #[tokio::test(flavor = "current_thread")] + async fn process_manager_request_tracks_http_connection_tasks() { + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let nats_client = AdvancedMockNatsClient::new(); + let _injector = nats_client.inject_messages(); + nats_client.hang_next_request(); + + let js_client = MockJs::new(); + let config = test_config(); + let mut http_connections = HashMap::new(); + let mut websocket_handles = Vec::new(); + let mut http_connection_handles = Vec::new(); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let (response_tx, response_rx) = oneshot::channel(); + let initialize = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"# + .to_string(), + ) + .unwrap(); + + process_manager_request( + ManagerRequest::HttpPost { + connection_id: None, + session_id: None, + message: initialize, + response: response_tx, + shutdown_rx, + }, + &mut http_connections, + &mut websocket_handles, + &mut http_connection_handles, + &nats_client, + &js_client, + &config, + ) + .await; + + assert_eq!(http_connections.len(), 1); + assert_eq!(http_connection_handles.len(), 1); + assert!(!http_connection_handles[0].is_finished()); + assert!(matches!( + response_rx.await.unwrap(), + Ok(HttpPostOutcome::Live { .. }) + )); + + let _ = shutdown_tx.send(true); + tokio::time::timeout(std::time::Duration::from_secs(2), async { + while !http_connection_handles[0].is_finished() { + tokio::task::yield_now().await; + } + }) + .await + .expect("HTTP connection task did not finish after shutdown"); + }) + .await; + } } From bb6d8a08c27b60d9e725335ab2824592f948eed5 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 16:28:48 -0400 Subject: [PATCH 07/23] fix(acp-nats-ws): preserve session stream boundaries Signed-off-by: Yordis Prieto --- .../crates/acp-nats-ws/src/transport.rs | 162 ++++++++++++++++-- 1 file changed, 149 insertions(+), 13 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 264f25719..fccafb4b2 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -153,6 +153,7 @@ impl SseFrame { enum PendingRequest { Live { request_id: RequestId, + session_id: Option, sender: SseSender, }, Buffered { @@ -267,6 +268,62 @@ impl OutgoingHttpMessage { } } +enum LiveFrameOutcome { + Keep, + Clear, + Drop, +} + +fn dispatch_to_get_listeners( + frame: &SseFrame, + session_id: &acp_nats::AcpSessionId, + get_listeners: &mut HashMap>, +) -> bool { + let Some(listeners) = get_listeners.get_mut(session_id) else { + return false; + }; + + listeners.retain(|listener| listener.send(frame.clone()).is_ok()); + !listeners.is_empty() +} + +fn route_live_frame( + frame: &SseFrame, + parsed: Option<&OutgoingHttpMessage>, + request_id: &RequestId, + request_session_id: Option<&acp_nats::AcpSessionId>, + sender: &SseSender, + get_listeners: &mut HashMap>, +) -> LiveFrameOutcome { + if parsed.and_then(|message| message.id.as_ref()) == Some(request_id) { + return if sender.send(frame.clone()).is_ok() { + LiveFrameOutcome::Clear + } else { + LiveFrameOutcome::Drop + }; + } + + if let Some(frame_session_id) = parsed.and_then(OutgoingHttpMessage::params_session_id) { + if request_session_id == Some(&frame_session_id) { + return if sender.send(frame.clone()).is_ok() { + LiveFrameOutcome::Keep + } else { + LiveFrameOutcome::Drop + }; + } + + if dispatch_to_get_listeners(frame, &frame_session_id, get_listeners) { + return LiveFrameOutcome::Keep; + } + } + + if sender.send(frame.clone()).is_ok() { + LiveFrameOutcome::Keep + } else { + LiveFrameOutcome::Drop + } +} + pub async fn get(State(state): State, request: Request) -> Response { if is_websocket_request(request.headers()) { let (mut parts, _body) = request.into_parts(); @@ -760,6 +817,7 @@ pub async fn run_http_connection( let (stream_tx, stream_rx) = mpsc::unbounded_channel(); pending_request = Some(PendingRequest::Live { request_id: message.id.clone().expect("request must have id"), + session_id: session_id.clone(), sender: stream_tx, }); let _ = input_tx.send(message.raw); @@ -812,14 +870,23 @@ pub async fn run_http_connection( if let Some(pending) = pending_request.as_mut() { match pending { - PendingRequest::Live { request_id, sender, .. } => { - if sender.send(frame.clone()).is_err() { - pending_request = None; - continue; - } - - if parsed.as_ref().and_then(|message| message.id.as_ref()) == Some(request_id) { - pending_request = None; + PendingRequest::Live { + request_id, + session_id, + sender, + } => { + match route_live_frame( + &frame, + parsed.as_ref(), + request_id, + session_id.as_ref(), + sender, + &mut get_listeners, + ) { + LiveFrameOutcome::Keep => {} + LiveFrameOutcome::Clear | LiveFrameOutcome::Drop => { + pending_request = None; + } } continue; } @@ -848,11 +915,7 @@ pub async fn run_http_connection( continue; }; - let Some(listeners) = get_listeners.get_mut(&session_id) else { - continue; - }; - - listeners.retain(|listener| listener.send(frame.clone()).is_ok()); + let _ = dispatch_to_get_listeners(&frame, &session_id, &mut get_listeners); } result = &mut client_task => { match result { @@ -1254,6 +1317,79 @@ mod tests { assert_eq!(response.result_session_id(), Some(session_id())); } + #[test] + fn route_live_frame_keeps_same_session_notifications_on_post_stream() { + let frame = SseFrame::Json( + r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-1"}}"# + .to_string(), + ); + let parsed = OutgoingHttpMessage::parse(match &frame { + SseFrame::Json(json) => json, + }) + .unwrap(); + let request_id = RequestId::Number(1); + let request_session_id = session_id(); + let (live_tx, mut live_rx) = mpsc::unbounded_channel(); + let (get_tx, mut get_rx) = mpsc::unbounded_channel(); + let mut get_listeners = HashMap::new(); + get_listeners.insert(request_session_id.clone(), vec![get_tx]); + + let outcome = route_live_frame( + &frame, + Some(&parsed), + &request_id, + Some(&request_session_id), + &live_tx, + &mut get_listeners, + ); + + assert!(matches!(outcome, LiveFrameOutcome::Keep)); + match live_rx.try_recv().unwrap() { + SseFrame::Json(json) => assert!(json.contains(r#""sessionId":"session-1""#)), + } + assert!(matches!( + get_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) + )); + } + + #[test] + fn route_live_frame_sends_other_session_notifications_to_get_listeners() { + let frame = SseFrame::Json( + r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-2"}}"# + .to_string(), + ); + let parsed = OutgoingHttpMessage::parse(match &frame { + SseFrame::Json(json) => json, + }) + .unwrap(); + let request_id = RequestId::Number(1); + let request_session_id = session_id(); + let other_session_id = acp_nats::AcpSessionId::new("session-2").unwrap(); + let (live_tx, mut live_rx) = mpsc::unbounded_channel(); + let (get_tx, mut get_rx) = mpsc::unbounded_channel(); + let mut get_listeners = HashMap::new(); + get_listeners.insert(other_session_id, vec![get_tx]); + + let outcome = route_live_frame( + &frame, + Some(&parsed), + &request_id, + Some(&request_session_id), + &live_tx, + &mut get_listeners, + ); + + assert!(matches!(outcome, LiveFrameOutcome::Keep)); + assert!(matches!( + live_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) + )); + match get_rx.try_recv().unwrap() { + SseFrame::Json(json) => assert!(json.contains(r#""sessionId":"session-2""#)), + } + } + #[test] fn header_validators_enforce_content_negotiation() { let valid_post = post_headers(); From a252d07adfb11d6f291bab8e22a538a28d64f914 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 16:31:54 -0400 Subject: [PATCH 08/23] fix(acp-nats-ws): accept proxy-shaped headers Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/src/transport.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index fccafb4b2..fe0268f6b 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -587,8 +587,9 @@ fn validate_http_context( fn accept_contains(header: &str, expected: &str) -> bool { header .split(',') + .filter_map(|value| value.split(';').next()) .map(str::trim) - .any(|value| value.eq_ignore_ascii_case(expected)) + .any(|media_type| media_type.eq_ignore_ascii_case(expected)) } fn parse_connection_id_header( @@ -1417,6 +1418,18 @@ mod tests { let valid_get = get_headers(); assert!(validate_get_headers(&valid_get).is_ok()); + let mut valid_post_with_q = HeaderMap::new(); + valid_post_with_q.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + valid_post_with_q.insert( + ACCEPT, + HeaderValue::from_static("application/json;q=0.9, text/event-stream"), + ); + assert!(validate_post_headers(&valid_post_with_q).is_ok()); + + let mut valid_get_with_q = HeaderMap::new(); + valid_get_with_q.insert(ACCEPT, HeaderValue::from_static("text/event-stream; q=0.5")); + assert!(validate_get_headers(&valid_get_with_q).is_ok()); + let mut invalid_get = HeaderMap::new(); invalid_get.insert(ACCEPT, HeaderValue::from_static("application/json")); assert!(matches!( From d7d81ac2b88bbbd4f953a667e40ba934ff575957 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 16:49:28 -0400 Subject: [PATCH 09/23] fix(acp-nats-ws): protect HTTP streams from slow clients Signed-off-by: Yordis Prieto --- .../crates/acp-nats-ws/src/transport.rs | 143 ++++++++++++++---- 1 file changed, 116 insertions(+), 27 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index fe0268f6b..bbcb1e31d 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -22,8 +22,10 @@ use tokio::sync::{mpsc, oneshot, watch}; use tracing::{error, info, warn}; use trogon_std::time::SystemClock; -type SseSender = mpsc::UnboundedSender; -type SseReceiver = mpsc::UnboundedReceiver; +const HTTP_CHANNEL_CAPACITY: usize = 64; + +type SseSender = mpsc::Sender; +type SseReceiver = mpsc::Receiver; #[derive(Clone)] pub struct AppState { @@ -274,17 +276,61 @@ enum LiveFrameOutcome { Drop, } +enum ListenerDispatch { + Missing, + Delivered, + Dropped, +} + +fn try_send_sse_frame(sender: &SseSender, frame: &SseFrame) -> bool { + match sender.try_send(frame.clone()) { + Ok(()) => true, + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => false, + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => false, + } +} + fn dispatch_to_get_listeners( frame: &SseFrame, session_id: &acp_nats::AcpSessionId, get_listeners: &mut HashMap>, -) -> bool { - let Some(listeners) = get_listeners.get_mut(session_id) else { - return false; +) -> ListenerDispatch { + let Some(_) = get_listeners.get(session_id) else { + return ListenerDispatch::Missing; }; - listeners.retain(|listener| listener.send(frame.clone()).is_ok()); - !listeners.is_empty() + let (outcome, remove_session) = { + let listeners = get_listeners + .get_mut(session_id) + .expect("session listeners must exist after presence check"); + let mut delivered = false; + listeners.retain(|listener| match listener.try_send(frame.clone()) { + Ok(()) => { + delivered = true; + true + } + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + warn!(session_id = %session_id, "Dropping stalled HTTP SSE listener"); + false + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => false, + }); + + ( + if delivered { + ListenerDispatch::Delivered + } else { + ListenerDispatch::Dropped + }, + listeners.is_empty(), + ) + }; + + if remove_session { + get_listeners.remove(session_id); + } + + outcome } fn route_live_frame( @@ -296,30 +342,36 @@ fn route_live_frame( get_listeners: &mut HashMap>, ) -> LiveFrameOutcome { if parsed.and_then(|message| message.id.as_ref()) == Some(request_id) { - return if sender.send(frame.clone()).is_ok() { + return if try_send_sse_frame(sender, frame) { LiveFrameOutcome::Clear } else { + warn!(request_id = ?request_id, "Dropping stalled HTTP POST response stream"); LiveFrameOutcome::Drop }; } if let Some(frame_session_id) = parsed.and_then(OutgoingHttpMessage::params_session_id) { if request_session_id == Some(&frame_session_id) { - return if sender.send(frame.clone()).is_ok() { + return if try_send_sse_frame(sender, frame) { LiveFrameOutcome::Keep } else { + warn!(request_id = ?request_id, session_id = %frame_session_id, "Dropping stalled HTTP POST response stream"); LiveFrameOutcome::Drop }; } - if dispatch_to_get_listeners(frame, &frame_session_id, get_listeners) { - return LiveFrameOutcome::Keep; + match dispatch_to_get_listeners(frame, &frame_session_id, get_listeners) { + ListenerDispatch::Delivered | ListenerDispatch::Dropped => { + return LiveFrameOutcome::Keep; + } + ListenerDispatch::Missing => {} } } - if sender.send(frame.clone()).is_ok() { + if try_send_sse_frame(sender, frame) { LiveFrameOutcome::Keep } else { + warn!(request_id = ?request_id, "Dropping stalled HTTP POST response stream"); LiveFrameOutcome::Drop } } @@ -736,7 +788,7 @@ pub async fn run_http_connection( let connection = Rc::new(connection); spawn_notification_forwarder(connection.clone(), notification_rx); - let (input_tx, mut input_rx) = mpsc::unbounded_channel::(); + let (input_tx, mut input_rx) = mpsc::channel::(HTTP_CHANNEL_CAPACITY); let input_task = tokio::task::spawn_local(async move { while let Some(message) = input_rx.recv().await { if input_write.write_all(message.as_bytes()).await.is_err() { @@ -748,7 +800,7 @@ pub async fn run_http_connection( } }); - let (output_tx, mut output_rx) = mpsc::unbounded_channel::(); + let (output_tx, mut output_rx) = mpsc::channel::(HTTP_CHANNEL_CAPACITY); let output_task = tokio::task::spawn_local(async move { let mut reader = tokio::io::BufReader::new(&mut output_read); let mut line = String::new(); @@ -761,7 +813,7 @@ pub async fn run_http_connection( if trimmed.is_empty() { continue; } - if output_tx.send(trimmed.to_string()).is_err() { + if output_tx.send(trimmed.to_string()).await.is_err() { break; } } @@ -807,7 +859,14 @@ pub async fn run_http_connection( events: Vec::new(), response, }); - let _ = input_tx.send(message.raw); + if input_tx.try_send(message.raw).is_err() + && let Some(PendingRequest::Buffered { response, .. }) = + pending_request.take() + { + let _ = response.send(Err(HttpTransportError::Internal( + "ACP runtime input queue is full", + ))); + } continue; } @@ -815,13 +874,19 @@ pub async fn run_http_connection( sessions.insert(session_id); } - let (stream_tx, stream_rx) = mpsc::unbounded_channel(); + let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); pending_request = Some(PendingRequest::Live { request_id: message.id.clone().expect("request must have id"), session_id: session_id.clone(), sender: stream_tx, }); - let _ = input_tx.send(message.raw); + if input_tx.try_send(message.raw).is_err() { + pending_request = None; + let _ = response.send(Err(HttpTransportError::Internal( + "ACP runtime input queue is full", + ))); + continue; + } let _ = response.send(Ok(HttpPostOutcome::Live { connection_id: connection_id.clone(), session_id, @@ -834,9 +899,9 @@ pub async fn run_http_connection( sessions.insert(session_id); } - if input_tx.send(message.raw).is_err() { + if input_tx.try_send(message.raw).is_err() { let _ = response.send(Err(HttpTransportError::Internal( - "failed to forward HTTP payload into ACP runtime", + "ACP runtime input queue is full", ))); continue; } @@ -851,7 +916,7 @@ pub async fn run_http_connection( continue; } - let (stream_tx, stream_rx) = mpsc::unbounded_channel(); + let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); get_listeners.entry(session_id).or_default().push(stream_tx); let _ = response.send(Ok(stream_rx)); } @@ -1330,8 +1395,8 @@ mod tests { .unwrap(); let request_id = RequestId::Number(1); let request_session_id = session_id(); - let (live_tx, mut live_rx) = mpsc::unbounded_channel(); - let (get_tx, mut get_rx) = mpsc::unbounded_channel(); + let (live_tx, mut live_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let (get_tx, mut get_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); let mut get_listeners = HashMap::new(); get_listeners.insert(request_session_id.clone(), vec![get_tx]); @@ -1367,8 +1432,8 @@ mod tests { let request_id = RequestId::Number(1); let request_session_id = session_id(); let other_session_id = acp_nats::AcpSessionId::new("session-2").unwrap(); - let (live_tx, mut live_rx) = mpsc::unbounded_channel(); - let (get_tx, mut get_rx) = mpsc::unbounded_channel(); + let (live_tx, mut live_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let (get_tx, mut get_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); let mut get_listeners = HashMap::new(); get_listeners.insert(other_session_id, vec![get_tx]); @@ -1391,6 +1456,30 @@ mod tests { } } + #[test] + fn dispatch_to_get_listeners_drops_full_listener() { + let session_id = session_id(); + let mut get_listeners = HashMap::new(); + let (listener_tx, mut listener_rx) = mpsc::channel(1); + listener_tx + .try_send(SseFrame::Json( + r#"{"jsonrpc":"2.0","method":"session/update"}"#.to_string(), + )) + .unwrap(); + get_listeners.insert(session_id.clone(), vec![listener_tx]); + + let frame = SseFrame::Json(r#"{"jsonrpc":"2.0","method":"session/update"}"#.to_string()); + let outcome = dispatch_to_get_listeners(&frame, &session_id, &mut get_listeners); + + assert!(matches!(outcome, ListenerDispatch::Dropped)); + assert!(!get_listeners.contains_key(&session_id)); + assert!(matches!(listener_rx.try_recv(), Ok(SseFrame::Json(_)))); + assert!(matches!( + listener_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) + )); + } + #[test] fn header_validators_enforce_content_negotiation() { let valid_post = post_headers(); @@ -1646,8 +1735,8 @@ mod tests { } => { assert_eq!(actual_connection_id, expected_connection_id.clone()); assert_eq!(actual_session_id, expected_session_id); - let (stream_tx, stream_rx) = mpsc::unbounded_channel(); - let _ = stream_tx.send(SseFrame::Json( + let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let _ = stream_tx.try_send(SseFrame::Json( json!({ "jsonrpc": "2.0", "method": "session/update", From 38fbedcde41d1592677f47e42ce9e78e93b78635 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 17:24:11 -0400 Subject: [PATCH 10/23] fix(acp-nats-ws): preserve HTTP transport failure context Signed-off-by: Yordis Prieto --- rsworkspace/Cargo.lock | 1 + rsworkspace/Cargo.toml | 1 + rsworkspace/crates/acp-nats-ws/Cargo.toml | 1 + .../crates/acp-nats-ws/src/transport.rs | 414 ++++++++++++------ 4 files changed, 273 insertions(+), 144 deletions(-) diff --git a/rsworkspace/Cargo.lock b/rsworkspace/Cargo.lock index 2f720feae..ef886ab40 100644 --- a/rsworkspace/Cargo.lock +++ b/rsworkspace/Cargo.lock @@ -78,6 +78,7 @@ dependencies = [ "opentelemetry", "serde", "serde_json", + "thiserror 2.0.18", "tokio", "tokio-tungstenite 0.29.0", "tower", diff --git a/rsworkspace/Cargo.toml b/rsworkspace/Cargo.toml index d4be98c9b..1753406d1 100644 --- a/rsworkspace/Cargo.toml +++ b/rsworkspace/Cargo.toml @@ -59,6 +59,7 @@ opentelemetry_sdk = "=0.31.0" tracing = "=0.1.44" tracing-opentelemetry = "=0.32.1" tracing-subscriber = "=0.3.23" +thiserror = "=2.0.18" # Discord twilight-gateway = { version = "=0.17.1", default-features = false, features = ["rustls-native-roots"] } diff --git a/rsworkspace/crates/acp-nats-ws/Cargo.toml b/rsworkspace/crates/acp-nats-ws/Cargo.toml index 9850b06d3..62367ef8d 100644 --- a/rsworkspace/crates/acp-nats-ws/Cargo.toml +++ b/rsworkspace/crates/acp-nats-ws/Cargo.toml @@ -20,6 +20,7 @@ opentelemetry = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal", "net", "sync", "io-util"] } +thiserror = { workspace = true } tracing = { workspace = true } trogon-nats = { workspace = true } trogon-std = { workspace = true, features = ["telemetry-http"] } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index bbcb1e31d..76bb13e7f 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -1,4 +1,4 @@ -use crate::acp_connection_id::{AcpConnectionId, AcpConnectionIdError}; +use crate::acp_connection_id::AcpConnectionId; use crate::connection; use crate::constants::{ACP_CONNECTION_ID_HEADER, ACP_SESSION_ID_HEADER, X_ACCEL_BUFFERING_HEADER}; use acp_nats::{StdJsonSerialize, agent::Bridge, client, spawn_notification_forwarder}; @@ -17,6 +17,7 @@ use serde_json::Value; use std::collections::{HashMap, HashSet}; use std::convert::Infallible; use std::rc::Rc; +use thiserror::Error; use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot, watch}; use tracing::{error, info, warn}; @@ -24,6 +25,7 @@ use trogon_std::time::SystemClock; const HTTP_CHANNEL_CAPACITY: usize = 64; +type BoxError = Box; type SseSender = mpsc::Sender; type SseReceiver = mpsc::Receiver; @@ -74,48 +76,139 @@ pub enum HttpPostOutcome { }, } -#[derive(Debug)] +#[derive(Debug, Error)] pub enum HttpTransportError { - BadRequest(&'static str), - NotFound(&'static str), - Conflict(&'static str), - UnsupportedMediaType(&'static str), - NotAcceptable(&'static str), - NotImplemented(&'static str), - Internal(&'static str), + #[error("{message}")] + BadRequest { + message: &'static str, + #[source] + source: Option, + }, + #[error("{message}")] + NotFound { + message: &'static str, + #[source] + source: Option, + }, + #[error("{message}")] + Conflict { + message: &'static str, + #[source] + source: Option, + }, + #[error("{message}")] + UnsupportedMediaType { + message: &'static str, + #[source] + source: Option, + }, + #[error("{message}")] + NotAcceptable { + message: &'static str, + #[source] + source: Option, + }, + #[error("{message}")] + NotImplemented { + message: &'static str, + #[source] + source: Option, + }, + #[error("{message}")] + Internal { + message: &'static str, + #[source] + source: Option, + }, } impl HttpTransportError { - fn into_response(self) -> Response { - let (status, message) = match self { - Self::BadRequest(message) => (StatusCode::BAD_REQUEST, message), - Self::NotFound(message) => (StatusCode::NOT_FOUND, message), - Self::Conflict(message) => (StatusCode::CONFLICT, message), - Self::UnsupportedMediaType(message) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, message), - Self::NotAcceptable(message) => (StatusCode::NOT_ACCEPTABLE, message), - Self::NotImplemented(message) => (StatusCode::NOT_IMPLEMENTED, message), - Self::Internal(message) => (StatusCode::INTERNAL_SERVER_ERROR, message), - }; + fn bad_request(message: &'static str) -> Self { + Self::BadRequest { + message, + source: None, + } + } - (status, message).into_response() + fn bad_request_with( + message: &'static str, + source: impl std::error::Error + Send + 'static, + ) -> Self { + Self::BadRequest { + message, + source: Some(Box::new(source)), + } + } + + fn not_found(message: &'static str) -> Self { + Self::NotFound { + message, + source: None, + } + } + + fn conflict(message: &'static str) -> Self { + Self::Conflict { + message, + source: None, + } + } + + fn unsupported_media_type(message: &'static str) -> Self { + Self::UnsupportedMediaType { + message, + source: None, + } + } + + fn not_acceptable(message: &'static str) -> Self { + Self::NotAcceptable { + message, + source: None, + } } -} -impl std::fmt::Display for HttpTransportError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn not_implemented(message: &'static str) -> Self { + Self::NotImplemented { + message, + source: None, + } + } + + fn internal(message: &'static str) -> Self { + Self::Internal { + message, + source: None, + } + } + + fn internal_with( + message: &'static str, + source: impl std::error::Error + Send + 'static, + ) -> Self { + Self::Internal { + message, + source: Some(Box::new(source)), + } + } + + fn status_code(&self) -> StatusCode { match self { - Self::BadRequest(message) - | Self::NotFound(message) - | Self::Conflict(message) - | Self::UnsupportedMediaType(message) - | Self::NotAcceptable(message) - | Self::NotImplemented(message) - | Self::Internal(message) => f.write_str(message), + Self::BadRequest { .. } => StatusCode::BAD_REQUEST, + Self::NotFound { .. } => StatusCode::NOT_FOUND, + Self::Conflict { .. } => StatusCode::CONFLICT, + Self::UnsupportedMediaType { .. } => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::NotAcceptable { .. } => StatusCode::NOT_ACCEPTABLE, + Self::NotImplemented { .. } => StatusCode::NOT_IMPLEMENTED, + Self::Internal { .. } => StatusCode::INTERNAL_SERVER_ERROR, } } -} -impl std::error::Error for HttpTransportError {} + fn into_response(self) -> Response { + let status = self.status_code(); + (status, self.to_string()).into_response() + } +} #[derive(Debug)] pub struct HttpConnectionHandle { @@ -180,13 +273,13 @@ impl IncomingHttpMessage { pub fn parse(raw: String) -> Result { let trimmed = raw.trim_start(); if trimmed.starts_with('[') { - return Err(HttpTransportError::NotImplemented( + return Err(HttpTransportError::not_implemented( "batch JSON-RPC requests are not supported", )); } let mut parsed = serde_json::from_str::(&raw) - .map_err(|_| HttpTransportError::BadRequest("invalid JSON-RPC payload"))?; + .map_err(|error| HttpTransportError::bad_request_with("invalid JSON-RPC payload", error))?; parsed.raw = raw; Ok(parsed) } @@ -241,7 +334,9 @@ impl IncomingHttpMessage { acp_nats::AcpSessionId::new(session_id) .map(Some) - .map_err(|_| HttpTransportError::BadRequest("invalid sessionId in request body")) + .map_err(|error| { + HttpTransportError::bad_request_with("invalid sessionId in request body", error) + }) } } @@ -382,7 +477,8 @@ pub async fn get(State(state): State, request: Request) -> Response { match WebSocketUpgrade::from_request_parts(&mut parts, &state).await { Ok(ws) => websocket_response(ws, state), Err(_) => { - HttpTransportError::BadRequest("invalid WebSocket upgrade request").into_response() + HttpTransportError::bad_request("invalid WebSocket upgrade request") + .into_response() } } } else { @@ -444,9 +540,7 @@ async fn http_post( let message = IncomingHttpMessage::parse(body)?; if !(message.is_request() || message.is_notification() || message.is_response()) { - return Err(HttpTransportError::BadRequest( - "invalid JSON-RPC message shape", - )); + return Err(HttpTransportError::bad_request("invalid JSON-RPC message shape")); } let connection_id = parse_connection_id_header(&headers)?; @@ -464,11 +558,15 @@ async fn http_post( response: response_tx, shutdown_rx: state.shutdown_tx.subscribe(), }) - .map_err(|_| HttpTransportError::Internal("connection manager is unavailable"))?; + .map_err(|error| { + HttpTransportError::internal_with("connection manager is unavailable", error) + })?; match response_rx .await - .map_err(|_| HttpTransportError::Internal("connection manager dropped the request"))?? + .map_err(|error| { + HttpTransportError::internal_with("connection manager dropped the request", error) + })?? { HttpPostOutcome::Accepted => Ok(StatusCode::ACCEPTED.into_response()), HttpPostOutcome::Live { @@ -491,12 +589,10 @@ async fn http_post( async fn http_get(headers: HeaderMap, state: AppState) -> Result { validate_get_headers(&headers)?; - let connection_id = parse_connection_id_header(&headers)?.ok_or( - HttpTransportError::BadRequest("missing Acp-Connection-Id header"), - )?; - let session_id = parse_session_id_header(&headers)?.ok_or(HttpTransportError::BadRequest( - "missing Acp-Session-Id header", - ))?; + let connection_id = parse_connection_id_header(&headers)? + .ok_or(HttpTransportError::bad_request("missing Acp-Connection-Id header"))?; + let session_id = parse_session_id_header(&headers)? + .ok_or(HttpTransportError::bad_request("missing Acp-Session-Id header"))?; let (response_tx, response_rx) = oneshot::channel(); state @@ -506,11 +602,15 @@ async fn http_get(headers: HeaderMap, state: AppState) -> Result Result Result { - let connection_id = parse_connection_id_header(&headers)?.ok_or( - HttpTransportError::BadRequest("missing Acp-Connection-Id header"), - )?; + let connection_id = parse_connection_id_header(&headers)? + .ok_or(HttpTransportError::bad_request("missing Acp-Connection-Id header"))?; let (response_tx, response_rx) = oneshot::channel(); state @@ -537,11 +636,15 @@ async fn http_delete(headers: HeaderMap, state: AppState) -> Result Result<(), HttpTransportError> { Some(value) if value.eq_ignore_ascii_case("application/json") => {} _ => { - return Err(HttpTransportError::UnsupportedMediaType( + return Err(HttpTransportError::unsupported_media_type( "Content-Type must be application/json", )); } @@ -562,13 +665,13 @@ fn validate_post_headers(headers: &HeaderMap) -> Result<(), HttpTransportError> let accept = headers .get(ACCEPT) .and_then(|value| value.to_str().ok()) - .ok_or(HttpTransportError::NotAcceptable( + .ok_or(HttpTransportError::not_acceptable( "Accept must include application/json and text/event-stream", ))?; if !accept_contains(accept, "application/json") || !accept_contains(accept, "text/event-stream") { - return Err(HttpTransportError::NotAcceptable( + return Err(HttpTransportError::not_acceptable( "Accept must include application/json and text/event-stream", )); } @@ -580,12 +683,12 @@ fn validate_get_headers(headers: &HeaderMap) -> Result<(), HttpTransportError> { let accept = headers .get(ACCEPT) .and_then(|value| value.to_str().ok()) - .ok_or(HttpTransportError::NotAcceptable( + .ok_or(HttpTransportError::not_acceptable( "Accept must include text/event-stream", ))?; if !accept_contains(accept, "text/event-stream") { - return Err(HttpTransportError::NotAcceptable( + return Err(HttpTransportError::not_acceptable( "Accept must include text/event-stream", )); } @@ -600,12 +703,12 @@ fn validate_http_context( ) -> Result<(), HttpTransportError> { if message.is_initialize() { if connection_id.is_some() { - return Err(HttpTransportError::BadRequest( + return Err(HttpTransportError::bad_request( "initialize must not include Acp-Connection-Id", )); } if session_id.is_some() { - return Err(HttpTransportError::BadRequest( + return Err(HttpTransportError::bad_request( "initialize must not include Acp-Session-Id", )); } @@ -613,14 +716,14 @@ fn validate_http_context( } if connection_id.is_none() { - return Err(HttpTransportError::BadRequest( + return Err(HttpTransportError::bad_request( "missing Acp-Connection-Id header", )); } let body_session_id = message.params_session_id()?; if message.requires_session_id() && session_id.is_none() { - return Err(HttpTransportError::BadRequest( + return Err(HttpTransportError::bad_request( "missing Acp-Session-Id header", )); } @@ -628,7 +731,7 @@ fn validate_http_context( if let (Some(header_session_id), Some(body_session_id)) = (session_id, body_session_id.as_ref()) && header_session_id != body_session_id { - return Err(HttpTransportError::BadRequest( + return Err(HttpTransportError::bad_request( "Acp-Session-Id header does not match request body sessionId", )); } @@ -652,12 +755,15 @@ fn parse_connection_id_header( .map(|value| { value .to_str() - .map_err(|_| HttpTransportError::BadRequest("invalid Acp-Connection-Id header")) + .map_err(|error| { + HttpTransportError::bad_request_with("invalid Acp-Connection-Id header", error) + }) .and_then(|value| { - AcpConnectionId::parse(value).map_err(|error| match error { - AcpConnectionIdError::InvalidUuid(_) => { - HttpTransportError::BadRequest("invalid Acp-Connection-Id header") - } + AcpConnectionId::parse(value).map_err(|error| { + HttpTransportError::bad_request_with( + "invalid Acp-Connection-Id header", + error, + ) }) }) }) @@ -672,10 +778,15 @@ fn parse_session_id_header( .map(|value| { value .to_str() - .map_err(|_| HttpTransportError::BadRequest("invalid Acp-Session-Id header")) + .map_err(|error| { + HttpTransportError::bad_request_with("invalid Acp-Session-Id header", error) + }) .and_then(|value| { - acp_nats::AcpSessionId::new(value).map_err(|_| { - HttpTransportError::BadRequest("invalid Acp-Session-Id header") + acp_nats::AcpSessionId::new(value).map_err(|error| { + HttpTransportError::bad_request_with( + "invalid Acp-Session-Id header", + error, + ) }) }) }) @@ -847,7 +958,7 @@ pub async fn run_http_connection( HttpConnectionCommand::Post { session_id, message, response } => { if message.is_request() { if pending_request.is_some() { - let _ = response.send(Err(HttpTransportError::Conflict( + let _ = response.send(Err(HttpTransportError::conflict( "only one in-flight HTTP request is supported per ACP connection", ))); continue; @@ -863,7 +974,7 @@ pub async fn run_http_connection( && let Some(PendingRequest::Buffered { response, .. }) = pending_request.take() { - let _ = response.send(Err(HttpTransportError::Internal( + let _ = response.send(Err(HttpTransportError::internal( "ACP runtime input queue is full", ))); } @@ -882,7 +993,7 @@ pub async fn run_http_connection( }); if input_tx.try_send(message.raw).is_err() { pending_request = None; - let _ = response.send(Err(HttpTransportError::Internal( + let _ = response.send(Err(HttpTransportError::internal( "ACP runtime input queue is full", ))); continue; @@ -900,7 +1011,7 @@ pub async fn run_http_connection( } if input_tx.try_send(message.raw).is_err() { - let _ = response.send(Err(HttpTransportError::Internal( + let _ = response.send(Err(HttpTransportError::internal( "ACP runtime input queue is full", ))); continue; @@ -910,9 +1021,8 @@ pub async fn run_http_connection( } HttpConnectionCommand::AttachListener { session_id, response } => { if !sessions.contains(&session_id) { - let _ = response.send(Err(HttpTransportError::NotFound( - "unknown ACP session", - ))); + let _ = response + .send(Err(HttpTransportError::not_found("unknown ACP session"))); continue; } @@ -1006,7 +1116,7 @@ pub async fn run_http_connection( } if let Some(PendingRequest::Buffered { response, .. }) = pending_request.take() { - let _ = response.send(Err(HttpTransportError::Internal( + let _ = response.send(Err(HttpTransportError::internal( "HTTP connection closed before the request completed", ))); } @@ -1070,7 +1180,7 @@ pub async fn process_manager_request( Some(connection_id) => connection_id, None => { if !message.is_initialize() { - let _ = response.send(Err(HttpTransportError::BadRequest( + let _ = response.send(Err(HttpTransportError::bad_request( "missing Acp-Connection-Id header", ))); return; @@ -1097,7 +1207,7 @@ pub async fn process_manager_request( }; let Some(handle) = http_connections.get(&connection_id) else { - let _ = response.send(Err(HttpTransportError::NotFound("unknown ACP connection"))); + let _ = response.send(Err(HttpTransportError::not_found("unknown ACP connection"))); return; }; @@ -1119,7 +1229,7 @@ pub async fn process_manager_request( response, } => { let Some(handle) = http_connections.get(&connection_id) else { - let _ = response.send(Err(HttpTransportError::NotFound("unknown ACP connection"))); + let _ = response.send(Err(HttpTransportError::not_found("unknown ACP connection"))); return; }; @@ -1139,7 +1249,7 @@ pub async fn process_manager_request( response, } => { let Some(handle) = http_connections.remove(&connection_id) else { - let _ = response.send(Err(HttpTransportError::NotFound("unknown ACP connection"))); + let _ = response.send(Err(HttpTransportError::not_found("unknown ACP connection"))); return; }; @@ -1259,32 +1369,23 @@ mod tests { #[test] fn http_transport_error_into_response_maps_status_codes() { let cases = [ + (HttpTransportError::bad_request("bad"), StatusCode::BAD_REQUEST), + (HttpTransportError::not_found("missing"), StatusCode::NOT_FOUND), + (HttpTransportError::conflict("conflict"), StatusCode::CONFLICT), ( - HttpTransportError::BadRequest("bad"), - StatusCode::BAD_REQUEST, - ), - ( - HttpTransportError::NotFound("missing"), - StatusCode::NOT_FOUND, - ), - ( - HttpTransportError::Conflict("conflict"), - StatusCode::CONFLICT, - ), - ( - HttpTransportError::UnsupportedMediaType("unsupported"), + HttpTransportError::unsupported_media_type("unsupported"), StatusCode::UNSUPPORTED_MEDIA_TYPE, ), ( - HttpTransportError::NotAcceptable("not-acceptable"), + HttpTransportError::not_acceptable("not-acceptable"), StatusCode::NOT_ACCEPTABLE, ), ( - HttpTransportError::NotImplemented("not-implemented"), + HttpTransportError::not_implemented("not-implemented"), StatusCode::NOT_IMPLEMENTED, ), ( - HttpTransportError::Internal("internal"), + HttpTransportError::internal("internal"), StatusCode::INTERNAL_SERVER_ERROR, ), ]; @@ -1292,19 +1393,22 @@ mod tests { for (error, status) in cases { assert_eq!(error.into_response().status(), status); } + + let sourced = IncomingHttpMessage::parse("{".to_string()).unwrap_err(); + assert_eq!(sourced.into_response().status(), StatusCode::BAD_REQUEST); } #[test] - fn http_transport_error_display_returns_message() { - assert_eq!(HttpTransportError::BadRequest("bad").to_string(), "bad"); - assert_eq!( - HttpTransportError::Conflict("conflict").to_string(), - "conflict" - ); - assert_eq!( - HttpTransportError::Internal("internal").to_string(), - "internal" - ); + fn http_transport_error_display_and_source_chain_return_expected_values() { + use std::error::Error as _; + + assert_eq!(HttpTransportError::bad_request("bad").to_string(), "bad"); + assert_eq!(HttpTransportError::conflict("conflict").to_string(), "conflict"); + assert_eq!(HttpTransportError::internal("internal").to_string(), "internal"); + + let invalid_json = IncomingHttpMessage::parse("{".to_string()).unwrap_err(); + assert_eq!(invalid_json.to_string(), "invalid JSON-RPC payload"); + assert!(invalid_json.source().is_some()); } #[test] @@ -1336,13 +1440,19 @@ mod tests { let batch = IncomingHttpMessage::parse(r#"[{"jsonrpc":"2.0"}]"#.to_string()).unwrap_err(); assert!(matches!( batch, - HttpTransportError::NotImplemented("batch JSON-RPC requests are not supported") + HttpTransportError::NotImplemented { + message: "batch JSON-RPC requests are not supported", + source: None, + } )); let invalid = IncomingHttpMessage::parse("{".to_string()).unwrap_err(); assert!(matches!( invalid, - HttpTransportError::BadRequest("invalid JSON-RPC payload") + HttpTransportError::BadRequest { + message: "invalid JSON-RPC payload", + source: Some(_), + } )); } @@ -1362,9 +1472,10 @@ mod tests { .unwrap(); assert!(matches!( invalid.params_session_id(), - Err(HttpTransportError::BadRequest( - "invalid sessionId in request body" - )) + Err(HttpTransportError::BadRequest { + message: "invalid sessionId in request body", + source: Some(_), + }) )); } @@ -1489,9 +1600,10 @@ mod tests { bad_content_type.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); assert!(matches!( validate_post_headers(&bad_content_type), - Err(HttpTransportError::UnsupportedMediaType( - "Content-Type must be application/json" - )) + Err(HttpTransportError::UnsupportedMediaType { + message: "Content-Type must be application/json", + source: None, + }) )); let mut bad_accept = HeaderMap::new(); @@ -1499,9 +1611,10 @@ mod tests { bad_accept.insert(ACCEPT, HeaderValue::from_static("application/json")); assert!(matches!( validate_post_headers(&bad_accept), - Err(HttpTransportError::NotAcceptable( - "Accept must include application/json and text/event-stream" - )) + Err(HttpTransportError::NotAcceptable { + message: "Accept must include application/json and text/event-stream", + source: None, + }) )); let valid_get = get_headers(); @@ -1523,9 +1636,10 @@ mod tests { invalid_get.insert(ACCEPT, HeaderValue::from_static("application/json")); assert!(matches!( validate_get_headers(&invalid_get), - Err(HttpTransportError::NotAcceptable( - "Accept must include text/event-stream" - )) + Err(HttpTransportError::NotAcceptable { + message: "Accept must include text/event-stream", + source: None, + }) )); } @@ -1542,9 +1656,10 @@ mod tests { assert!(validate_http_context(&initialize, None, None).is_ok()); assert!(matches!( validate_http_context(&initialize, Some(&connection_id), None), - Err(HttpTransportError::BadRequest( - "initialize must not include Acp-Connection-Id" - )) + Err(HttpTransportError::BadRequest { + message: "initialize must not include Acp-Connection-Id", + source: None, + }) )); let initialized = @@ -1552,9 +1667,10 @@ mod tests { .unwrap(); assert!(matches!( validate_http_context(&initialized, None, None), - Err(HttpTransportError::BadRequest( - "missing Acp-Connection-Id header" - )) + Err(HttpTransportError::BadRequest { + message: "missing Acp-Connection-Id header", + source: None, + }) )); let prompt = IncomingHttpMessage::parse( @@ -1564,9 +1680,10 @@ mod tests { .unwrap(); assert!(matches!( validate_http_context(&prompt, Some(&connection_id), None), - Err(HttpTransportError::BadRequest( - "missing Acp-Session-Id header" - )) + Err(HttpTransportError::BadRequest { + message: "missing Acp-Session-Id header", + source: None, + }) )); let mismatched = IncomingHttpMessage::parse( @@ -1576,9 +1693,10 @@ mod tests { .unwrap(); assert!(matches!( validate_http_context(&mismatched, Some(&connection_id), Some(&session_id)), - Err(HttpTransportError::BadRequest( - "Acp-Session-Id header does not match request body sessionId" - )) + Err(HttpTransportError::BadRequest { + message: "Acp-Session-Id header does not match request body sessionId", + source: None, + }) )); } @@ -1610,9 +1728,10 @@ mod tests { ); assert!(matches!( parse_connection_id_header(&headers), - Err(HttpTransportError::BadRequest( - "invalid Acp-Connection-Id header" - )) + Err(HttpTransportError::BadRequest { + message: "invalid Acp-Connection-Id header", + source: Some(_), + }) )); } @@ -1842,9 +1961,10 @@ mod tests { .await; assert!(matches!( post_response_rx.await.unwrap(), - Err(HttpTransportError::BadRequest( - "missing Acp-Connection-Id header" - )) + Err(HttpTransportError::BadRequest { + message: "missing Acp-Connection-Id header", + source: None, + }) )); let unknown_connection_id = AcpConnectionId::new(); @@ -1865,7 +1985,10 @@ mod tests { .await; assert!(matches!( get_response_rx.await.unwrap(), - Err(HttpTransportError::NotFound("unknown ACP connection")) + Err(HttpTransportError::NotFound { + message: "unknown ACP connection", + source: None, + }) )); let (delete_response_tx, delete_response_rx) = oneshot::channel(); @@ -1884,7 +2007,10 @@ mod tests { .await; assert!(matches!( delete_response_rx.await.unwrap(), - Err(HttpTransportError::NotFound("unknown ACP connection")) + Err(HttpTransportError::NotFound { + message: "unknown ACP connection", + source: None, + }) )); } From 92146c6faf0ded89c8759428ea800c7fb6151112 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 17:26:58 -0400 Subject: [PATCH 11/23] fix(acp-nats-ws): keep GET stream headers consistent Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/src/transport.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 76bb13e7f..76aea9df0 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -598,8 +598,8 @@ async fn http_get(headers: HeaderMap, state: AppState) -> Result Result(item.into_event()), stream)) })) .into_response(); + set_transport_headers(response.headers_mut(), &connection_id, Some(&session_id)); response .headers_mut() .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); @@ -1892,6 +1893,14 @@ mod tests { let response = http_get(headers, state.clone()).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(ACP_CONNECTION_ID_HEADER).unwrap(), + HeaderValue::from_str(&connection_id.to_string()).unwrap() + ); + assert_eq!( + response.headers().get(ACP_SESSION_ID_HEADER).unwrap(), + HeaderValue::from_str(session_id.as_str()).unwrap() + ); assert_eq!( response.headers().get(X_ACCEL_BUFFERING_HEADER).unwrap(), HeaderValue::from_static("no") From bd7602774c8877b814bfbe3dea5e4dc8180e3868 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 17:36:46 -0400 Subject: [PATCH 12/23] fix(acp-nats-ws): accept null JSON-RPC responses Signed-off-by: Yordis Prieto --- .../crates/acp-nats-ws/src/transport.rs | 88 ++++++++++++++++--- 1 file changed, 78 insertions(+), 10 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 76aea9df0..7ec75effc 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -42,7 +42,7 @@ pub struct ConnectionRequest { } pub enum ManagerRequest { - WebSocket(ConnectionRequest), + WebSocket(Box), HttpPost { connection_id: Option, session_id: Option, @@ -263,10 +263,12 @@ pub struct IncomingHttpMessage { pub id: Option, pub method: Option, pub params: Option, - pub result: Option, - pub error: Option, #[serde(skip)] pub raw: String, + #[serde(skip)] + has_result: bool, + #[serde(skip)] + has_error: bool, } impl IncomingHttpMessage { @@ -278,9 +280,17 @@ impl IncomingHttpMessage { )); } - let mut parsed = serde_json::from_str::(&raw) + let value = serde_json::from_str::(&raw) + .map_err(|error| HttpTransportError::bad_request_with("invalid JSON-RPC payload", error))?; + let (has_result, has_error) = value + .as_object() + .map(|object| (object.contains_key("result"), object.contains_key("error"))) + .unwrap_or((false, false)); + let mut parsed = serde_json::from_value::(value) .map_err(|error| HttpTransportError::bad_request_with("invalid JSON-RPC payload", error))?; parsed.raw = raw; + parsed.has_result = has_result; + parsed.has_error = has_error; Ok(parsed) } @@ -295,7 +305,7 @@ impl IncomingHttpMessage { fn is_response(&self) -> bool { self.id.is_some() && self.method.is_none() - && (self.result.is_some() || self.error.is_some()) + && (self.has_result || self.has_error) } fn method_name(&self) -> Option<&str> { @@ -515,11 +525,11 @@ fn websocket_response(ws: WebSocketUpgrade, state: AppState) -> Response { let mut response = ws.on_upgrade(move |socket| async move { if state .manager_tx - .send(ManagerRequest::WebSocket(ConnectionRequest { + .send(ManagerRequest::WebSocket(Box::new(ConnectionRequest { connection_id, socket, shutdown_rx, - })) + }))) .is_err() { error!("Connection thread is gone; dropping WebSocket"); @@ -1161,13 +1171,18 @@ pub async fn process_manager_request( match request { ManagerRequest::WebSocket(request) => { + let ConnectionRequest { + connection_id, + socket, + shutdown_rx, + } = *request; websocket_handles.push(tokio::task::spawn_local(connection::handle( - request.connection_id, - request.socket, + connection_id, + socket, nats_client.clone(), js_client.clone(), config.clone(), - request.shutdown_rx, + shutdown_rx, ))); } ManagerRequest::HttpPost { @@ -1434,6 +1449,12 @@ mod tests { .unwrap(); assert!(response.is_response()); assert!(response.requires_session_id()); + + let null_response = + IncomingHttpMessage::parse(r#"{"jsonrpc":"2.0","id":100,"result":null}"#.to_string()) + .unwrap(); + assert!(null_response.is_response()); + assert!(null_response.requires_session_id()); } #[test] @@ -1776,6 +1797,53 @@ mod tests { assert_eq!(response.status(), StatusCode::ACCEPTED); } + #[tokio::test] + async fn http_post_accepts_null_result_responses() { + let (state, mut manager_rx) = test_state(); + let connection_id = AcpConnectionId::new(); + let session_id = session_id(); + let expected_connection_id = connection_id.clone(); + let expected_session_id = session_id.clone(); + + tokio::spawn(async move { + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpPost { + connection_id: Some(actual_connection_id), + session_id: Some(actual_session_id), + message, + response, + .. + } => { + assert_eq!(actual_connection_id, expected_connection_id); + assert_eq!(actual_session_id, expected_session_id); + assert!(message.is_response()); + let _ = response.send(Ok(HttpPostOutcome::Accepted)); + } + _ => panic!("unexpected manager request"), + } + }); + + let mut headers = post_headers(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + headers.insert( + ACP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id.as_str()).unwrap(), + ); + + let response = http_post( + headers, + state, + r#"{"jsonrpc":"2.0","id":1,"result":null}"#.to_string(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::ACCEPTED); + } + #[tokio::test] async fn http_post_returns_buffered_sse_with_session_headers() { let (state, mut manager_rx) = test_state(); From 55623d6ec7dbd92e34f5667d9a11882ff73dd630 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 17:45:29 -0400 Subject: [PATCH 13/23] fix(acp-nats-ws): close transport review gaps Signed-off-by: Yordis Prieto --- .../crates/acp-nats-ws/src/transport.rs | 255 ++++++++++++++---- 1 file changed, 198 insertions(+), 57 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 7ec75effc..1a2e7fdb9 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -280,14 +280,16 @@ impl IncomingHttpMessage { )); } - let value = serde_json::from_str::(&raw) - .map_err(|error| HttpTransportError::bad_request_with("invalid JSON-RPC payload", error))?; + let value = serde_json::from_str::(&raw).map_err(|error| { + HttpTransportError::bad_request_with("invalid JSON-RPC payload", error) + })?; let (has_result, has_error) = value .as_object() .map(|object| (object.contains_key("result"), object.contains_key("error"))) .unwrap_or((false, false)); - let mut parsed = serde_json::from_value::(value) - .map_err(|error| HttpTransportError::bad_request_with("invalid JSON-RPC payload", error))?; + let mut parsed = serde_json::from_value::(value).map_err(|error| { + HttpTransportError::bad_request_with("invalid JSON-RPC payload", error) + })?; parsed.raw = raw; parsed.has_result = has_result; parsed.has_error = has_error; @@ -303,9 +305,7 @@ impl IncomingHttpMessage { } fn is_response(&self) -> bool { - self.id.is_some() - && self.method.is_none() - && (self.has_result || self.has_error) + self.id.is_some() && self.method.is_none() && (self.has_result || self.has_error) } fn method_name(&self) -> Option<&str> { @@ -381,6 +381,14 @@ enum LiveFrameOutcome { Drop, } +enum BufferedFrameOutcome { + Buffered, + Routed, + Finalize { + session_id: Option, + }, +} + enum ListenerDispatch { Missing, Delivered, @@ -481,14 +489,40 @@ fn route_live_frame( } } +fn route_buffered_frame( + frame: &SseFrame, + parsed: Option<&OutgoingHttpMessage>, + request_id: &RequestId, + events: &mut Vec, + get_listeners: &mut HashMap>, +) -> BufferedFrameOutcome { + if parsed.and_then(|message| message.id.as_ref()) == Some(request_id) { + events.push(frame.clone()); + return BufferedFrameOutcome::Finalize { + session_id: parsed.and_then(OutgoingHttpMessage::result_session_id), + }; + } + + if let Some(frame_session_id) = parsed.and_then(OutgoingHttpMessage::params_session_id) { + match dispatch_to_get_listeners(frame, &frame_session_id, get_listeners) { + ListenerDispatch::Delivered | ListenerDispatch::Dropped => { + return BufferedFrameOutcome::Routed; + } + ListenerDispatch::Missing => {} + } + } + + events.push(frame.clone()); + BufferedFrameOutcome::Buffered +} + pub async fn get(State(state): State, request: Request) -> Response { if is_websocket_request(request.headers()) { let (mut parts, _body) = request.into_parts(); match WebSocketUpgrade::from_request_parts(&mut parts, &state).await { Ok(ws) => websocket_response(ws, state), Err(_) => { - HttpTransportError::bad_request("invalid WebSocket upgrade request") - .into_response() + HttpTransportError::bad_request("invalid WebSocket upgrade request").into_response() } } } else { @@ -550,7 +584,9 @@ async fn http_post( let message = IncomingHttpMessage::parse(body)?; if !(message.is_request() || message.is_notification() || message.is_response()) { - return Err(HttpTransportError::bad_request("invalid JSON-RPC message shape")); + return Err(HttpTransportError::bad_request( + "invalid JSON-RPC message shape", + )); } let connection_id = parse_connection_id_header(&headers)?; @@ -572,12 +608,9 @@ async fn http_post( HttpTransportError::internal_with("connection manager is unavailable", error) })?; - match response_rx - .await - .map_err(|error| { - HttpTransportError::internal_with("connection manager dropped the request", error) - })?? - { + match response_rx.await.map_err(|error| { + HttpTransportError::internal_with("connection manager dropped the request", error) + })?? { HttpPostOutcome::Accepted => Ok(StatusCode::ACCEPTED.into_response()), HttpPostOutcome::Live { connection_id, @@ -599,10 +632,12 @@ async fn http_post( async fn http_get(headers: HeaderMap, state: AppState) -> Result { validate_get_headers(&headers)?; - let connection_id = parse_connection_id_header(&headers)? - .ok_or(HttpTransportError::bad_request("missing Acp-Connection-Id header"))?; - let session_id = parse_session_id_header(&headers)? - .ok_or(HttpTransportError::bad_request("missing Acp-Session-Id header"))?; + let connection_id = parse_connection_id_header(&headers)?.ok_or( + HttpTransportError::bad_request("missing Acp-Connection-Id header"), + )?; + let session_id = parse_session_id_header(&headers)?.ok_or(HttpTransportError::bad_request( + "missing Acp-Session-Id header", + ))?; let (response_tx, response_rx) = oneshot::channel(); state @@ -616,11 +651,9 @@ async fn http_get(headers: HeaderMap, state: AppState) -> Result Result Result { - let connection_id = parse_connection_id_header(&headers)? - .ok_or(HttpTransportError::bad_request("missing Acp-Connection-Id header"))?; + let connection_id = parse_connection_id_header(&headers)?.ok_or( + HttpTransportError::bad_request("missing Acp-Connection-Id header"), + )?; let (response_tx, response_rx) = oneshot::channel(); state @@ -651,11 +685,9 @@ async fn http_delete(headers: HeaderMap, state: AppState) -> Result Result<(), HttpTransportError> .get(CONTENT_TYPE) .and_then(|value| value.to_str().ok()) { - Some(value) if value.eq_ignore_ascii_case("application/json") => {} + Some(value) if media_type_matches(value, "application/json") => {} _ => { return Err(HttpTransportError::unsupported_media_type( "Content-Type must be application/json", @@ -753,9 +785,16 @@ fn validate_http_context( fn accept_contains(header: &str, expected: &str) -> bool { header .split(',') - .filter_map(|value| value.split(';').next()) .map(str::trim) - .any(|media_type| media_type.eq_ignore_ascii_case(expected)) + .any(|media_type| media_type_matches(media_type, expected)) +} + +fn media_type_matches(header_value: &str, expected: &str) -> bool { + header_value + .split(';') + .next() + .map(str::trim) + .is_some_and(|media_type| media_type.eq_ignore_ascii_case(expected)) } fn parse_connection_id_header( @@ -794,10 +833,7 @@ fn parse_session_id_header( }) .and_then(|value| { acp_nats::AcpSessionId::new(value).map_err(|error| { - HttpTransportError::bad_request_with( - "invalid Acp-Session-Id header", - error, - ) + HttpTransportError::bad_request_with("invalid Acp-Session-Id header", error) }) }) }) @@ -1078,19 +1114,28 @@ pub async fn run_http_connection( continue; } PendingRequest::Buffered { request_id, events, .. } => { - events.push(frame); - if parsed.as_ref().and_then(|message| message.id.as_ref()) == Some(request_id) { - let session_id = parsed.and_then(|message| message.result_session_id()); - if let Some(session_id) = session_id.clone() { - sessions.insert(session_id); - } - let events = std::mem::take(events); - if let Some(PendingRequest::Buffered { response, .. }) = pending_request.take() { - let _ = response.send(Ok(HttpPostOutcome::Buffered { - connection_id: connection_id.clone(), - session_id, - events, - })); + match route_buffered_frame( + &frame, + parsed.as_ref(), + request_id, + events, + &mut get_listeners, + ) { + BufferedFrameOutcome::Buffered | BufferedFrameOutcome::Routed => {} + BufferedFrameOutcome::Finalize { session_id } => { + if let Some(session_id) = session_id.clone() { + sessions.insert(session_id); + } + let events = std::mem::take(events); + if let Some(PendingRequest::Buffered { response, .. }) = + pending_request.take() + { + let _ = response.send(Ok(HttpPostOutcome::Buffered { + connection_id: connection_id.clone(), + session_id, + events, + })); + } } } continue; @@ -1168,6 +1213,7 @@ pub async fn process_manager_request( { websocket_handles.retain(|handle| !handle.is_finished()); http_connection_handles.retain(|handle| !handle.is_finished()); + http_connections.retain(|_, handle| !handle.command_tx.is_closed()); match request { ManagerRequest::WebSocket(request) => { @@ -1385,9 +1431,18 @@ mod tests { #[test] fn http_transport_error_into_response_maps_status_codes() { let cases = [ - (HttpTransportError::bad_request("bad"), StatusCode::BAD_REQUEST), - (HttpTransportError::not_found("missing"), StatusCode::NOT_FOUND), - (HttpTransportError::conflict("conflict"), StatusCode::CONFLICT), + ( + HttpTransportError::bad_request("bad"), + StatusCode::BAD_REQUEST, + ), + ( + HttpTransportError::not_found("missing"), + StatusCode::NOT_FOUND, + ), + ( + HttpTransportError::conflict("conflict"), + StatusCode::CONFLICT, + ), ( HttpTransportError::unsupported_media_type("unsupported"), StatusCode::UNSUPPORTED_MEDIA_TYPE, @@ -1419,8 +1474,14 @@ mod tests { use std::error::Error as _; assert_eq!(HttpTransportError::bad_request("bad").to_string(), "bad"); - assert_eq!(HttpTransportError::conflict("conflict").to_string(), "conflict"); - assert_eq!(HttpTransportError::internal("internal").to_string(), "internal"); + assert_eq!( + HttpTransportError::conflict("conflict").to_string(), + "conflict" + ); + assert_eq!( + HttpTransportError::internal("internal").to_string(), + "internal" + ); let invalid_json = IncomingHttpMessage::parse("{".to_string()).unwrap_err(); assert_eq!(invalid_json.to_string(), "invalid JSON-RPC payload"); @@ -1589,6 +1650,38 @@ mod tests { } } + #[test] + fn route_buffered_frame_sends_other_session_notifications_to_get_listeners() { + let frame = SseFrame::Json( + r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-2"}}"# + .to_string(), + ); + let parsed = OutgoingHttpMessage::parse(match &frame { + SseFrame::Json(json) => json, + }) + .unwrap(); + let request_id = RequestId::Number(1); + let other_session_id = acp_nats::AcpSessionId::new("session-2").unwrap(); + let (get_tx, mut get_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let mut events = Vec::new(); + let mut get_listeners = HashMap::new(); + get_listeners.insert(other_session_id, vec![get_tx]); + + let outcome = route_buffered_frame( + &frame, + Some(&parsed), + &request_id, + &mut events, + &mut get_listeners, + ); + + assert!(matches!(outcome, BufferedFrameOutcome::Routed)); + assert!(events.is_empty()); + match get_rx.try_recv().unwrap() { + SseFrame::Json(json) => assert!(json.contains(r#""sessionId":"session-2""#)), + } + } + #[test] fn dispatch_to_get_listeners_drops_full_listener() { let session_id = session_id(); @@ -1618,6 +1711,13 @@ mod tests { let valid_post = post_headers(); assert!(validate_post_headers(&valid_post).is_ok()); + let mut valid_post_with_charset = post_headers(); + valid_post_with_charset.insert( + CONTENT_TYPE, + HeaderValue::from_static("application/json; charset=utf-8"), + ); + assert!(validate_post_headers(&valid_post_with_charset).is_ok()); + let mut bad_content_type = valid_post.clone(); bad_content_type.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); assert!(matches!( @@ -2091,6 +2191,47 @@ mod tests { )); } + #[tokio::test] + async fn process_manager_request_prunes_closed_http_connections() { + let nats_client = AdvancedMockNatsClient::new(); + let js_client = MockJs::new(); + let config = test_config(); + let mut http_connections = HashMap::new(); + let mut websocket_handles = Vec::new(); + let mut http_connection_handles = Vec::new(); + + let stale_connection_id = AcpConnectionId::new(); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + drop(command_rx); + http_connections.insert(stale_connection_id, HttpConnectionHandle { command_tx }); + + let unknown_connection_id = AcpConnectionId::new(); + let (response_tx, response_rx) = oneshot::channel(); + process_manager_request( + ManagerRequest::HttpGet { + connection_id: unknown_connection_id, + session_id: session_id(), + response: response_tx, + }, + &mut http_connections, + &mut websocket_handles, + &mut http_connection_handles, + &nats_client, + &js_client, + &config, + ) + .await; + + assert!(http_connections.is_empty()); + assert!(matches!( + response_rx.await.unwrap(), + Err(HttpTransportError::NotFound { + message: "unknown ACP connection", + source: None, + }) + )); + } + #[tokio::test(flavor = "current_thread")] async fn process_manager_request_tracks_http_connection_tasks() { let local = tokio::task::LocalSet::new(); From ff6742a559609e5d8118eb1f9f002ca7a9dd7cf8 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 18:04:37 -0400 Subject: [PATCH 14/23] refactor(acp-nats-ws): align transport errors with workspace style Signed-off-by: Yordis Prieto --- rsworkspace/Cargo.lock | 1 - rsworkspace/Cargo.toml | 1 - rsworkspace/crates/acp-nats-ws/Cargo.toml | 1 - .../crates/acp-nats-ws/src/transport.rs | 55 +++++++++++++------ 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/rsworkspace/Cargo.lock b/rsworkspace/Cargo.lock index ef886ab40..2f720feae 100644 --- a/rsworkspace/Cargo.lock +++ b/rsworkspace/Cargo.lock @@ -78,7 +78,6 @@ dependencies = [ "opentelemetry", "serde", "serde_json", - "thiserror 2.0.18", "tokio", "tokio-tungstenite 0.29.0", "tower", diff --git a/rsworkspace/Cargo.toml b/rsworkspace/Cargo.toml index 1753406d1..d4be98c9b 100644 --- a/rsworkspace/Cargo.toml +++ b/rsworkspace/Cargo.toml @@ -59,7 +59,6 @@ opentelemetry_sdk = "=0.31.0" tracing = "=0.1.44" tracing-opentelemetry = "=0.32.1" tracing-subscriber = "=0.3.23" -thiserror = "=2.0.18" # Discord twilight-gateway = { version = "=0.17.1", default-features = false, features = ["rustls-native-roots"] } diff --git a/rsworkspace/crates/acp-nats-ws/Cargo.toml b/rsworkspace/crates/acp-nats-ws/Cargo.toml index 62367ef8d..9850b06d3 100644 --- a/rsworkspace/crates/acp-nats-ws/Cargo.toml +++ b/rsworkspace/crates/acp-nats-ws/Cargo.toml @@ -20,7 +20,6 @@ opentelemetry = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal", "net", "sync", "io-util"] } -thiserror = { workspace = true } tracing = { workspace = true } trogon-nats = { workspace = true } trogon-std = { workspace = true, features = ["telemetry-http"] } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 1a2e7fdb9..3c8c88cfd 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -17,7 +17,6 @@ use serde_json::Value; use std::collections::{HashMap, HashSet}; use std::convert::Infallible; use std::rc::Rc; -use thiserror::Error; use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot, watch}; use tracing::{error, info, warn}; @@ -76,53 +75,77 @@ pub enum HttpPostOutcome { }, } -#[derive(Debug, Error)] +#[derive(Debug)] pub enum HttpTransportError { - #[error("{message}")] BadRequest { message: &'static str, - #[source] source: Option, }, - #[error("{message}")] NotFound { message: &'static str, - #[source] source: Option, }, - #[error("{message}")] Conflict { message: &'static str, - #[source] source: Option, }, - #[error("{message}")] UnsupportedMediaType { message: &'static str, - #[source] source: Option, }, - #[error("{message}")] NotAcceptable { message: &'static str, - #[source] source: Option, }, - #[error("{message}")] NotImplemented { message: &'static str, - #[source] source: Option, }, - #[error("{message}")] Internal { message: &'static str, - #[source] source: Option, }, } +impl std::fmt::Display for HttpTransportError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.message()) + } +} + +impl std::error::Error for HttpTransportError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source_ref() + } +} + impl HttpTransportError { + fn message(&self) -> &'static str { + match self { + Self::BadRequest { message, .. } + | Self::NotFound { message, .. } + | Self::Conflict { message, .. } + | Self::UnsupportedMediaType { message, .. } + | Self::NotAcceptable { message, .. } + | Self::NotImplemented { message, .. } + | Self::Internal { message, .. } => message, + } + } + + fn source_ref(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::BadRequest { source, .. } + | Self::NotFound { source, .. } + | Self::Conflict { source, .. } + | Self::UnsupportedMediaType { source, .. } + | Self::NotAcceptable { source, .. } + | Self::NotImplemented { source, .. } + | Self::Internal { source, .. } => source + .as_deref() + .map(|source| source as &(dyn std::error::Error + 'static)), + } + } + fn bad_request(message: &'static str) -> Self { Self::BadRequest { message, From 8b56be406566326b1d30a224f48485d91a2eb505 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 18:05:52 -0400 Subject: [PATCH 15/23] test(acp-nats-ws): keep manual transport errors covered Signed-off-by: Yordis Prieto --- .../crates/acp-nats-ws/src/transport.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 3c8c88cfd..5ccb49c44 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -1497,14 +1497,33 @@ mod tests { use std::error::Error as _; assert_eq!(HttpTransportError::bad_request("bad").to_string(), "bad"); + assert!(HttpTransportError::bad_request("bad").source().is_none()); + assert!(HttpTransportError::not_found("missing").source().is_none()); assert_eq!( HttpTransportError::conflict("conflict").to_string(), "conflict" ); + assert!(HttpTransportError::conflict("conflict").source().is_none()); + assert!( + HttpTransportError::unsupported_media_type("unsupported") + .source() + .is_none() + ); + assert!( + HttpTransportError::not_acceptable("not-acceptable") + .source() + .is_none() + ); + assert!( + HttpTransportError::not_implemented("not-implemented") + .source() + .is_none() + ); assert_eq!( HttpTransportError::internal("internal").to_string(), "internal" ); + assert!(HttpTransportError::internal("internal").source().is_none()); let invalid_json = IncomingHttpMessage::parse("{".to_string()).unwrap_err(); assert_eq!(invalid_json.to_string(), "invalid JSON-RPC payload"); From c73e94c0adb1ee13ddbeaca42c998f7c0c4e9573 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 18:21:54 -0400 Subject: [PATCH 16/23] refactor(acp-nats-ws): drop pre-spec websocket alias Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/README.md | 2 +- rsworkspace/crates/acp-nats-ws/src/constants.rs | 1 - rsworkspace/crates/acp-nats-ws/src/main.rs | 10 +--------- rsworkspace/crates/acp-nats-ws/src/transport.rs | 4 ---- 4 files changed, 2 insertions(+), 15 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/README.md b/rsworkspace/crates/acp-nats-ws/README.md index e08d933f6..61a3b1fa4 100644 --- a/rsworkspace/crates/acp-nats-ws/README.md +++ b/rsworkspace/crates/acp-nats-ws/README.md @@ -15,7 +15,7 @@ graph LR ## Features - Streamable HTTP transport on `/acp` with session-scoped SSE listeners -- WebSocket upgrade on `/acp` plus a legacy `/ws` compatibility alias +- WebSocket upgrade on `/acp` - Multiple concurrent ACP connections sharing the same NATS bridge - OpenTelemetry integration (logs, metrics, traces) - Graceful shutdown (SIGINT/SIGTERM) with per-connection drain diff --git a/rsworkspace/crates/acp-nats-ws/src/constants.rs b/rsworkspace/crates/acp-nats-ws/src/constants.rs index 501c15547..f9bebcb5d 100644 --- a/rsworkspace/crates/acp-nats-ws/src/constants.rs +++ b/rsworkspace/crates/acp-nats-ws/src/constants.rs @@ -6,6 +6,5 @@ pub const ACP_SESSION_ID_HEADER: &str = "acp-session-id"; pub const DEFAULT_HOST: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); pub const DEFAULT_PORT: u16 = 8080; pub const DUPLEX_BUFFER_SIZE: usize = 64 * 1024; -pub const LEGACY_WS_ENDPOINT: &str = "/ws"; pub const THREAD_NAME: &str = "acp-ws-local"; pub const X_ACCEL_BUFFERING_HEADER: &str = "x-accel-buffering"; diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs index eb2cc40ae..873cdc270 100644 --- a/rsworkspace/crates/acp-nats-ws/src/main.rs +++ b/rsworkspace/crates/acp-nats-ws/src/main.rs @@ -55,10 +55,6 @@ async fn main() -> Result<(), Box> { .post(transport::post) .delete(transport::delete), ) - .route( - LEGACY_WS_ENDPOINT, - axum::routing::get(transport::legacy_websocket_get), - ) .with_state(state), ); @@ -98,7 +94,7 @@ async fn main() -> Result<(), Box> { #[cfg(coverage)] fn main() {} -use constants::{ACP_ENDPOINT, LEGACY_WS_ENDPOINT, THREAD_NAME}; +use constants::{ACP_ENDPOINT, THREAD_NAME}; /// Runs a single-threaded tokio runtime with a /// `LocalSet`. All WebSocket connections are processed here because the ACP @@ -280,10 +276,6 @@ mod tests { .post(transport::post) .delete(transport::delete), ) - .route( - LEGACY_WS_ENDPOINT, - axum::routing::get(transport::legacy_websocket_get), - ) .with_state(state) } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 5ccb49c44..0337d3a81 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -556,10 +556,6 @@ pub async fn get(State(state): State, request: Request) -> Response { } } -pub async fn legacy_websocket_get(ws: WebSocketUpgrade, State(state): State) -> Response { - websocket_response(ws, state) -} - pub async fn post(headers: HeaderMap, State(state): State, body: String) -> Response { match http_post(headers, state, body).await { Ok(response) => response, From 7b3306a55ccdd9541c24deb531327428f706a119 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 18:23:33 -0400 Subject: [PATCH 17/23] test(acp-nats-ws): cover the single-endpoint router Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/src/main.rs | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs index 873cdc270..9bf4fc014 100644 --- a/rsworkspace/crates/acp-nats-ws/src/main.rs +++ b/rsworkspace/crates/acp-nats-ws/src/main.rs @@ -589,6 +589,30 @@ mod tests { conn_thread.join().unwrap(); } + #[tokio::test] + async fn legacy_websocket_alias_is_not_routed() { + let nats_mock = AdvancedMockNatsClient::new(); + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + + let response = app + .clone() + .oneshot( + Request::builder() + .method("GET") + .uri("/ws") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + #[tokio::test] async fn streamable_http_delete_terminates_initialized_connection() { let nats_mock = AdvancedMockNatsClient::new(); From f5f83ffef80c2fa17751acf866a3512dbb67e5e6 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 18:57:26 -0400 Subject: [PATCH 18/23] fix(acp-nats-ws): prevent protocol drift on HTTP Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/README.md | 2 + .../crates/acp-nats-ws/src/constants.rs | 1 + rsworkspace/crates/acp-nats-ws/src/main.rs | 74 ++++- .../crates/acp-nats-ws/src/transport.rs | 289 ++++++++++++++++-- 4 files changed, 341 insertions(+), 25 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/README.md b/rsworkspace/crates/acp-nats-ws/README.md index 61a3b1fa4..c0575792f 100644 --- a/rsworkspace/crates/acp-nats-ws/README.md +++ b/rsworkspace/crates/acp-nats-ws/README.md @@ -49,6 +49,8 @@ curl -i \ `POST /acp` returns an SSE response for JSON-RPC requests, `GET /acp` opens a session-scoped SSE listener with `Acp-Connection-Id` and `Acp-Session-Id`, and `DELETE /acp` terminates a connection. The WebSocket upgrade response and HTTP initialize response both include `Acp-Connection-Id`. +After `initialize`, HTTP clients may send `Acp-Protocol-Version` on `POST`/`GET`/`DELETE`. When present, it must match the negotiated ACP protocol version for that connection. + ## Configuration ### WebSocket Server diff --git a/rsworkspace/crates/acp-nats-ws/src/constants.rs b/rsworkspace/crates/acp-nats-ws/src/constants.rs index f9bebcb5d..767031bc6 100644 --- a/rsworkspace/crates/acp-nats-ws/src/constants.rs +++ b/rsworkspace/crates/acp-nats-ws/src/constants.rs @@ -2,6 +2,7 @@ use std::net::{IpAddr, Ipv4Addr}; pub const ACP_CONNECTION_ID_HEADER: &str = "acp-connection-id"; pub const ACP_ENDPOINT: &str = "/acp"; +pub const ACP_PROTOCOL_VERSION_HEADER: &str = "acp-protocol-version"; pub const ACP_SESSION_ID_HEADER: &str = "acp-session-id"; pub const DEFAULT_HOST: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); pub const DEFAULT_PORT: u16 = 8080; diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs index 9bf4fc014..1e26777d2 100644 --- a/rsworkspace/crates/acp-nats-ws/src/main.rs +++ b/rsworkspace/crates/acp-nats-ws/src/main.rs @@ -199,7 +199,9 @@ async fn process_connections( #[cfg(test)] mod tests { use super::*; - use crate::constants::{ACP_CONNECTION_ID_HEADER, ACP_SESSION_ID_HEADER}; + use crate::constants::{ + ACP_CONNECTION_ID_HEADER, ACP_PROTOCOL_VERSION_HEADER, ACP_SESSION_ID_HEADER, + }; use acp_nats::Config; use axum::body::{Body, to_bytes}; use axum::http::header::{ACCEPT, CONTENT_TYPE}; @@ -528,6 +530,7 @@ mod tests { .header(CONTENT_TYPE, "application/json") .header(ACCEPT, "application/json, text/event-stream") .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .header(ACP_PROTOCOL_VERSION_HEADER, "0") .body(Body::from( r#"{"jsonrpc":"2.0","id":2,"method":"session/new","params":{"cwd":".","mcpServers":[]}}"#, )) @@ -552,6 +555,13 @@ mod tests { .get(ACP_SESSION_ID_HEADER) .and_then(|value| value.to_str().ok()) .map(str::to_owned); + assert_eq!( + session_new + .headers() + .get(ACP_PROTOCOL_VERSION_HEADER) + .and_then(|value| value.to_str().ok()), + Some("0") + ); let body = body_text(session_new).await; let events = sse_events(&body); assert_eq!(events.len(), 1); @@ -647,6 +657,7 @@ mod tests { .method("DELETE") .uri(ACP_ENDPOINT) .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .header(ACP_PROTOCOL_VERSION_HEADER, "0") .body(Body::empty()) .unwrap(), ) @@ -654,6 +665,67 @@ mod tests { .unwrap(); assert_eq!(response.status(), StatusCode::ACCEPTED); + assert_eq!( + response + .headers() + .get(ACP_PROTOCOL_VERSION_HEADER) + .and_then(|value| value.to_str().ok()), + Some("0") + ); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn streamable_http_rejects_mismatched_protocol_version_after_initialize() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let initialize = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + let connection_id = initialize + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + let _ = body_text(initialize).await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(ACP_ENDPOINT) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .header(ACP_PROTOCOL_VERSION_HEADER, "1") + .body(Body::from(r#"{"jsonrpc":"2.0","method":"initialized"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert_eq!( + body_text(response).await, + "Acp-Protocol-Version header does not match initialized protocol version" + ); let _ = shutdown_tx.send(true); drop(app); diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 0337d3a81..c8af0a913 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -1,8 +1,11 @@ use crate::acp_connection_id::AcpConnectionId; use crate::connection; -use crate::constants::{ACP_CONNECTION_ID_HEADER, ACP_SESSION_ID_HEADER, X_ACCEL_BUFFERING_HEADER}; +use crate::constants::{ + ACP_CONNECTION_ID_HEADER, ACP_PROTOCOL_VERSION_HEADER, ACP_SESSION_ID_HEADER, + X_ACCEL_BUFFERING_HEADER, +}; use acp_nats::{StdJsonSerialize, agent::Bridge, client, spawn_notification_forwarder}; -use agent_client_protocol::{AgentSideConnection, RequestId, SessionNotification}; +use agent_client_protocol::{AgentSideConnection, ProtocolVersion, RequestId, SessionNotification}; use axum::extract::FromRequestParts; use axum::extract::Request; use axum::extract::State; @@ -44,6 +47,7 @@ pub enum ManagerRequest { WebSocket(Box), HttpPost { connection_id: Option, + protocol_version: Option, session_id: Option, message: IncomingHttpMessage, response: oneshot::Sender>, @@ -52,11 +56,13 @@ pub enum ManagerRequest { HttpGet { connection_id: AcpConnectionId, session_id: acp_nats::AcpSessionId, - response: oneshot::Sender>, + protocol_version: Option, + response: oneshot::Sender>, }, HttpDelete { connection_id: AcpConnectionId, - response: oneshot::Sender>, + protocol_version: Option, + response: oneshot::Sender, HttpTransportError>>, }, } @@ -65,16 +71,24 @@ pub enum HttpPostOutcome { Accepted, Live { connection_id: AcpConnectionId, + protocol_version: Option, session_id: Option, stream: SseReceiver, }, Buffered { connection_id: AcpConnectionId, + protocol_version: Option, session_id: Option, events: Vec, }, } +#[derive(Debug)] +pub struct HttpGetOutcome { + pub protocol_version: Option, + pub stream: SseReceiver, +} + #[derive(Debug)] pub enum HttpTransportError { BadRequest { @@ -241,16 +255,19 @@ pub struct HttpConnectionHandle { #[derive(Debug)] pub enum HttpConnectionCommand { Post { + protocol_version: Option, session_id: Option, message: IncomingHttpMessage, response: oneshot::Sender>, }, AttachListener { session_id: acp_nats::AcpSessionId, - response: oneshot::Sender>, + protocol_version: Option, + response: oneshot::Sender>, }, Close { - response: oneshot::Sender>, + protocol_version: Option, + response: oneshot::Sender, HttpTransportError>>, }, } @@ -271,6 +288,7 @@ impl SseFrame { enum PendingRequest { Live { request_id: RequestId, + capture_protocol_version: bool, session_id: Option, sender: SseSender, }, @@ -396,6 +414,12 @@ impl OutgoingHttpMessage { let session_id = result.get("sessionId")?.as_str()?; acp_nats::AcpSessionId::new(session_id).ok() } + + fn result_protocol_version(&self) -> Option { + let result = self.result.as_ref()?; + let protocol_version = result.get("protocolVersion")?; + serde_json::from_value(protocol_version.clone()).ok() + } } enum LiveFrameOutcome { @@ -609,6 +633,7 @@ async fn http_post( } let connection_id = parse_connection_id_header(&headers)?; + let protocol_version = parse_protocol_version_header(&headers)?; let session_id = parse_session_id_header(&headers)?; validate_http_context(&message, connection_id.as_ref(), session_id.as_ref())?; @@ -618,6 +643,7 @@ async fn http_post( .manager_tx .send(ManagerRequest::HttpPost { connection_id, + protocol_version, session_id, message, response: response_tx, @@ -633,15 +659,23 @@ async fn http_post( HttpPostOutcome::Accepted => Ok(StatusCode::ACCEPTED.into_response()), HttpPostOutcome::Live { connection_id, + protocol_version, session_id, stream, - } => Ok(build_sse_response(connection_id, session_id, stream)), + } => Ok(build_sse_response( + connection_id, + protocol_version, + session_id, + stream, + )), HttpPostOutcome::Buffered { connection_id, + protocol_version, session_id, events, } => Ok(build_buffered_sse_response( connection_id, + protocol_version, session_id, events, )), @@ -654,6 +688,7 @@ async fn http_get(headers: HeaderMap, state: AppState) -> Result Result(item.into_event()), stream)) })) .into_response(); - set_transport_headers(response.headers_mut(), &connection_id, Some(&session_id)); + set_transport_headers( + response.headers_mut(), + &connection_id, + outcome.protocol_version.as_ref(), + Some(&session_id), + ); response .headers_mut() .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); @@ -692,23 +733,27 @@ async fn http_delete(headers: HeaderMap, state: AppState) -> Result Result<(), HttpTransportError> { @@ -839,6 +884,35 @@ fn parse_connection_id_header( .transpose() } +fn parse_protocol_version_header( + headers: &HeaderMap, +) -> Result, HttpTransportError> { + headers + .get(ACP_PROTOCOL_VERSION_HEADER) + .map(|value| { + value + .to_str() + .map_err(|error| { + HttpTransportError::bad_request_with( + "invalid Acp-Protocol-Version header", + error, + ) + }) + .and_then(|value| { + value + .parse::() + .map(ProtocolVersion::from) + .map_err(|error| { + HttpTransportError::bad_request_with( + "invalid Acp-Protocol-Version header", + error, + ) + }) + }) + }) + .transpose() +} + fn parse_session_id_header( headers: &HeaderMap, ) -> Result, HttpTransportError> { @@ -869,6 +943,7 @@ fn is_websocket_request(headers: &HeaderMap) -> bool { fn build_sse_response( connection_id: AcpConnectionId, + protocol_version: Option, session_id: Option, stream: SseReceiver, ) -> Response { @@ -879,7 +954,12 @@ fn build_sse_response( .map(|item| (Ok::(item.into_event()), stream)) })) .into_response(); - set_transport_headers(response.headers_mut(), &connection_id, session_id.as_ref()); + set_transport_headers( + response.headers_mut(), + &connection_id, + protocol_version.as_ref(), + session_id.as_ref(), + ); response .headers_mut() .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); @@ -888,6 +968,7 @@ fn build_sse_response( fn build_buffered_sse_response( connection_id: AcpConnectionId, + protocol_version: Option, session_id: Option, events: Vec, ) -> Response { @@ -897,7 +978,12 @@ fn build_buffered_sse_response( .map(|item| Ok::(item.into_event())), ); let mut response = Sse::new(stream).into_response(); - set_transport_headers(response.headers_mut(), &connection_id, session_id.as_ref()); + set_transport_headers( + response.headers_mut(), + &connection_id, + protocol_version.as_ref(), + session_id.as_ref(), + ); response .headers_mut() .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); @@ -907,6 +993,7 @@ fn build_buffered_sse_response( fn set_transport_headers( headers: &mut HeaderMap, connection_id: &AcpConnectionId, + protocol_version: Option<&ProtocolVersion>, session_id: Option<&acp_nats::AcpSessionId>, ) { headers.insert( @@ -914,6 +1001,7 @@ fn set_transport_headers( HeaderValue::from_str(&connection_id.to_string()) .expect("generated ACP connection id must be a valid header value"), ); + set_protocol_version_header(headers, protocol_version); if let Some(session_id) = session_id { headers.insert( ACP_SESSION_ID_HEADER, @@ -923,6 +1011,34 @@ fn set_transport_headers( } } +fn set_protocol_version_header( + headers: &mut HeaderMap, + protocol_version: Option<&ProtocolVersion>, +) { + if let Some(protocol_version) = protocol_version { + headers.insert( + ACP_PROTOCOL_VERSION_HEADER, + HeaderValue::from_str(&protocol_version.to_string()) + .expect("ACP protocol version must be a valid header value"), + ); + } +} + +fn validate_protocol_version_header( + provided: Option<&ProtocolVersion>, + negotiated: Option<&ProtocolVersion>, +) -> Result<(), HttpTransportError> { + if let (Some(provided), Some(negotiated)) = (provided, negotiated) + && provided != negotiated + { + return Err(HttpTransportError::bad_request( + "Acp-Protocol-Version header does not match initialized protocol version", + )); + } + + Ok(()) +} + pub async fn run_http_connection( connection_id: AcpConnectionId, nats_client: N, @@ -1008,6 +1124,7 @@ pub async fn run_http_connection( let mut io_task = tokio::task::spawn_local(io_task); let mut pending_request: Option = None; + let mut protocol_version: Option = None; let mut sessions = HashSet::::new(); let mut get_listeners = HashMap::>::new(); @@ -1021,7 +1138,15 @@ pub async fn run_http_connection( }; match command { - HttpConnectionCommand::Post { session_id, message, response } => { + HttpConnectionCommand::Post { protocol_version: header_protocol_version, session_id, message, response } => { + if let Err(error) = validate_protocol_version_header( + header_protocol_version.as_ref(), + protocol_version.as_ref(), + ) { + let _ = response.send(Err(error)); + continue; + } + if message.is_request() { if pending_request.is_some() { let _ = response.send(Err(HttpTransportError::conflict( @@ -1054,6 +1179,7 @@ pub async fn run_http_connection( let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); pending_request = Some(PendingRequest::Live { request_id: message.id.clone().expect("request must have id"), + capture_protocol_version: message.is_initialize(), session_id: session_id.clone(), sender: stream_tx, }); @@ -1066,6 +1192,7 @@ pub async fn run_http_connection( } let _ = response.send(Ok(HttpPostOutcome::Live { connection_id: connection_id.clone(), + protocol_version: protocol_version.clone(), session_id, stream: stream_rx, })); @@ -1085,7 +1212,19 @@ pub async fn run_http_connection( let _ = response.send(Ok(HttpPostOutcome::Accepted)); } - HttpConnectionCommand::AttachListener { session_id, response } => { + HttpConnectionCommand::AttachListener { + session_id, + protocol_version: header_protocol_version, + response, + } => { + if let Err(error) = validate_protocol_version_header( + header_protocol_version.as_ref(), + protocol_version.as_ref(), + ) { + let _ = response.send(Err(error)); + continue; + } + if !sessions.contains(&session_id) { let _ = response .send(Err(HttpTransportError::not_found("unknown ACP session"))); @@ -1094,10 +1233,24 @@ pub async fn run_http_connection( let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); get_listeners.entry(session_id).or_default().push(stream_tx); - let _ = response.send(Ok(stream_rx)); + let _ = response.send(Ok(HttpGetOutcome { + protocol_version: protocol_version.clone(), + stream: stream_rx, + })); } - HttpConnectionCommand::Close { response } => { - let _ = response.send(Ok(())); + HttpConnectionCommand::Close { + protocol_version: header_protocol_version, + response, + } => { + if let Err(error) = validate_protocol_version_header( + header_protocol_version.as_ref(), + protocol_version.as_ref(), + ) { + let _ = response.send(Err(error)); + continue; + } + + let _ = response.send(Ok(protocol_version.clone())); break; } } @@ -1114,9 +1267,19 @@ pub async fn run_http_connection( match pending { PendingRequest::Live { request_id, + capture_protocol_version, session_id, sender, } => { + if *capture_protocol_version + && parsed.as_ref().and_then(|message| message.id.as_ref()) + == Some(request_id) + { + protocol_version = parsed + .as_ref() + .and_then(OutgoingHttpMessage::result_protocol_version); + } + match route_live_frame( &frame, parsed.as_ref(), @@ -1151,6 +1314,7 @@ pub async fn run_http_connection( { let _ = response.send(Ok(HttpPostOutcome::Buffered { connection_id: connection_id.clone(), + protocol_version: protocol_version.clone(), session_id, events, })); @@ -1252,6 +1416,7 @@ pub async fn process_manager_request( } ManagerRequest::HttpPost { connection_id, + protocol_version, session_id, message, response, @@ -1295,6 +1460,7 @@ pub async fn process_manager_request( if handle .command_tx .send(HttpConnectionCommand::Post { + protocol_version, session_id, message, response, @@ -1307,6 +1473,7 @@ pub async fn process_manager_request( ManagerRequest::HttpGet { connection_id, session_id, + protocol_version, response, } => { let Some(handle) = http_connections.get(&connection_id) else { @@ -1318,6 +1485,7 @@ pub async fn process_manager_request( .command_tx .send(HttpConnectionCommand::AttachListener { session_id, + protocol_version, response, }) .is_err() @@ -1327,6 +1495,7 @@ pub async fn process_manager_request( } ManagerRequest::HttpDelete { connection_id, + protocol_version, response, } => { let Some(handle) = http_connections.remove(&connection_id) else { @@ -1334,9 +1503,10 @@ pub async fn process_manager_request( return; }; - let _ = handle - .command_tx - .send(HttpConnectionCommand::Close { response }); + let _ = handle.command_tx.send(HttpConnectionCommand::Close { + protocol_version, + response, + }); } } } @@ -1860,6 +2030,31 @@ mod tests { )); } + #[test] + fn validate_protocol_version_header_allows_missing_and_rejects_mismatches() { + assert!(validate_protocol_version_header(None, None).is_ok()); + assert!(validate_protocol_version_header(Some(&ProtocolVersion::V0), None).is_ok()); + assert!(validate_protocol_version_header(None, Some(&ProtocolVersion::V0)).is_ok()); + assert!( + validate_protocol_version_header( + Some(&ProtocolVersion::V0), + Some(&ProtocolVersion::V0) + ) + .is_ok() + ); + + assert!(matches!( + validate_protocol_version_header( + Some(&ProtocolVersion::V1), + Some(&ProtocolVersion::V0) + ), + Err(HttpTransportError::BadRequest { + message: "Acp-Protocol-Version header does not match initialized protocol version", + source: None, + }) + )); + } + #[test] fn header_parsers_and_websocket_detection_handle_valid_and_invalid_values() { let connection_id = AcpConnectionId::new(); @@ -1873,6 +2068,7 @@ mod tests { ACP_SESSION_ID_HEADER, HeaderValue::from_str(session_id.as_str()).unwrap(), ); + headers.insert(ACP_PROTOCOL_VERSION_HEADER, HeaderValue::from_static("0")); headers.insert("upgrade", HeaderValue::from_static("websocket")); assert_eq!( @@ -1880,6 +2076,10 @@ mod tests { Some(connection_id.clone()) ); assert_eq!(parse_session_id_header(&headers).unwrap(), Some(session_id)); + assert_eq!( + parse_protocol_version_header(&headers).unwrap(), + Some(ProtocolVersion::V0) + ); assert!(is_websocket_request(&headers)); headers.insert( @@ -1893,6 +2093,22 @@ mod tests { source: Some(_), }) )); + + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + headers.insert( + ACP_PROTOCOL_VERSION_HEADER, + HeaderValue::from_static("not-a-version"), + ); + assert!(matches!( + parse_protocol_version_header(&headers), + Err(HttpTransportError::BadRequest { + message: "invalid Acp-Protocol-Version header", + source: Some(_), + }) + )); } #[tokio::test] @@ -2009,6 +2225,7 @@ mod tests { assert!(message.is_session_new()); let _ = response.send(Ok(HttpPostOutcome::Buffered { connection_id: expected_connection_id, + protocol_version: Some(ProtocolVersion::V0), session_id: Some(expected_session_id), events: vec![SseFrame::Json(event.to_string())], })); @@ -2041,6 +2258,10 @@ mod tests { response.headers().get(ACP_SESSION_ID_HEADER).unwrap(), HeaderValue::from_str(session_id.as_str()).unwrap() ); + assert_eq!( + response.headers().get(ACP_PROTOCOL_VERSION_HEADER).unwrap(), + HeaderValue::from_static("0") + ); assert_eq!(json_event_body(response).await, vec![expected_event]); } @@ -2057,10 +2278,12 @@ mod tests { ManagerRequest::HttpGet { connection_id: actual_connection_id, session_id: actual_session_id, + protocol_version, response, } => { assert_eq!(actual_connection_id, expected_connection_id.clone()); assert_eq!(actual_session_id, expected_session_id); + assert_eq!(protocol_version, None); let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); let _ = stream_tx.try_send(SseFrame::Json( json!({ @@ -2071,17 +2294,22 @@ mod tests { .to_string(), )); drop(stream_tx); - let _ = response.send(Ok(stream_rx)); + let _ = response.send(Ok(HttpGetOutcome { + protocol_version: Some(ProtocolVersion::V0), + stream: stream_rx, + })); } _ => panic!("unexpected manager request"), } match manager_rx.recv().await.unwrap() { ManagerRequest::HttpDelete { connection_id: actual_connection_id, + protocol_version, response, } => { assert_eq!(actual_connection_id, expected_connection_id); - let _ = response.send(Ok(())); + assert_eq!(protocol_version, None); + let _ = response.send(Ok(Some(ProtocolVersion::V0))); } _ => panic!("unexpected manager request"), } @@ -2107,6 +2335,10 @@ mod tests { response.headers().get(ACP_SESSION_ID_HEADER).unwrap(), HeaderValue::from_str(session_id.as_str()).unwrap() ); + assert_eq!( + response.headers().get(ACP_PROTOCOL_VERSION_HEADER).unwrap(), + HeaderValue::from_static("0") + ); assert_eq!( response.headers().get(X_ACCEL_BUFFERING_HEADER).unwrap(), HeaderValue::from_static("no") @@ -2128,6 +2360,10 @@ mod tests { let response = http_delete(delete_headers, state).await.unwrap(); assert_eq!(response.status(), StatusCode::ACCEPTED); + assert_eq!( + response.headers().get(ACP_PROTOCOL_VERSION_HEADER).unwrap(), + HeaderValue::from_static("0") + ); } #[tokio::test] @@ -2161,6 +2397,7 @@ mod tests { process_manager_request( ManagerRequest::HttpPost { connection_id: None, + protocol_version: None, session_id: None, message: post_message, response: post_response_tx, @@ -2188,6 +2425,7 @@ mod tests { ManagerRequest::HttpGet { connection_id: unknown_connection_id.clone(), session_id: session_id(), + protocol_version: None, response: get_response_tx, }, &mut http_connections, @@ -2210,6 +2448,7 @@ mod tests { process_manager_request( ManagerRequest::HttpDelete { connection_id: unknown_connection_id, + protocol_version: None, response: delete_response_tx, }, &mut http_connections, @@ -2249,6 +2488,7 @@ mod tests { ManagerRequest::HttpGet { connection_id: unknown_connection_id, session_id: session_id(), + protocol_version: None, response: response_tx, }, &mut http_connections, @@ -2296,6 +2536,7 @@ mod tests { process_manager_request( ManagerRequest::HttpPost { connection_id: None, + protocol_version: None, session_id: None, message: initialize, response: response_tx, From 32a0c517c3185a2ab3ff780eadea30ee157e381b Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 19:16:49 -0400 Subject: [PATCH 19/23] fix(acp-nats-ws): close transport compliance gaps Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/README.md | 2 + rsworkspace/crates/acp-nats-ws/src/main.rs | 5 + .../crates/acp-nats-ws/src/transport.rs | 330 +++++++++++++++--- 3 files changed, 290 insertions(+), 47 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/README.md b/rsworkspace/crates/acp-nats-ws/README.md index c0575792f..c30f99a05 100644 --- a/rsworkspace/crates/acp-nats-ws/README.md +++ b/rsworkspace/crates/acp-nats-ws/README.md @@ -51,6 +51,8 @@ curl -i \ After `initialize`, HTTP clients may send `Acp-Protocol-Version` on `POST`/`GET`/`DELETE`. When present, it must match the negotiated ACP protocol version for that connection. +When clients send an `Origin` header, `/acp` validates it against the bound host and rejects disallowed origins with `403 Forbidden`. Streamable HTTP `POST` SSE responses also emit a priming SSE event ID before JSON-RPC payloads and attach event IDs to streamed JSON events. + ## Configuration ### WebSocket Server diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs index 1e26777d2..eaa6542fc 100644 --- a/rsworkspace/crates/acp-nats-ws/src/main.rs +++ b/rsworkspace/crates/acp-nats-ws/src/main.rs @@ -43,6 +43,7 @@ async fn main() -> Result<(), Box> { .spawn(move || run_connection_thread(manager_rx, nats_client, js_client, ws_config.acp))?; let state = AppState { + bind_host: ws_config.host, manager_tx, shutdown_tx: shutdown_tx.clone(), }; @@ -208,6 +209,7 @@ mod tests { use axum::http::{Request, StatusCode}; use futures_util::{SinkExt, StreamExt}; use serde_json::Value; + use std::net::{IpAddr, Ipv4Addr}; use std::time::Duration; use tokio::net::TcpListener; use tokio_tungstenite::connect_async; @@ -303,6 +305,7 @@ mod tests { let (manager_tx, manager_rx) = mpsc::unbounded_channel::(); let conn_thread = spawn_connection_thread(nats_mock, manager_rx); let app = test_app(AppState { + bind_host: IpAddr::V4(Ipv4Addr::LOCALHOST), manager_tx, shutdown_tx: shutdown_tx.clone(), }); @@ -321,6 +324,7 @@ mod tests { let (manager_tx, manager_rx) = mpsc::unbounded_channel::(); let conn_thread = spawn_connection_thread(nats_mock, manager_rx); let app = test_app(AppState { + bind_host: IpAddr::V4(Ipv4Addr::LOCALHOST), manager_tx, shutdown_tx: shutdown_tx.clone(), }); @@ -347,6 +351,7 @@ mod tests { fn sse_events(body: &str) -> Vec { body.lines() .filter_map(|line| line.strip_prefix("data: ")) + .filter(|json| !json.is_empty()) .map(|json| serde_json::from_str(json).unwrap()) .collect() } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index c8af0a913..f25fdc019 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -10,20 +10,22 @@ use axum::extract::FromRequestParts; use axum::extract::Request; use axum::extract::State; use axum::extract::ws::{WebSocket, WebSocketUpgrade}; -use axum::http::header::{ACCEPT, CONTENT_TYPE}; -use axum::http::{HeaderMap, HeaderValue, StatusCode}; +use axum::http::header::{ACCEPT, CONTENT_TYPE, HOST, ORIGIN}; +use axum::http::{HeaderMap, HeaderValue, StatusCode, Uri, uri::Authority}; use axum::response::sse::{Event, Sse}; use axum::response::{IntoResponse, Response}; -use futures_util::stream; +use futures_util::{StreamExt, stream}; use serde::Deserialize; use serde_json::Value; use std::collections::{HashMap, HashSet}; use std::convert::Infallible; +use std::net::IpAddr; use std::rc::Rc; use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot, watch}; use tracing::{error, info, warn}; use trogon_std::time::SystemClock; +use uuid::Uuid; const HTTP_CHANNEL_CAPACITY: usize = 64; @@ -33,6 +35,7 @@ type SseReceiver = mpsc::Receiver; #[derive(Clone)] pub struct AppState { + pub bind_host: IpAddr, pub manager_tx: mpsc::UnboundedSender, pub shutdown_tx: watch::Sender, } @@ -103,6 +106,10 @@ pub enum HttpTransportError { message: &'static str, source: Option, }, + Forbidden { + message: &'static str, + source: Option, + }, UnsupportedMediaType { message: &'static str, source: Option, @@ -139,6 +146,7 @@ impl HttpTransportError { Self::BadRequest { message, .. } | Self::NotFound { message, .. } | Self::Conflict { message, .. } + | Self::Forbidden { message, .. } | Self::UnsupportedMediaType { message, .. } | Self::NotAcceptable { message, .. } | Self::NotImplemented { message, .. } @@ -151,6 +159,7 @@ impl HttpTransportError { Self::BadRequest { source, .. } | Self::NotFound { source, .. } | Self::Conflict { source, .. } + | Self::Forbidden { source, .. } | Self::UnsupportedMediaType { source, .. } | Self::NotAcceptable { source, .. } | Self::NotImplemented { source, .. } @@ -191,6 +200,23 @@ impl HttpTransportError { } } + fn forbidden(message: &'static str) -> Self { + Self::Forbidden { + message, + source: None, + } + } + + fn forbidden_with( + message: &'static str, + source: impl std::error::Error + Send + 'static, + ) -> Self { + Self::Forbidden { + message, + source: Some(Box::new(source)), + } + } + fn unsupported_media_type(message: &'static str) -> Self { Self::UnsupportedMediaType { message, @@ -234,6 +260,7 @@ impl HttpTransportError { Self::BadRequest { .. } => StatusCode::BAD_REQUEST, Self::NotFound { .. } => StatusCode::NOT_FOUND, Self::Conflict { .. } => StatusCode::CONFLICT, + Self::Forbidden { .. } => StatusCode::FORBIDDEN, Self::UnsupportedMediaType { .. } => StatusCode::UNSUPPORTED_MEDIA_TYPE, Self::NotAcceptable { .. } => StatusCode::NOT_ACCEPTABLE, Self::NotImplemented { .. } => StatusCode::NOT_IMPLEMENTED, @@ -273,13 +300,40 @@ pub enum HttpConnectionCommand { #[derive(Clone, Debug)] pub(crate) enum SseFrame { - Json(String), + Empty { + event_id: String, + }, + Json { + event_id: Option, + json: String, + }, } impl SseFrame { + fn prime() -> Self { + Self::Empty { + event_id: Uuid::new_v4().to_string(), + } + } + + fn json(json: String) -> Self { + Self::Json { + event_id: Some(Uuid::new_v4().to_string()), + json, + } + } + fn into_event(self) -> Event { match self { - Self::Json(json) => Event::default().data(json), + Self::Empty { event_id } => Event::default().id(event_id).data(""), + Self::Json { event_id, json } => { + let event = Event::default().data(json); + if let Some(event_id) = event_id { + event.id(event_id) + } else { + event + } + } } } } @@ -464,17 +518,26 @@ fn dispatch_to_get_listeners( .get_mut(session_id) .expect("session listeners must exist after presence check"); let mut delivered = false; - listeners.retain(|listener| match listener.try_send(frame.clone()) { - Ok(()) => { - delivered = true; - true + let mut retained = Vec::with_capacity(listeners.len()); + + for listener in listeners.drain(..) { + if delivered { + retained.push(listener); + continue; } - Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { - warn!(session_id = %session_id, "Dropping stalled HTTP SSE listener"); - false + + match listener.try_send(frame.clone()) { + Ok(()) => { + delivered = true; + retained.push(listener); + } + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + warn!(session_id = %session_id, "Dropping stalled HTTP SSE listener"); + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {} } - Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => false, - }); + } + *listeners = retained; ( if delivered { @@ -564,6 +627,10 @@ fn route_buffered_frame( } pub async fn get(State(state): State, request: Request) -> Response { + if let Err(error) = validate_origin(request.headers(), state.bind_host) { + return error.into_response(); + } + if is_websocket_request(request.headers()) { let (mut parts, _body) = request.into_parts(); match WebSocketUpgrade::from_request_parts(&mut parts, &state).await { @@ -581,6 +648,10 @@ pub async fn get(State(state): State, request: Request) -> Response { } pub async fn post(headers: HeaderMap, State(state): State, body: String) -> Response { + if let Err(error) = validate_origin(&headers, state.bind_host) { + return error.into_response(); + } + match http_post(headers, state, body).await { Ok(response) => response, Err(error) => error.into_response(), @@ -588,6 +659,10 @@ pub async fn post(headers: HeaderMap, State(state): State, body: Strin } pub async fn delete(headers: HeaderMap, State(state): State) -> Response { + if let Err(error) = validate_origin(&headers, state.bind_host) { + return error.into_response(); + } + match http_delete(headers, state).await { Ok(response) => response, Err(error) => error.into_response(), @@ -861,6 +936,52 @@ fn media_type_matches(header_value: &str, expected: &str) -> bool { .is_some_and(|media_type| media_type.eq_ignore_ascii_case(expected)) } +fn validate_origin(headers: &HeaderMap, bind_host: IpAddr) -> Result<(), HttpTransportError> { + let Some(origin) = headers.get(ORIGIN) else { + return Ok(()); + }; + + let origin = origin + .to_str() + .map_err(|error| HttpTransportError::forbidden_with("invalid Origin header", error))?; + let origin = origin + .parse::() + .map_err(|error| HttpTransportError::forbidden_with("invalid Origin header", error))?; + let Some(origin_host) = origin.host() else { + return Err(HttpTransportError::forbidden("invalid Origin header")); + }; + + let request_host = headers + .get(HOST) + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.parse::().ok()) + .map(|authority| authority.host().to_owned()); + + let allowed = if bind_host.is_loopback() { + is_loopback_host(origin_host) + } else if bind_host.is_unspecified() { + is_loopback_host(origin_host) + || request_host + .as_deref() + .is_some_and(|host| host.eq_ignore_ascii_case(origin_host)) + } else { + origin_host.eq_ignore_ascii_case(&bind_host.to_string()) + || request_host + .as_deref() + .is_some_and(|host| host.eq_ignore_ascii_case(origin_host)) + }; + + if allowed { + Ok(()) + } else { + Err(HttpTransportError::forbidden("Origin is not allowed")) + } +} + +fn is_loopback_host(host: &str) -> bool { + matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]") +} + fn parse_connection_id_header( headers: &HeaderMap, ) -> Result, HttpTransportError> { @@ -947,13 +1068,16 @@ fn build_sse_response( session_id: Option, stream: SseReceiver, ) -> Response { - let mut response = Sse::new(stream::unfold(stream, |mut stream| async move { - stream - .recv() - .await - .map(|item| (Ok::(item.into_event()), stream)) - })) - .into_response(); + let primed_stream = + stream::once(async { Ok::(SseFrame::prime().into_event()) }).chain( + stream::unfold(stream, |mut stream| async move { + stream + .recv() + .await + .map(|item| (Ok::(item.into_event()), stream)) + }), + ); + let mut response = Sse::new(primed_stream).into_response(); set_transport_headers( response.headers_mut(), &connection_id, @@ -973,8 +1097,8 @@ fn build_buffered_sse_response( events: Vec, ) -> Response { let stream = stream::iter( - events - .into_iter() + std::iter::once(SseFrame::prime()) + .chain(events) .map(|item| Ok::(item.into_event())), ); let mut response = Sse::new(stream).into_response(); @@ -1260,7 +1384,7 @@ pub async fn run_http_connection( break; }; - let frame = SseFrame::Json(outbound.clone()); + let frame = SseFrame::json(outbound.clone()); let parsed = OutgoingHttpMessage::parse(&outbound); if let Some(pending) = pending_request.as_mut() { @@ -1517,7 +1641,9 @@ mod tests { use acp_nats::Config; use axum::body::{Body, to_bytes}; use axum::http::Request as HttpRequest; + use axum::http::header::{HOST, ORIGIN}; use serde_json::{Value, json}; + use std::net::{IpAddr, Ipv4Addr}; use tokio::sync::{mpsc, oneshot, watch}; use trogon_nats::AdvancedMockNatsClient; @@ -1581,6 +1707,7 @@ mod tests { let (shutdown_tx, _) = watch::channel(false); ( AppState { + bind_host: IpAddr::V4(Ipv4Addr::LOCALHOST), manager_tx, shutdown_tx, }, @@ -1593,10 +1720,18 @@ mod tests { let body = String::from_utf8(bytes.to_vec()).unwrap(); body.lines() .filter_map(|line| line.strip_prefix("data: ")) + .filter(|json| !json.is_empty()) .map(|json| serde_json::from_str(json).unwrap()) .collect() } + fn frame_json(frame: &SseFrame) -> &str { + match frame { + SseFrame::Json { json, .. } => json, + SseFrame::Empty { .. } => panic!("expected JSON SSE frame"), + } + } + fn post_headers() -> HeaderMap { let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); @@ -1632,6 +1767,10 @@ mod tests { HttpTransportError::conflict("conflict"), StatusCode::CONFLICT, ), + ( + HttpTransportError::forbidden("forbidden"), + StatusCode::FORBIDDEN, + ), ( HttpTransportError::unsupported_media_type("unsupported"), StatusCode::UNSUPPORTED_MEDIA_TYPE, @@ -1670,6 +1809,15 @@ mod tests { "conflict" ); assert!(HttpTransportError::conflict("conflict").source().is_none()); + assert_eq!( + HttpTransportError::forbidden("forbidden").to_string(), + "forbidden" + ); + assert!( + HttpTransportError::forbidden("forbidden") + .source() + .is_none() + ); assert!( HttpTransportError::unsupported_media_type("unsupported") .source() @@ -1787,14 +1935,11 @@ mod tests { #[test] fn route_live_frame_keeps_same_session_notifications_on_post_stream() { - let frame = SseFrame::Json( + let frame = SseFrame::json( r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-1"}}"# .to_string(), ); - let parsed = OutgoingHttpMessage::parse(match &frame { - SseFrame::Json(json) => json, - }) - .unwrap(); + let parsed = OutgoingHttpMessage::parse(frame_json(&frame)).unwrap(); let request_id = RequestId::Number(1); let request_session_id = session_id(); let (live_tx, mut live_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); @@ -1813,7 +1958,8 @@ mod tests { assert!(matches!(outcome, LiveFrameOutcome::Keep)); match live_rx.try_recv().unwrap() { - SseFrame::Json(json) => assert!(json.contains(r#""sessionId":"session-1""#)), + SseFrame::Json { json, .. } => assert!(json.contains(r#""sessionId":"session-1""#)), + SseFrame::Empty { .. } => panic!("expected JSON SSE frame"), } assert!(matches!( get_rx.try_recv(), @@ -1823,14 +1969,11 @@ mod tests { #[test] fn route_live_frame_sends_other_session_notifications_to_get_listeners() { - let frame = SseFrame::Json( + let frame = SseFrame::json( r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-2"}}"# .to_string(), ); - let parsed = OutgoingHttpMessage::parse(match &frame { - SseFrame::Json(json) => json, - }) - .unwrap(); + let parsed = OutgoingHttpMessage::parse(frame_json(&frame)).unwrap(); let request_id = RequestId::Number(1); let request_session_id = session_id(); let other_session_id = acp_nats::AcpSessionId::new("session-2").unwrap(); @@ -1854,20 +1997,18 @@ mod tests { Err(tokio::sync::mpsc::error::TryRecvError::Empty) )); match get_rx.try_recv().unwrap() { - SseFrame::Json(json) => assert!(json.contains(r#""sessionId":"session-2""#)), + SseFrame::Json { json, .. } => assert!(json.contains(r#""sessionId":"session-2""#)), + SseFrame::Empty { .. } => panic!("expected JSON SSE frame"), } } #[test] fn route_buffered_frame_sends_other_session_notifications_to_get_listeners() { - let frame = SseFrame::Json( + let frame = SseFrame::json( r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-2"}}"# .to_string(), ); - let parsed = OutgoingHttpMessage::parse(match &frame { - SseFrame::Json(json) => json, - }) - .unwrap(); + let parsed = OutgoingHttpMessage::parse(frame_json(&frame)).unwrap(); let request_id = RequestId::Number(1); let other_session_id = acp_nats::AcpSessionId::new("session-2").unwrap(); let (get_tx, mut get_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); @@ -1886,7 +2027,8 @@ mod tests { assert!(matches!(outcome, BufferedFrameOutcome::Routed)); assert!(events.is_empty()); match get_rx.try_recv().unwrap() { - SseFrame::Json(json) => assert!(json.contains(r#""sessionId":"session-2""#)), + SseFrame::Json { json, .. } => assert!(json.contains(r#""sessionId":"session-2""#)), + SseFrame::Empty { .. } => panic!("expected JSON SSE frame"), } } @@ -1896,24 +2038,46 @@ mod tests { let mut get_listeners = HashMap::new(); let (listener_tx, mut listener_rx) = mpsc::channel(1); listener_tx - .try_send(SseFrame::Json( + .try_send(SseFrame::json( r#"{"jsonrpc":"2.0","method":"session/update"}"#.to_string(), )) .unwrap(); get_listeners.insert(session_id.clone(), vec![listener_tx]); - let frame = SseFrame::Json(r#"{"jsonrpc":"2.0","method":"session/update"}"#.to_string()); + let frame = SseFrame::json(r#"{"jsonrpc":"2.0","method":"session/update"}"#.to_string()); let outcome = dispatch_to_get_listeners(&frame, &session_id, &mut get_listeners); assert!(matches!(outcome, ListenerDispatch::Dropped)); assert!(!get_listeners.contains_key(&session_id)); - assert!(matches!(listener_rx.try_recv(), Ok(SseFrame::Json(_)))); + assert!(matches!(listener_rx.try_recv(), Ok(SseFrame::Json { .. }))); assert!(matches!( listener_rx.try_recv(), Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) )); } + #[test] + fn dispatch_to_get_listeners_delivers_each_message_on_only_one_stream() { + let session_id = session_id(); + let mut get_listeners = HashMap::new(); + let (first_tx, mut first_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let (second_tx, mut second_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + get_listeners.insert(session_id.clone(), vec![first_tx, second_tx]); + + let frame = SseFrame::json( + r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-1"}}"# + .to_string(), + ); + let outcome = dispatch_to_get_listeners(&frame, &session_id, &mut get_listeners); + + assert!(matches!(outcome, ListenerDispatch::Delivered)); + assert!(matches!(first_rx.try_recv(), Ok(SseFrame::Json { .. }))); + assert!(matches!( + second_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) + )); + } + #[test] fn header_validators_enforce_content_negotiation() { let valid_post = post_headers(); @@ -1973,6 +2137,41 @@ mod tests { )); } + #[test] + fn validate_origin_allows_loopback_and_matching_hosts() { + let mut loopback = HeaderMap::new(); + loopback.insert(ORIGIN, HeaderValue::from_static("http://localhost:3000")); + assert!(validate_origin(&loopback, IpAddr::V4(Ipv4Addr::LOCALHOST)).is_ok()); + + let mut proxied = HeaderMap::new(); + proxied.insert(ORIGIN, HeaderValue::from_static("https://example.com")); + proxied.insert(HOST, HeaderValue::from_static("example.com")); + assert!(validate_origin(&proxied, IpAddr::V4(Ipv4Addr::UNSPECIFIED)).is_ok()); + } + + #[test] + fn validate_origin_rejects_invalid_and_remote_origins() { + let mut invalid = HeaderMap::new(); + invalid.insert(ORIGIN, HeaderValue::from_static("not a uri")); + assert!(matches!( + validate_origin(&invalid, IpAddr::V4(Ipv4Addr::LOCALHOST)), + Err(HttpTransportError::Forbidden { + message: "invalid Origin header", + source: Some(_), + }) + )); + + let mut remote = HeaderMap::new(); + remote.insert(ORIGIN, HeaderValue::from_static("https://evil.example")); + assert!(matches!( + validate_origin(&remote, IpAddr::V4(Ipv4Addr::LOCALHOST)), + Err(HttpTransportError::Forbidden { + message: "Origin is not allowed", + source: None, + }) + )); + } + #[test] fn validate_http_context_enforces_connection_and_session_rules() { let initialize = IncomingHttpMessage::parse( @@ -2227,7 +2426,7 @@ mod tests { connection_id: expected_connection_id, protocol_version: Some(ProtocolVersion::V0), session_id: Some(expected_session_id), - events: vec![SseFrame::Json(event.to_string())], + events: vec![SseFrame::json(event.to_string())], })); } _ => panic!("unexpected manager request"), @@ -2285,7 +2484,7 @@ mod tests { assert_eq!(actual_session_id, expected_session_id); assert_eq!(protocol_version, None); let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); - let _ = stream_tx.try_send(SseFrame::Json( + let _ = stream_tx.try_send(SseFrame::json( json!({ "jsonrpc": "2.0", "method": "session/update", @@ -2380,6 +2579,43 @@ mod tests { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } + #[tokio::test] + async fn post_rejects_disallowed_origin_before_processing() { + let (state, _manager_rx) = test_state(); + let mut headers = post_headers(); + headers.insert(ORIGIN, HeaderValue::from_static("https://evil.example")); + + let response = post( + headers, + State(state), + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"# + .to_string(), + ) + .await; + + assert_eq!(response.status(), StatusCode::FORBIDDEN); + let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + assert_eq!( + String::from_utf8(bytes.to_vec()).unwrap(), + "Origin is not allowed" + ); + } + + #[tokio::test] + async fn get_rejects_websocket_upgrade_with_disallowed_origin() { + let (state, _manager_rx) = test_state(); + let request = HttpRequest::builder() + .method("GET") + .uri("/acp") + .header("upgrade", "websocket") + .header(ORIGIN, "https://evil.example") + .body(Body::empty()) + .unwrap(); + + let response = get(State(state), request).await; + assert_eq!(response.status(), StatusCode::FORBIDDEN); + } + #[tokio::test] async fn process_manager_request_rejects_invalid_or_unknown_http_targets() { let nats_client = AdvancedMockNatsClient::new(); From 0e92e05b5f372aedb456b168c8e371bf4d8a0faa Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 19:46:52 -0400 Subject: [PATCH 20/23] fix(acp-nats-ws): honor session activation flow Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/src/main.rs | 84 ++++++++++++++++++- .../crates/acp-nats-ws/src/transport.rs | 62 ++++++++++++-- 2 files changed, 137 insertions(+), 9 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs index eaa6542fc..4be6b0e95 100644 --- a/rsworkspace/crates/acp-nats-ws/src/main.rs +++ b/rsworkspace/crates/acp-nats-ws/src/main.rs @@ -208,7 +208,7 @@ mod tests { use axum::http::header::{ACCEPT, CONTENT_TYPE}; use axum::http::{Request, StatusCode}; use futures_util::{SinkExt, StreamExt}; - use serde_json::Value; + use serde_json::{Value, json}; use std::net::{IpAddr, Ipv4Addr}; use std::time::Duration; use tokio::net::TcpListener; @@ -579,6 +579,88 @@ mod tests { conn_thread.join().unwrap(); } + #[tokio::test] + async fn streamable_http_session_load_uses_request_session_id_header() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + nats_mock.set_response("acp.session.test-session-1.agent.load", "{}".into()); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let initialize = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + let connection_id = initialize + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + let _ = body_text(initialize).await; + + let session_load = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(ACP_ENDPOINT) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .header(ACP_PROTOCOL_VERSION_HEADER, "0") + .body(Body::from( + r#"{"jsonrpc":"2.0","id":2,"method":"session/load","params":{"sessionId":"test-session-1","cwd":"."}}"#, + )) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(session_load.status(), StatusCode::OK); + assert_eq!( + session_load + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap(), + connection_id + ); + assert_eq!( + session_load + .headers() + .get(ACP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()), + Some("test-session-1") + ); + assert_eq!( + session_load + .headers() + .get(ACP_PROTOCOL_VERSION_HEADER) + .and_then(|value| value.to_str().ok()), + Some("0") + ); + + let body = body_text(session_load).await; + let events = sse_events(&body); + assert_eq!(events.len(), 1); + assert_eq!(events[0]["id"], 2); + assert_eq!(events[0]["result"], json!(null)); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + #[tokio::test] async fn streamable_http_get_requires_connection_and_session_headers() { let nats_mock = AdvancedMockNatsClient::new(); diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index f25fdc019..f7415af6a 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -348,6 +348,7 @@ enum PendingRequest { }, Buffered { request_id: RequestId, + fallback_session_id: Option, events: Vec, response: oneshot::Sender>, }, @@ -411,8 +412,14 @@ impl IncomingHttpMessage { self.method_name() == Some("initialize") } - fn is_session_new(&self) -> bool { - self.method_name() == Some("session/new") + fn activates_session(&self) -> bool { + matches!( + self.method_name(), + Some("session/new") + | Some("session/load") + | Some("session/resume") + | Some("session/fork") + ) } fn requires_session_id(&self) -> bool { @@ -421,8 +428,8 @@ impl IncomingHttpMessage { } match self.method_name() { - Some("initialize") | Some("authenticate") | Some("session/new") - | Some("session/list") => false, + Some("initialize") | Some("authenticate") | Some("session/list") => false, + _ if self.activates_session() => false, Some(method) if method.starts_with("session/") => true, _ => false, } @@ -1279,9 +1286,10 @@ pub async fn run_http_connection( continue; } - if message.is_session_new() { + if message.activates_session() { pending_request = Some(PendingRequest::Buffered { request_id: message.id.clone().expect("request must have id"), + fallback_session_id: message.params_session_id().ok().flatten(), events: Vec::new(), response, }); @@ -1419,7 +1427,12 @@ pub async fn run_http_connection( } continue; } - PendingRequest::Buffered { request_id, events, .. } => { + PendingRequest::Buffered { + request_id, + fallback_session_id, + events, + .. + } => { match route_buffered_frame( &frame, parsed.as_ref(), @@ -1429,6 +1442,8 @@ pub async fn run_http_connection( ) { BufferedFrameOutcome::Buffered | BufferedFrameOutcome::Routed => {} BufferedFrameOutcome::Finalize { session_id } => { + let session_id = + session_id.or_else(|| fallback_session_id.clone()); if let Some(session_id) = session_id.clone() { sessions.insert(session_id); } @@ -1852,7 +1867,17 @@ mod tests { ) .unwrap(); assert!(request.is_request()); - assert!(request.is_session_new()); + assert!(request.activates_session()); + + for raw in [ + r#"{"jsonrpc":"2.0","id":2,"method":"session/load","params":{"sessionId":"session-1","cwd":"."}}"#, + r#"{"jsonrpc":"2.0","id":3,"method":"session/resume","params":{"sessionId":"session-1","cwd":"."}}"#, + r#"{"jsonrpc":"2.0","id":4,"method":"session/fork","params":{"sessionId":"session-1","cwd":"."}}"#, + ] { + let activating = IncomingHttpMessage::parse(raw.to_string()).unwrap(); + assert!(activating.activates_session()); + assert!(!activating.requires_session_id()); + } let notification = IncomingHttpMessage::parse(r#"{"jsonrpc":"2.0","method":"initialized"}"#.to_string()) @@ -2227,6 +2252,27 @@ mod tests { source: None, }) )); + + let load = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/load","params":{"sessionId":"session-1","cwd":"."}}"# + .to_string(), + ) + .unwrap(); + assert!(validate_http_context(&load, Some(&connection_id), None).is_ok()); + + let resume = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/resume","params":{"sessionId":"session-1","cwd":"."}}"# + .to_string(), + ) + .unwrap(); + assert!(validate_http_context(&resume, Some(&connection_id), None).is_ok()); + + let fork = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/fork","params":{"sessionId":"session-1","cwd":"."}}"# + .to_string(), + ) + .unwrap(); + assert!(validate_http_context(&fork, Some(&connection_id), None).is_ok()); } #[test] @@ -2421,7 +2467,7 @@ mod tests { .. } => { assert_eq!(actual_connection_id, expected_connection_id.clone()); - assert!(message.is_session_new()); + assert!(message.activates_session()); let _ = response.send(Ok(HttpPostOutcome::Buffered { connection_id: expected_connection_id, protocol_version: Some(ProtocolVersion::V0), From 332230030f475b776dd893d1017a7742047f6bbb Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 22:05:10 -0400 Subject: [PATCH 21/23] refactor(acp-nats-ws): make explicit ids the primary constructor Signed-off-by: Yordis Prieto --- .../acp-nats-ws/src/acp_connection_id.rs | 22 ++++++++++++++----- .../crates/acp-nats-ws/src/transport.rs | 22 +++++++++---------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs index c00b3037f..b01f7f3bd 100644 --- a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs +++ b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs @@ -2,20 +2,24 @@ pub struct AcpConnectionId(uuid::Uuid); impl AcpConnectionId { - pub fn new() -> Self { + pub fn new(uuid: uuid::Uuid) -> Self { + Self(uuid) + } + + pub fn now_v7() -> Self { Self(uuid::Uuid::now_v7()) } pub fn parse(s: &str) -> Result { uuid::Uuid::parse_str(s) - .map(Self) + .map(Self::new) .map_err(AcpConnectionIdError::InvalidUuid) } } impl Default for AcpConnectionId { fn default() -> Self { - Self::new() + Self::now_v7() } } @@ -52,13 +56,19 @@ mod tests { use std::error::Error as _; #[test] - fn new_generates_non_empty_id() { - assert!(!AcpConnectionId::new().to_string().is_empty()); + fn new_wraps_existing_uuid() { + let uuid = uuid::Uuid::nil(); + assert_eq!(AcpConnectionId::new(uuid).to_string(), uuid.to_string()); + } + + #[test] + fn now_v7_generates_non_empty_id() { + assert!(!AcpConnectionId::now_v7().to_string().is_empty()); } #[test] fn parse_round_trips_uuid() { - let id = AcpConnectionId::new(); + let id = AcpConnectionId::now_v7(); let parsed = AcpConnectionId::parse(&id.to_string()).unwrap(); assert_eq!(parsed, id); } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index f7415af6a..c3b29be3d 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -677,7 +677,7 @@ pub async fn delete(headers: HeaderMap, State(state): State) -> Respon } fn websocket_response(ws: WebSocketUpgrade, state: AppState) -> Response { - let connection_id = AcpConnectionId::new(); + let connection_id = AcpConnectionId::now_v7(); let response_header = HeaderValue::from_str(&connection_id.to_string()) .expect("generated ACP connection id must be a valid header value"); let shutdown_rx = state.shutdown_tx.subscribe(); @@ -1571,7 +1571,7 @@ pub async fn process_manager_request( return; } - let connection_id = AcpConnectionId::new(); + let connection_id = AcpConnectionId::now_v7(); let (command_tx, command_rx) = mpsc::unbounded_channel(); http_connections.insert( connection_id.clone(), @@ -2204,7 +2204,7 @@ mod tests { .to_string(), ) .unwrap(); - let connection_id = AcpConnectionId::new(); + let connection_id = AcpConnectionId::now_v7(); let session_id = session_id(); assert!(validate_http_context(&initialize, None, None).is_ok()); @@ -2302,7 +2302,7 @@ mod tests { #[test] fn header_parsers_and_websocket_detection_handle_valid_and_invalid_values() { - let connection_id = AcpConnectionId::new(); + let connection_id = AcpConnectionId::now_v7(); let session_id = session_id(); let mut headers = HeaderMap::new(); headers.insert( @@ -2359,7 +2359,7 @@ mod tests { #[tokio::test] async fn http_post_returns_accepted_for_notifications() { let (state, mut manager_rx) = test_state(); - let connection_id = AcpConnectionId::new(); + let connection_id = AcpConnectionId::now_v7(); let expected_connection_id = connection_id.clone(); tokio::spawn(async move { @@ -2399,7 +2399,7 @@ mod tests { #[tokio::test] async fn http_post_accepts_null_result_responses() { let (state, mut manager_rx) = test_state(); - let connection_id = AcpConnectionId::new(); + let connection_id = AcpConnectionId::now_v7(); let session_id = session_id(); let expected_connection_id = connection_id.clone(); let expected_session_id = session_id.clone(); @@ -2446,7 +2446,7 @@ mod tests { #[tokio::test] async fn http_post_returns_buffered_sse_with_session_headers() { let (state, mut manager_rx) = test_state(); - let connection_id = AcpConnectionId::new(); + let connection_id = AcpConnectionId::now_v7(); let session_id = session_id(); let event = json!({ "jsonrpc": "2.0", @@ -2513,7 +2513,7 @@ mod tests { #[tokio::test] async fn http_get_and_delete_round_trip_through_manager() { let (state, mut manager_rx) = test_state(); - let connection_id = AcpConnectionId::new(); + let connection_id = AcpConnectionId::now_v7(); let session_id = session_id(); let expected_connection_id = connection_id.clone(); let expected_session_id = session_id.clone(); @@ -2701,7 +2701,7 @@ mod tests { }) )); - let unknown_connection_id = AcpConnectionId::new(); + let unknown_connection_id = AcpConnectionId::now_v7(); let (get_response_tx, get_response_rx) = oneshot::channel(); process_manager_request( ManagerRequest::HttpGet { @@ -2759,12 +2759,12 @@ mod tests { let mut websocket_handles = Vec::new(); let mut http_connection_handles = Vec::new(); - let stale_connection_id = AcpConnectionId::new(); + let stale_connection_id = AcpConnectionId::now_v7(); let (command_tx, command_rx) = mpsc::unbounded_channel(); drop(command_rx); http_connections.insert(stale_connection_id, HttpConnectionHandle { command_tx }); - let unknown_connection_id = AcpConnectionId::new(); + let unknown_connection_id = AcpConnectionId::now_v7(); let (response_tx, response_rx) = oneshot::channel(); process_manager_request( ManagerRequest::HttpGet { From 8898cb132ada66a271401303501d2fe82110c97f Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 22:07:15 -0400 Subject: [PATCH 22/23] refactor(acp-nats-ws): keep generated ids behind default Signed-off-by: Yordis Prieto --- .../acp-nats-ws/src/acp_connection_id.rs | 13 ++--------- .../crates/acp-nats-ws/src/transport.rs | 22 +++++++++---------- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs index b01f7f3bd..47bd2fc99 100644 --- a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs +++ b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs @@ -6,10 +6,6 @@ impl AcpConnectionId { Self(uuid) } - pub fn now_v7() -> Self { - Self(uuid::Uuid::now_v7()) - } - pub fn parse(s: &str) -> Result { uuid::Uuid::parse_str(s) .map(Self::new) @@ -19,7 +15,7 @@ impl AcpConnectionId { impl Default for AcpConnectionId { fn default() -> Self { - Self::now_v7() + Self::new(uuid::Uuid::now_v7()) } } @@ -61,14 +57,9 @@ mod tests { assert_eq!(AcpConnectionId::new(uuid).to_string(), uuid.to_string()); } - #[test] - fn now_v7_generates_non_empty_id() { - assert!(!AcpConnectionId::now_v7().to_string().is_empty()); - } - #[test] fn parse_round_trips_uuid() { - let id = AcpConnectionId::now_v7(); + let id = AcpConnectionId::default(); let parsed = AcpConnectionId::parse(&id.to_string()).unwrap(); assert_eq!(parsed, id); } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index c3b29be3d..0c47ac26e 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -677,7 +677,7 @@ pub async fn delete(headers: HeaderMap, State(state): State) -> Respon } fn websocket_response(ws: WebSocketUpgrade, state: AppState) -> Response { - let connection_id = AcpConnectionId::now_v7(); + let connection_id = AcpConnectionId::default(); let response_header = HeaderValue::from_str(&connection_id.to_string()) .expect("generated ACP connection id must be a valid header value"); let shutdown_rx = state.shutdown_tx.subscribe(); @@ -1571,7 +1571,7 @@ pub async fn process_manager_request( return; } - let connection_id = AcpConnectionId::now_v7(); + let connection_id = AcpConnectionId::default(); let (command_tx, command_rx) = mpsc::unbounded_channel(); http_connections.insert( connection_id.clone(), @@ -2204,7 +2204,7 @@ mod tests { .to_string(), ) .unwrap(); - let connection_id = AcpConnectionId::now_v7(); + let connection_id = AcpConnectionId::default(); let session_id = session_id(); assert!(validate_http_context(&initialize, None, None).is_ok()); @@ -2302,7 +2302,7 @@ mod tests { #[test] fn header_parsers_and_websocket_detection_handle_valid_and_invalid_values() { - let connection_id = AcpConnectionId::now_v7(); + let connection_id = AcpConnectionId::default(); let session_id = session_id(); let mut headers = HeaderMap::new(); headers.insert( @@ -2359,7 +2359,7 @@ mod tests { #[tokio::test] async fn http_post_returns_accepted_for_notifications() { let (state, mut manager_rx) = test_state(); - let connection_id = AcpConnectionId::now_v7(); + let connection_id = AcpConnectionId::default(); let expected_connection_id = connection_id.clone(); tokio::spawn(async move { @@ -2399,7 +2399,7 @@ mod tests { #[tokio::test] async fn http_post_accepts_null_result_responses() { let (state, mut manager_rx) = test_state(); - let connection_id = AcpConnectionId::now_v7(); + let connection_id = AcpConnectionId::default(); let session_id = session_id(); let expected_connection_id = connection_id.clone(); let expected_session_id = session_id.clone(); @@ -2446,7 +2446,7 @@ mod tests { #[tokio::test] async fn http_post_returns_buffered_sse_with_session_headers() { let (state, mut manager_rx) = test_state(); - let connection_id = AcpConnectionId::now_v7(); + let connection_id = AcpConnectionId::default(); let session_id = session_id(); let event = json!({ "jsonrpc": "2.0", @@ -2513,7 +2513,7 @@ mod tests { #[tokio::test] async fn http_get_and_delete_round_trip_through_manager() { let (state, mut manager_rx) = test_state(); - let connection_id = AcpConnectionId::now_v7(); + let connection_id = AcpConnectionId::default(); let session_id = session_id(); let expected_connection_id = connection_id.clone(); let expected_session_id = session_id.clone(); @@ -2701,7 +2701,7 @@ mod tests { }) )); - let unknown_connection_id = AcpConnectionId::now_v7(); + let unknown_connection_id = AcpConnectionId::default(); let (get_response_tx, get_response_rx) = oneshot::channel(); process_manager_request( ManagerRequest::HttpGet { @@ -2759,12 +2759,12 @@ mod tests { let mut websocket_handles = Vec::new(); let mut http_connection_handles = Vec::new(); - let stale_connection_id = AcpConnectionId::now_v7(); + let stale_connection_id = AcpConnectionId::default(); let (command_tx, command_rx) = mpsc::unbounded_channel(); drop(command_rx); http_connections.insert(stale_connection_id, HttpConnectionHandle { command_tx }); - let unknown_connection_id = AcpConnectionId::now_v7(); + let unknown_connection_id = AcpConnectionId::default(); let (response_tx, response_rx) = oneshot::channel(); process_manager_request( ManagerRequest::HttpGet { From 8805ee25e91a5a941af747df60f6f10ed84f4588 Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 22 Apr 2026 22:36:15 -0400 Subject: [PATCH 23/23] refactor(acp-nats-ws): keep transport defaults centralized Signed-off-by: Yordis Prieto --- rsworkspace/crates/acp-nats-ws/src/constants.rs | 1 + rsworkspace/crates/acp-nats-ws/src/transport.rs | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/rsworkspace/crates/acp-nats-ws/src/constants.rs b/rsworkspace/crates/acp-nats-ws/src/constants.rs index 767031bc6..93c278bad 100644 --- a/rsworkspace/crates/acp-nats-ws/src/constants.rs +++ b/rsworkspace/crates/acp-nats-ws/src/constants.rs @@ -7,5 +7,6 @@ pub const ACP_SESSION_ID_HEADER: &str = "acp-session-id"; pub const DEFAULT_HOST: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); pub const DEFAULT_PORT: u16 = 8080; pub const DUPLEX_BUFFER_SIZE: usize = 64 * 1024; +pub const HTTP_CHANNEL_CAPACITY: usize = 64; pub const THREAD_NAME: &str = "acp-ws-local"; pub const X_ACCEL_BUFFERING_HEADER: &str = "x-accel-buffering"; diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs index 0c47ac26e..9c710b49a 100644 --- a/rsworkspace/crates/acp-nats-ws/src/transport.rs +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -2,7 +2,7 @@ use crate::acp_connection_id::AcpConnectionId; use crate::connection; use crate::constants::{ ACP_CONNECTION_ID_HEADER, ACP_PROTOCOL_VERSION_HEADER, ACP_SESSION_ID_HEADER, - X_ACCEL_BUFFERING_HEADER, + HTTP_CHANNEL_CAPACITY, X_ACCEL_BUFFERING_HEADER, }; use acp_nats::{StdJsonSerialize, agent::Bridge, client, spawn_notification_forwarder}; use agent_client_protocol::{AgentSideConnection, ProtocolVersion, RequestId, SessionNotification}; @@ -27,8 +27,6 @@ use tracing::{error, info, warn}; use trogon_std::time::SystemClock; use uuid::Uuid; -const HTTP_CHANNEL_CAPACITY: usize = 64; - type BoxError = Box; type SseSender = mpsc::Sender; type SseReceiver = mpsc::Receiver;