diff --git a/rsworkspace/Cargo.lock b/rsworkspace/Cargo.lock index 9ed295afb..2f720feae 100644 --- a/rsworkspace/Cargo.lock +++ b/rsworkspace/Cargo.lock @@ -76,13 +76,16 @@ dependencies = [ "clap", "futures-util", "opentelemetry", + "serde", "serde_json", "tokio", "tokio-tungstenite 0.29.0", + "tower", "tracing", "tracing-subscriber", "trogon-nats", "trogon-std", + "uuid", ] [[package]] diff --git a/rsworkspace/crates/acp-nats-ws/Cargo.toml b/rsworkspace/crates/acp-nats-ws/Cargo.toml index c5451d991..9850b06d3 100644 --- a/rsworkspace/crates/acp-nats-ws/Cargo.toml +++ b/rsworkspace/crates/acp-nats-ws/Cargo.toml @@ -17,13 +17,16 @@ bytes = { workspace = true } clap = { workspace = true, features = ["env"] } futures-util = { workspace = true, features = ["sink"] } opentelemetry = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal", "net", "sync", "io-util"] } tracing = { workspace = true } trogon-nats = { workspace = true } trogon-std = { workspace = true, features = ["telemetry-http"] } +uuid = { workspace = true } [dev-dependencies] -serde_json = { workspace = true } +tower = { version = "0.5", features = ["util"] } tokio-tungstenite = { workspace = true } tracing-subscriber = { workspace = true, features = ["fmt"] } trogon-nats = { workspace = true, features = ["test-support"] } diff --git a/rsworkspace/crates/acp-nats-ws/README.md b/rsworkspace/crates/acp-nats-ws/README.md index 846fc17b6..c30f99a05 100644 --- a/rsworkspace/crates/acp-nats-ws/README.md +++ b/rsworkspace/crates/acp-nats-ws/README.md @@ -1,21 +1,22 @@ -# ACP NATS WebSocket +# ACP NATS Streamable HTTP & WebSocket -Translates [Agent Client Protocol](https://agentclientprotocol.com) (ACP) messages between WebSocket connections and [NATS](https://nats.io), letting browser-based UIs and remote clients talk to distributed agent backends over a standard WebSocket endpoint. +Translates [Agent Client Protocol](https://agentclientprotocol.com) (ACP) messages between [NATS](https://nats.io) and the draft remote transport served on `/acp`, including both Streamable HTTP (`POST`/`GET`/`DELETE`) and WebSocket upgrade. For managed NATS infrastructure in production, we recommend Synadia Synadia. ```mermaid graph LR - A1[Client1] <-->|ws| B[acp-nats-ws] - A2[Client2] <-->|ws| B + A1[Client1] <-->|http or ws| B[acp-nats-ws] + A2[Client2] <-->|http or ws| B AN[ClientN] <-->|ws| B B <-->|NATS| C[Backend] ``` ## Features -- Multiple concurrent WebSocket connections, each with its own ACP session -- Bidirectional ACP bridge with request forwarding +- Streamable HTTP transport on `/acp` with session-scoped SSE listeners +- WebSocket upgrade on `/acp` +- Multiple concurrent ACP connections sharing the same NATS bridge - OpenTelemetry integration (logs, metrics, traces) - Graceful shutdown (SIGINT/SIGTERM) with per-connection drain - Custom prefix support for multi-tenancy @@ -33,9 +34,25 @@ cargo build --release -p acp-nats-ws Connect with any WebSocket client: ```bash -websocat ws://127.0.0.1:8080/ws +websocat ws://127.0.0.1:8080/acp ``` +Or use Streamable HTTP: + +```bash +curl -i \ + -H 'Content-Type: application/json' \ + -H 'Accept: application/json, text/event-stream' \ + -d '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}' \ + http://127.0.0.1:8080/acp +``` + +`POST /acp` returns an SSE response for JSON-RPC requests, `GET /acp` opens a session-scoped SSE listener with `Acp-Connection-Id` and `Acp-Session-Id`, and `DELETE /acp` terminates a connection. The WebSocket upgrade response and HTTP initialize response both include `Acp-Connection-Id`. + +After `initialize`, HTTP clients may send `Acp-Protocol-Version` on `POST`/`GET`/`DELETE`. When present, it must match the negotiated ACP protocol version for that connection. + +When clients send an `Origin` header, `/acp` validates it against the bound host and rejects disallowed origins with `403 Forbidden`. Streamable HTTP `POST` SSE responses also emit a priming SSE event ID before JSON-RPC payloads and attach event IDs to streamed JSON events. + ## Configuration ### WebSocket Server diff --git a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs new file mode 100644 index 000000000..47bd2fc99 --- /dev/null +++ b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs @@ -0,0 +1,88 @@ +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct AcpConnectionId(uuid::Uuid); + +impl AcpConnectionId { + pub fn new(uuid: uuid::Uuid) -> Self { + Self(uuid) + } + + pub fn parse(s: &str) -> Result { + uuid::Uuid::parse_str(s) + .map(Self::new) + .map_err(AcpConnectionIdError::InvalidUuid) + } +} + +impl Default for AcpConnectionId { + fn default() -> Self { + Self::new(uuid::Uuid::now_v7()) + } +} + +impl std::fmt::Display for AcpConnectionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Debug)] +pub enum AcpConnectionIdError { + InvalidUuid(uuid::Error), +} + +impl std::fmt::Display for AcpConnectionIdError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidUuid(error) => write!(f, "invalid ACP connection id: {error}"), + } + } +} + +impl std::error::Error for AcpConnectionIdError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::InvalidUuid(error) => Some(error), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::error::Error as _; + + #[test] + fn new_wraps_existing_uuid() { + let uuid = uuid::Uuid::nil(); + assert_eq!(AcpConnectionId::new(uuid).to_string(), uuid.to_string()); + } + + #[test] + fn parse_round_trips_uuid() { + let id = AcpConnectionId::default(); + let parsed = AcpConnectionId::parse(&id.to_string()).unwrap(); + assert_eq!(parsed, id); + } + + #[test] + fn parse_rejects_invalid_uuid() { + assert!(AcpConnectionId::parse("not-a-uuid").is_err()); + } + + #[test] + fn default_generates_non_empty_id() { + assert!(!AcpConnectionId::default().to_string().is_empty()); + } + + #[test] + fn parse_error_displays_context() { + let error = AcpConnectionId::parse("not-a-uuid").unwrap_err(); + assert!(error.to_string().contains("invalid ACP connection id")); + } + + #[test] + fn parse_error_exposes_uuid_source() { + let error = AcpConnectionId::parse("not-a-uuid").unwrap_err(); + assert!(error.source().is_some()); + } +} diff --git a/rsworkspace/crates/acp-nats-ws/src/connection.rs b/rsworkspace/crates/acp-nats-ws/src/connection.rs index 7f54f25b9..0cfa840d7 100644 --- a/rsworkspace/crates/acp-nats-ws/src/connection.rs +++ b/rsworkspace/crates/acp-nats-ws/src/connection.rs @@ -9,10 +9,12 @@ use tokio::sync::watch; use tracing::{error, info, warn}; use trogon_std::time::SystemClock; +use crate::acp_connection_id::AcpConnectionId; use crate::constants::DUPLEX_BUFFER_SIZE; /// Handles a single WebSocket connection by bridging it to NATS via ACP. pub async fn handle( + connection_id: AcpConnectionId, socket: WebSocket, nats_client: N, js_client: J, @@ -68,7 +70,7 @@ pub async fn handle( let mut io_task = tokio::task::spawn_local(io_task); - info!("WebSocket connection established, ACP bridge running"); + info!(%connection_id, "WebSocket connection established, ACP bridge running"); let shutdown_result = tokio::select! { result = &mut client_task => { @@ -118,8 +120,8 @@ pub async fn handle( } match shutdown_result { - Ok(()) => info!("WebSocket connection closed cleanly"), - Err(e) => warn!(error = e, "WebSocket connection closed with error"), + Ok(()) => info!(%connection_id, "WebSocket connection closed cleanly"), + Err(e) => warn!(%connection_id, error = e, "WebSocket connection closed with error"), } } @@ -128,34 +130,24 @@ async fn run_recv_pump( mut ws_recv_write: tokio::io::DuplexStream, ) { while let Some(Ok(msg)) = ws_receiver.next().await { - let bytes = match msg { - Message::Text(t) => bytes::Bytes::from(t), - Message::Binary(b) => b, + let text = match msg { + Message::Text(text) => text, + Message::Binary(_) => continue, Message::Close(_) => break, _ => continue, }; - match std::str::from_utf8(&bytes) { - Ok(text) => { - let line = text.trim_end_matches(['\r', '\n']); - if line.is_empty() { - continue; - } + let line = text.trim_end_matches(['\r', '\n']); + if line.is_empty() { + continue; + } - if ws_recv_write.write_all(line.as_bytes()).await.is_err() { - break; - } + if ws_recv_write.write_all(line.as_bytes()).await.is_err() { + break; + } - if ws_recv_write.write_all(b"\n").await.is_err() { - break; - } - } - Err(e) => { - warn!( - error = %e, - "Received non-UTF-8 WebSocket message, dropping frame" - ); - } + if ws_recv_write.write_all(b"\n").await.is_err() { + break; } } } @@ -196,6 +188,8 @@ mod tests { use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message as TungsteniteMessage; + use crate::constants::ACP_ENDPOINT; + #[derive(Clone)] struct EchoState; @@ -213,12 +207,12 @@ mod tests { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let app = axum::Router::new() - .route("/ws", axum::routing::get(echo_handler)) + .route(ACP_ENDPOINT, axum::routing::get(echo_handler)) .with_state(EchoState); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); }); - format!("ws://{}/ws", addr) + format!("ws://{}{}", addr, ACP_ENDPOINT) } #[tokio::test] @@ -245,4 +239,30 @@ mod tests { } } } + + #[tokio::test] + async fn binary_messages_are_ignored() { + let url = start_echo_server().await; + let (mut ws, _) = connect_async(&url).await.unwrap(); + + ws.send(TungsteniteMessage::Binary(bytes::Bytes::from_static( + b"ignored", + ))) + .await + .unwrap(); + ws.send(TungsteniteMessage::Text("kept".into())) + .await + .unwrap(); + + let msg = tokio::time::timeout(Duration::from_secs(2), ws.next()) + .await + .expect("timeout") + .expect("stream ended") + .unwrap(); + + match msg { + TungsteniteMessage::Text(text) => assert_eq!(text, "kept"), + other => panic!("expected text frame, got {other:?}"), + } + } } diff --git a/rsworkspace/crates/acp-nats-ws/src/constants.rs b/rsworkspace/crates/acp-nats-ws/src/constants.rs index f78fd678c..93c278bad 100644 --- a/rsworkspace/crates/acp-nats-ws/src/constants.rs +++ b/rsworkspace/crates/acp-nats-ws/src/constants.rs @@ -1,6 +1,12 @@ use std::net::{IpAddr, Ipv4Addr}; +pub const ACP_CONNECTION_ID_HEADER: &str = "acp-connection-id"; +pub const ACP_ENDPOINT: &str = "/acp"; +pub const ACP_PROTOCOL_VERSION_HEADER: &str = "acp-protocol-version"; +pub const ACP_SESSION_ID_HEADER: &str = "acp-session-id"; pub const DEFAULT_HOST: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); pub const DEFAULT_PORT: u16 = 8080; pub const DUPLEX_BUFFER_SIZE: usize = 64 * 1024; +pub const HTTP_CHANNEL_CAPACITY: usize = 64; pub const THREAD_NAME: &str = "acp-ws-local"; +pub const X_ACCEL_BUFFERING_HEADER: &str = "x-accel-buffering"; diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs index 6faee93d7..4be6b0e95 100644 --- a/rsworkspace/crates/acp-nats-ws/src/main.rs +++ b/rsworkspace/crates/acp-nats-ws/src/main.rs @@ -1,11 +1,12 @@ +mod acp_connection_id; mod config; mod connection; mod constants; -mod upgrade; +mod transport; use tokio::sync::{mpsc, watch}; use tracing::info; -use upgrade::{ConnectionRequest, UpgradeState}; +use transport::{AppState, ManagerRequest}; #[cfg(not(coverage))] use { @@ -26,7 +27,7 @@ async fn main() -> Result<(), Box> { ); let ws_config = config::apply_timeout_overrides(ws_config, &SystemEnv); - info!("ACP WebSocket bridge starting"); + info!("ACP remote transport bridge starting"); let nats_connect_timeout = acp_nats::nats_connect_timeout(&SystemEnv); let nats_client = nats::connect(ws_config.acp.nats(), nats_connect_timeout).await?; @@ -35,27 +36,33 @@ async fn main() -> Result<(), Box> { let js_client = trogon_nats::jetstream::NatsJetStreamClient::new(js_context); let (shutdown_tx, _) = watch::channel(false); - let (conn_tx, conn_rx) = mpsc::unbounded_channel::(); + let (manager_tx, manager_rx) = mpsc::unbounded_channel::(); let conn_thread = std::thread::Builder::new() .name(THREAD_NAME.into()) - .spawn(move || run_connection_thread(conn_rx, nats_client, js_client, ws_config.acp))?; + .spawn(move || run_connection_thread(manager_rx, nats_client, js_client, ws_config.acp))?; - let state = UpgradeState { - conn_tx, + let state = AppState { + bind_host: ws_config.host, + manager_tx, shutdown_tx: shutdown_tx.clone(), }; let app = trogon_std::telemetry::http::instrument_router( axum::Router::new() - .route("/ws", axum::routing::get(upgrade::handle)) + .route( + ACP_ENDPOINT, + axum::routing::get(transport::get) + .post(transport::post) + .delete(transport::delete), + ) .with_state(state), ); let addr = SocketAddr::from((ws_config.host, ws_config.port)); let listener = tokio::net::TcpListener::bind(addr).await?; - info!(address = %addr, "Listening for WebSocket connections"); + info!(address = %addr, "Listening for ACP transport connections"); let result = axum::serve(listener, app) .with_graceful_shutdown(async move { @@ -66,8 +73,8 @@ async fn main() -> Result<(), Box> { .await; match &result { - Ok(()) => info!("ACP WebSocket bridge stopped"), - Err(e) => error!(error = %e, "ACP WebSocket bridge stopped with error"), + Ok(()) => info!("ACP remote transport bridge stopped"), + Err(e) => error!(error = %e, "ACP remote transport bridge stopped with error"), } // `serve` returning drops the Router (and its AppState.conn_tx), which @@ -88,13 +95,13 @@ async fn main() -> Result<(), Box> { #[cfg(coverage)] fn main() {} -use constants::THREAD_NAME; +use constants::{ACP_ENDPOINT, THREAD_NAME}; /// Runs a single-threaded tokio runtime with a /// `LocalSet`. All WebSocket connections are processed here because the ACP /// `Agent` trait is `?Send`, requiring `spawn_local` / `Rc`. fn run_connection_thread( - conn_rx: mpsc::UnboundedReceiver, + manager_rx: mpsc::UnboundedReceiver, nats_client: N, js_client: J, config: acp_nats::Config, @@ -115,7 +122,12 @@ fn run_connection_thread( .expect("failed to create per-connection runtime"); let local = tokio::task::LocalSet::new(); - rt.block_on(local.run_until(process_connections(conn_rx, nats_client, js_client, config))); + rt.block_on(local.run_until(process_connections( + manager_rx, + nats_client, + js_client, + config, + ))); // run_until returns once process_connections completes, but // sub-tasks spawned by connection handlers (pumps, @@ -130,7 +142,7 @@ fn run_connection_thread( } async fn process_connections( - mut conn_rx: mpsc::UnboundedReceiver, + mut manager_rx: mpsc::UnboundedReceiver, nats_client: N, js_client: J, config: acp_nats::Config, @@ -145,29 +157,40 @@ async fn process_connections( J: acp_nats::JetStreamPublisher + acp_nats::JetStreamGetStream + 'static, trogon_nats::jetstream::JsMessageOf: trogon_nats::jetstream::JsRequestMessage, { - let mut conn_handles: Vec> = Vec::new(); - - while let Some(req) = conn_rx.recv().await { - conn_handles.retain(|h| !h.is_finished()); - let client = nats_client.clone(); - let js = js_client.clone(); - let cfg = config.clone(); - conn_handles.push(tokio::task::spawn_local(connection::handle( - req.socket, - client, - js, - cfg, - req.shutdown_rx, - ))); + let mut websocket_handles: Vec> = Vec::new(); + let mut http_connection_handles: Vec> = Vec::new(); + let mut http_connections = std::collections::HashMap::new(); + + while let Some(request) = manager_rx.recv().await { + transport::process_manager_request( + request, + &mut http_connections, + &mut websocket_handles, + &mut http_connection_handles, + &nats_client, + &js_client, + &config, + ) + .await; } - let active = conn_handles.iter().filter(|h| !h.is_finished()).count(); + let active = websocket_handles + .iter() + .filter(|h| !h.is_finished()) + .count() + + http_connection_handles + .iter() + .filter(|h| !h.is_finished()) + .count(); info!( active_connections = active, "Connection channel closed, draining active connections" ); - for handle in conn_handles { + for handle in websocket_handles { + let _ = handle.await; + } + for handle in http_connection_handles { let _ = handle.await; } @@ -177,12 +200,21 @@ async fn process_connections( #[cfg(test)] mod tests { use super::*; + use crate::constants::{ + ACP_CONNECTION_ID_HEADER, ACP_PROTOCOL_VERSION_HEADER, ACP_SESSION_ID_HEADER, + }; use acp_nats::Config; + use axum::body::{Body, to_bytes}; + use axum::http::header::{ACCEPT, CONTENT_TYPE}; + use axum::http::{Request, StatusCode}; use futures_util::{SinkExt, StreamExt}; + use serde_json::{Value, json}; + use std::net::{IpAddr, Ipv4Addr}; use std::time::Duration; use tokio::net::TcpListener; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::Message; + use tower::ServiceExt; use trogon_nats::AdvancedMockNatsClient; #[derive(Clone)] @@ -230,41 +262,75 @@ mod tests { } } - #[tokio::test] - async fn test_websocket_connection_lifecycle() { - let nats_mock = AdvancedMockNatsClient::new(); - let config = Config::new( + fn test_config() -> Config { + Config::new( acp_nats::AcpPrefix::new("acp").unwrap(), acp_nats::NatsConfig { servers: vec!["localhost:4222".to_string()], auth: trogon_nats::NatsAuth::None, }, - ); - - // Required by AdvancedMockNatsClient to not error out on subscribe() - let _injector = nats_mock.inject_messages(); + ) + } - let (shutdown_tx, mut shutdown_rx) = watch::channel(false); - let (conn_tx, conn_rx) = mpsc::unbounded_channel::(); + fn test_app(state: AppState) -> axum::Router { + axum::Router::new() + .route( + ACP_ENDPOINT, + axum::routing::get(transport::get) + .post(transport::post) + .delete(transport::delete), + ) + .with_state(state) + } - let nats_mock_clone = nats_mock.clone(); - let conn_thread = std::thread::Builder::new() + fn spawn_connection_thread( + nats_mock: AdvancedMockNatsClient, + manager_rx: mpsc::UnboundedReceiver, + ) -> std::thread::JoinHandle<()> { + let config = test_config(); + std::thread::Builder::new() .name(THREAD_NAME.into()) - .spawn(move || run_connection_thread(conn_rx, nats_mock_clone, MockJs::new(), config)) - .expect("failed to spawn connection thread"); + .spawn(move || run_connection_thread(manager_rx, nats_mock, MockJs::new(), config)) + .expect("failed to spawn connection thread") + } - let state = UpgradeState { - conn_tx, + fn build_test_app( + nats_mock: AdvancedMockNatsClient, + ) -> ( + axum::Router, + watch::Sender, + std::thread::JoinHandle<()>, + ) { + let (shutdown_tx, _) = watch::channel(false); + let (manager_tx, manager_rx) = mpsc::unbounded_channel::(); + let conn_thread = spawn_connection_thread(nats_mock, manager_rx); + let app = test_app(AppState { + bind_host: IpAddr::V4(Ipv4Addr::LOCALHOST), + manager_tx, shutdown_tx: shutdown_tx.clone(), - }; + }); + (app, shutdown_tx, conn_thread) + } - let app = axum::Router::new() - .route("/ws", axum::routing::get(upgrade::handle)) - .with_state(state); + async fn start_test_server( + nats_mock: AdvancedMockNatsClient, + ) -> ( + std::net::SocketAddr, + watch::Sender, + tokio::task::JoinHandle<()>, + std::thread::JoinHandle<()>, + ) { + let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + let (manager_tx, manager_rx) = mpsc::unbounded_channel::(); + let conn_thread = spawn_connection_thread(nats_mock, manager_rx); + let app = test_app(AppState { + bind_host: IpAddr::V4(Ipv4Addr::LOCALHOST), + manager_tx, + shutdown_tx: shutdown_tx.clone(), + }); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let server_task = tokio::spawn(async move { axum::serve(listener, app) .with_graceful_shutdown(async move { @@ -274,13 +340,49 @@ mod tests { .unwrap(); }); + (addr, shutdown_tx, server_task, conn_thread) + } + + async fn body_text(response: axum::response::Response) -> String { + let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + String::from_utf8(bytes.to_vec()).unwrap() + } + + fn sse_events(body: &str) -> Vec { + body.lines() + .filter_map(|line| line.strip_prefix("data: ")) + .filter(|json| !json.is_empty()) + .map(|json| serde_json::from_str(json).unwrap()) + .collect() + } + + fn http_post_request(body: &str) -> Request { + Request::builder() + .method("POST") + .uri(ACP_ENDPOINT) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .body(Body::from(body.to_owned())) + .unwrap() + } + + #[tokio::test] + async fn test_websocket_connection_lifecycle() { + let nats_mock = AdvancedMockNatsClient::new(); + + // Required by AdvancedMockNatsClient to not error out on subscribe() + let _injector = nats_mock.inject_messages(); + // Setup mock response for NATS let nats_response = r#"{"agentCapabilities": {"loadSession": false, "mcpCapabilities": {"http": false, "sse": false}, "promptCapabilities": {"audio": false, "embeddedContext": false, "image": false}, "sessionCapabilities": {}}, "authMethods": [], "protocolVersion": 0}"#; nats_mock.set_response("acp.agent.initialize", nats_response.into()); + let (addr, shutdown_tx, server_task, conn_thread) = start_test_server(nats_mock).await; + // Connect client - let ws_url = format!("ws://{}/ws", addr); - let (mut ws_stream, _) = connect_async(ws_url).await.unwrap(); + let ws_url = format!("ws://{}{}", addr, ACP_ENDPOINT); + let (mut ws_stream, response) = connect_async(ws_url).await.unwrap(); + assert!(response.headers().contains_key(ACP_CONNECTION_ID_HEADER)); // Send initialize request let req = @@ -322,49 +424,13 @@ mod tests { #[tokio::test] async fn test_shutdown_while_connection_active() { let nats_mock = AdvancedMockNatsClient::new(); - let config = Config::new( - acp_nats::AcpPrefix::new("acp").unwrap(), - acp_nats::NatsConfig { - servers: vec!["localhost:4222".to_string()], - auth: trogon_nats::NatsAuth::None, - }, - ); let _injector = nats_mock.inject_messages(); nats_mock.hang_next_request(); + let (addr, shutdown_tx, server_task, conn_thread) = start_test_server(nats_mock).await; - let (shutdown_tx, mut shutdown_rx) = watch::channel(false); - let (conn_tx, conn_rx) = mpsc::unbounded_channel::(); - - let nats_mock_clone = nats_mock.clone(); - let conn_thread = std::thread::Builder::new() - .name(THREAD_NAME.into()) - .spawn(move || run_connection_thread(conn_rx, nats_mock_clone, MockJs::new(), config)) - .expect("failed to spawn connection thread"); - - let state = UpgradeState { - conn_tx, - shutdown_tx: shutdown_tx.clone(), - }; - - let app = axum::Router::new() - .route("/ws", axum::routing::get(upgrade::handle)) - .with_state(state); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let server_task = tokio::spawn(async move { - axum::serve(listener, app) - .with_graceful_shutdown(async move { - let _ = shutdown_rx.changed().await; - }) - .await - .unwrap(); - }); - - let ws_url = format!("ws://{}/ws", addr); + let ws_url = format!("ws://{}{}", addr, ACP_ENDPOINT); let (mut ws_stream, _) = connect_async(&ws_url).await.unwrap(); let req = @@ -383,4 +449,373 @@ mod tests { conn_thread.join().unwrap(); } + + #[tokio::test] + async fn streamable_http_initialize_returns_connection_id_and_sse_response() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let response = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(CONTENT_TYPE).unwrap(), + "text/event-stream" + ); + + let connection_id = response + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + assert!(crate::acp_connection_id::AcpConnectionId::parse(&connection_id).is_ok()); + + let body = body_text(response).await; + let events = sse_events(&body); + assert_eq!(events.len(), 1); + assert_eq!(events[0]["id"], 1); + assert_eq!(events[0]["result"]["protocolVersion"], 0); + + shutdown_tx.send(true).unwrap(); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn streamable_http_session_new_returns_session_header_and_body() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + nats_mock.set_response( + "acp.agent.session.new", + r#"{"sessionId":"test-session-1"}"#.into(), + ); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let initialize = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + let connection_id = initialize + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + let _ = body_text(initialize).await; + + let session_new = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(ACP_ENDPOINT) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .header(ACP_PROTOCOL_VERSION_HEADER, "0") + .body(Body::from( + r#"{"jsonrpc":"2.0","id":2,"method":"session/new","params":{"cwd":".","mcpServers":[]}}"#, + )) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(session_new.status(), StatusCode::OK); + assert_eq!( + session_new + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap(), + connection_id + ); + + let session_id = session_new + .headers() + .get(ACP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(str::to_owned); + assert_eq!( + session_new + .headers() + .get(ACP_PROTOCOL_VERSION_HEADER) + .and_then(|value| value.to_str().ok()), + Some("0") + ); + let body = body_text(session_new).await; + let events = sse_events(&body); + assert_eq!(events.len(), 1); + assert_eq!(events[0]["id"], 2); + assert_eq!(events[0]["result"]["sessionId"], "test-session-1"); + assert_eq!(session_id.as_deref(), Some("test-session-1")); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn streamable_http_session_load_uses_request_session_id_header() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + nats_mock.set_response("acp.session.test-session-1.agent.load", "{}".into()); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let initialize = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + let connection_id = initialize + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + let _ = body_text(initialize).await; + + let session_load = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(ACP_ENDPOINT) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .header(ACP_PROTOCOL_VERSION_HEADER, "0") + .body(Body::from( + r#"{"jsonrpc":"2.0","id":2,"method":"session/load","params":{"sessionId":"test-session-1","cwd":"."}}"#, + )) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(session_load.status(), StatusCode::OK); + assert_eq!( + session_load + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap(), + connection_id + ); + assert_eq!( + session_load + .headers() + .get(ACP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()), + Some("test-session-1") + ); + assert_eq!( + session_load + .headers() + .get(ACP_PROTOCOL_VERSION_HEADER) + .and_then(|value| value.to_str().ok()), + Some("0") + ); + + let body = body_text(session_load).await; + let events = sse_events(&body); + assert_eq!(events.len(), 1); + assert_eq!(events[0]["id"], 2); + assert_eq!(events[0]["result"], json!(null)); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn streamable_http_get_requires_connection_and_session_headers() { + let nats_mock = AdvancedMockNatsClient::new(); + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + + let response = app + .clone() + .oneshot( + Request::builder() + .method("GET") + .uri(ACP_ENDPOINT) + .header(ACCEPT, "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn legacy_websocket_alias_is_not_routed() { + let nats_mock = AdvancedMockNatsClient::new(); + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + + let response = app + .clone() + .oneshot( + Request::builder() + .method("GET") + .uri("/ws") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn streamable_http_delete_terminates_initialized_connection() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let initialize = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + let connection_id = initialize + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + let _ = body_text(initialize).await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("DELETE") + .uri(ACP_ENDPOINT) + .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .header(ACP_PROTOCOL_VERSION_HEADER, "0") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::ACCEPTED); + assert_eq!( + response + .headers() + .get(ACP_PROTOCOL_VERSION_HEADER) + .and_then(|value| value.to_str().ok()), + Some("0") + ); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } + + #[tokio::test] + async fn streamable_http_rejects_mismatched_protocol_version_after_initialize() { + let nats_mock = AdvancedMockNatsClient::new(); + let _injector = nats_mock.inject_messages(); + nats_mock.set_response( + "acp.agent.initialize", + r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"# + .into(), + ); + + let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock); + let initialize = app + .clone() + .oneshot(http_post_request( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#, + )) + .await + .unwrap(); + let connection_id = initialize + .headers() + .get(ACP_CONNECTION_ID_HEADER) + .unwrap() + .to_str() + .unwrap() + .to_owned(); + let _ = body_text(initialize).await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(ACP_ENDPOINT) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .header(ACP_CONNECTION_ID_HEADER, &connection_id) + .header(ACP_PROTOCOL_VERSION_HEADER, "1") + .body(Body::from(r#"{"jsonrpc":"2.0","method":"initialized"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert_eq!( + body_text(response).await, + "Acp-Protocol-Version header does not match initialized protocol version" + ); + + let _ = shutdown_tx.send(true); + drop(app); + conn_thread.join().unwrap(); + } } diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs new file mode 100644 index 000000000..9c710b49a --- /dev/null +++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs @@ -0,0 +1,2853 @@ +use crate::acp_connection_id::AcpConnectionId; +use crate::connection; +use crate::constants::{ + ACP_CONNECTION_ID_HEADER, ACP_PROTOCOL_VERSION_HEADER, ACP_SESSION_ID_HEADER, + HTTP_CHANNEL_CAPACITY, X_ACCEL_BUFFERING_HEADER, +}; +use acp_nats::{StdJsonSerialize, agent::Bridge, client, spawn_notification_forwarder}; +use agent_client_protocol::{AgentSideConnection, ProtocolVersion, RequestId, SessionNotification}; +use axum::extract::FromRequestParts; +use axum::extract::Request; +use axum::extract::State; +use axum::extract::ws::{WebSocket, WebSocketUpgrade}; +use axum::http::header::{ACCEPT, CONTENT_TYPE, HOST, ORIGIN}; +use axum::http::{HeaderMap, HeaderValue, StatusCode, Uri, uri::Authority}; +use axum::response::sse::{Event, Sse}; +use axum::response::{IntoResponse, Response}; +use futures_util::{StreamExt, stream}; +use serde::Deserialize; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; +use std::convert::Infallible; +use std::net::IpAddr; +use std::rc::Rc; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; +use tokio::sync::{mpsc, oneshot, watch}; +use tracing::{error, info, warn}; +use trogon_std::time::SystemClock; +use uuid::Uuid; + +type BoxError = Box; +type SseSender = mpsc::Sender; +type SseReceiver = mpsc::Receiver; + +#[derive(Clone)] +pub struct AppState { + pub bind_host: IpAddr, + pub manager_tx: mpsc::UnboundedSender, + pub shutdown_tx: watch::Sender, +} + +pub struct ConnectionRequest { + pub connection_id: AcpConnectionId, + pub socket: WebSocket, + pub shutdown_rx: watch::Receiver, +} + +pub enum ManagerRequest { + WebSocket(Box), + HttpPost { + connection_id: Option, + protocol_version: Option, + session_id: Option, + message: IncomingHttpMessage, + response: oneshot::Sender>, + shutdown_rx: watch::Receiver, + }, + HttpGet { + connection_id: AcpConnectionId, + session_id: acp_nats::AcpSessionId, + protocol_version: Option, + response: oneshot::Sender>, + }, + HttpDelete { + connection_id: AcpConnectionId, + protocol_version: Option, + response: oneshot::Sender, HttpTransportError>>, + }, +} + +#[derive(Debug)] +pub enum HttpPostOutcome { + Accepted, + Live { + connection_id: AcpConnectionId, + protocol_version: Option, + session_id: Option, + stream: SseReceiver, + }, + Buffered { + connection_id: AcpConnectionId, + protocol_version: Option, + session_id: Option, + events: Vec, + }, +} + +#[derive(Debug)] +pub struct HttpGetOutcome { + pub protocol_version: Option, + pub stream: SseReceiver, +} + +#[derive(Debug)] +pub enum HttpTransportError { + BadRequest { + message: &'static str, + source: Option, + }, + NotFound { + message: &'static str, + source: Option, + }, + Conflict { + message: &'static str, + source: Option, + }, + Forbidden { + message: &'static str, + source: Option, + }, + UnsupportedMediaType { + message: &'static str, + source: Option, + }, + NotAcceptable { + message: &'static str, + source: Option, + }, + NotImplemented { + message: &'static str, + source: Option, + }, + Internal { + message: &'static str, + source: Option, + }, +} + +impl std::fmt::Display for HttpTransportError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.message()) + } +} + +impl std::error::Error for HttpTransportError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source_ref() + } +} + +impl HttpTransportError { + fn message(&self) -> &'static str { + match self { + Self::BadRequest { message, .. } + | Self::NotFound { message, .. } + | Self::Conflict { message, .. } + | Self::Forbidden { message, .. } + | Self::UnsupportedMediaType { message, .. } + | Self::NotAcceptable { message, .. } + | Self::NotImplemented { message, .. } + | Self::Internal { message, .. } => message, + } + } + + fn source_ref(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::BadRequest { source, .. } + | Self::NotFound { source, .. } + | Self::Conflict { source, .. } + | Self::Forbidden { source, .. } + | Self::UnsupportedMediaType { source, .. } + | Self::NotAcceptable { source, .. } + | Self::NotImplemented { source, .. } + | Self::Internal { source, .. } => source + .as_deref() + .map(|source| source as &(dyn std::error::Error + 'static)), + } + } + + fn bad_request(message: &'static str) -> Self { + Self::BadRequest { + message, + source: None, + } + } + + fn bad_request_with( + message: &'static str, + source: impl std::error::Error + Send + 'static, + ) -> Self { + Self::BadRequest { + message, + source: Some(Box::new(source)), + } + } + + fn not_found(message: &'static str) -> Self { + Self::NotFound { + message, + source: None, + } + } + + fn conflict(message: &'static str) -> Self { + Self::Conflict { + message, + source: None, + } + } + + fn forbidden(message: &'static str) -> Self { + Self::Forbidden { + message, + source: None, + } + } + + fn forbidden_with( + message: &'static str, + source: impl std::error::Error + Send + 'static, + ) -> Self { + Self::Forbidden { + message, + source: Some(Box::new(source)), + } + } + + fn unsupported_media_type(message: &'static str) -> Self { + Self::UnsupportedMediaType { + message, + source: None, + } + } + + fn not_acceptable(message: &'static str) -> Self { + Self::NotAcceptable { + message, + source: None, + } + } + + fn not_implemented(message: &'static str) -> Self { + Self::NotImplemented { + message, + source: None, + } + } + + fn internal(message: &'static str) -> Self { + Self::Internal { + message, + source: None, + } + } + + fn internal_with( + message: &'static str, + source: impl std::error::Error + Send + 'static, + ) -> Self { + Self::Internal { + message, + source: Some(Box::new(source)), + } + } + + fn status_code(&self) -> StatusCode { + match self { + Self::BadRequest { .. } => StatusCode::BAD_REQUEST, + Self::NotFound { .. } => StatusCode::NOT_FOUND, + Self::Conflict { .. } => StatusCode::CONFLICT, + Self::Forbidden { .. } => StatusCode::FORBIDDEN, + Self::UnsupportedMediaType { .. } => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::NotAcceptable { .. } => StatusCode::NOT_ACCEPTABLE, + Self::NotImplemented { .. } => StatusCode::NOT_IMPLEMENTED, + Self::Internal { .. } => StatusCode::INTERNAL_SERVER_ERROR, + } + } + + fn into_response(self) -> Response { + let status = self.status_code(); + (status, self.to_string()).into_response() + } +} + +#[derive(Debug)] +pub struct HttpConnectionHandle { + pub command_tx: mpsc::UnboundedSender, +} + +#[derive(Debug)] +pub enum HttpConnectionCommand { + Post { + protocol_version: Option, + session_id: Option, + message: IncomingHttpMessage, + response: oneshot::Sender>, + }, + AttachListener { + session_id: acp_nats::AcpSessionId, + protocol_version: Option, + response: oneshot::Sender>, + }, + Close { + protocol_version: Option, + response: oneshot::Sender, HttpTransportError>>, + }, +} + +#[derive(Clone, Debug)] +pub(crate) enum SseFrame { + Empty { + event_id: String, + }, + Json { + event_id: Option, + json: String, + }, +} + +impl SseFrame { + fn prime() -> Self { + Self::Empty { + event_id: Uuid::new_v4().to_string(), + } + } + + fn json(json: String) -> Self { + Self::Json { + event_id: Some(Uuid::new_v4().to_string()), + json, + } + } + + fn into_event(self) -> Event { + match self { + Self::Empty { event_id } => Event::default().id(event_id).data(""), + Self::Json { event_id, json } => { + let event = Event::default().data(json); + if let Some(event_id) = event_id { + event.id(event_id) + } else { + event + } + } + } + } +} + +#[derive(Debug)] +enum PendingRequest { + Live { + request_id: RequestId, + capture_protocol_version: bool, + session_id: Option, + sender: SseSender, + }, + Buffered { + request_id: RequestId, + fallback_session_id: Option, + events: Vec, + response: oneshot::Sender>, + }, +} + +#[derive(Debug, Deserialize)] +pub struct IncomingHttpMessage { + pub id: Option, + pub method: Option, + pub params: Option, + #[serde(skip)] + pub raw: String, + #[serde(skip)] + has_result: bool, + #[serde(skip)] + has_error: bool, +} + +impl IncomingHttpMessage { + pub fn parse(raw: String) -> Result { + let trimmed = raw.trim_start(); + if trimmed.starts_with('[') { + return Err(HttpTransportError::not_implemented( + "batch JSON-RPC requests are not supported", + )); + } + + let value = serde_json::from_str::(&raw).map_err(|error| { + HttpTransportError::bad_request_with("invalid JSON-RPC payload", error) + })?; + let (has_result, has_error) = value + .as_object() + .map(|object| (object.contains_key("result"), object.contains_key("error"))) + .unwrap_or((false, false)); + let mut parsed = serde_json::from_value::(value).map_err(|error| { + HttpTransportError::bad_request_with("invalid JSON-RPC payload", error) + })?; + parsed.raw = raw; + parsed.has_result = has_result; + parsed.has_error = has_error; + Ok(parsed) + } + + fn is_request(&self) -> bool { + self.id.is_some() && self.method.is_some() + } + + fn is_notification(&self) -> bool { + self.id.is_none() && self.method.is_some() + } + + fn is_response(&self) -> bool { + self.id.is_some() && self.method.is_none() && (self.has_result || self.has_error) + } + + fn method_name(&self) -> Option<&str> { + self.method.as_deref() + } + + fn is_initialize(&self) -> bool { + self.method_name() == Some("initialize") + } + + fn activates_session(&self) -> bool { + matches!( + self.method_name(), + Some("session/new") + | Some("session/load") + | Some("session/resume") + | Some("session/fork") + ) + } + + fn requires_session_id(&self) -> bool { + if self.is_response() { + return true; + } + + match self.method_name() { + Some("initialize") | Some("authenticate") | Some("session/list") => false, + _ if self.activates_session() => false, + Some(method) if method.starts_with("session/") => true, + _ => false, + } + } + + fn params_session_id(&self) -> Result, HttpTransportError> { + let Some(params) = &self.params else { + return Ok(None); + }; + + let Some(session_id) = params.get("sessionId").and_then(Value::as_str) else { + return Ok(None); + }; + + acp_nats::AcpSessionId::new(session_id) + .map(Some) + .map_err(|error| { + HttpTransportError::bad_request_with("invalid sessionId in request body", error) + }) + } +} + +#[derive(Debug, Deserialize)] +struct OutgoingHttpMessage { + id: Option, + params: Option, + result: Option, +} + +impl OutgoingHttpMessage { + fn parse(raw: &str) -> Option { + serde_json::from_str(raw).ok() + } + + fn params_session_id(&self) -> Option { + let params = self.params.as_ref()?; + let session_id = params.get("sessionId")?.as_str()?; + acp_nats::AcpSessionId::new(session_id).ok() + } + + fn result_session_id(&self) -> Option { + let result = self.result.as_ref()?; + let session_id = result.get("sessionId")?.as_str()?; + acp_nats::AcpSessionId::new(session_id).ok() + } + + fn result_protocol_version(&self) -> Option { + let result = self.result.as_ref()?; + let protocol_version = result.get("protocolVersion")?; + serde_json::from_value(protocol_version.clone()).ok() + } +} + +enum LiveFrameOutcome { + Keep, + Clear, + Drop, +} + +enum BufferedFrameOutcome { + Buffered, + Routed, + Finalize { + session_id: Option, + }, +} + +enum ListenerDispatch { + Missing, + Delivered, + Dropped, +} + +fn try_send_sse_frame(sender: &SseSender, frame: &SseFrame) -> bool { + match sender.try_send(frame.clone()) { + Ok(()) => true, + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => false, + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => false, + } +} + +fn dispatch_to_get_listeners( + frame: &SseFrame, + session_id: &acp_nats::AcpSessionId, + get_listeners: &mut HashMap>, +) -> ListenerDispatch { + let Some(_) = get_listeners.get(session_id) else { + return ListenerDispatch::Missing; + }; + + let (outcome, remove_session) = { + let listeners = get_listeners + .get_mut(session_id) + .expect("session listeners must exist after presence check"); + let mut delivered = false; + let mut retained = Vec::with_capacity(listeners.len()); + + for listener in listeners.drain(..) { + if delivered { + retained.push(listener); + continue; + } + + match listener.try_send(frame.clone()) { + Ok(()) => { + delivered = true; + retained.push(listener); + } + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + warn!(session_id = %session_id, "Dropping stalled HTTP SSE listener"); + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {} + } + } + *listeners = retained; + + ( + if delivered { + ListenerDispatch::Delivered + } else { + ListenerDispatch::Dropped + }, + listeners.is_empty(), + ) + }; + + if remove_session { + get_listeners.remove(session_id); + } + + outcome +} + +fn route_live_frame( + frame: &SseFrame, + parsed: Option<&OutgoingHttpMessage>, + request_id: &RequestId, + request_session_id: Option<&acp_nats::AcpSessionId>, + sender: &SseSender, + get_listeners: &mut HashMap>, +) -> LiveFrameOutcome { + if parsed.and_then(|message| message.id.as_ref()) == Some(request_id) { + return if try_send_sse_frame(sender, frame) { + LiveFrameOutcome::Clear + } else { + warn!(request_id = ?request_id, "Dropping stalled HTTP POST response stream"); + LiveFrameOutcome::Drop + }; + } + + if let Some(frame_session_id) = parsed.and_then(OutgoingHttpMessage::params_session_id) { + if request_session_id == Some(&frame_session_id) { + return if try_send_sse_frame(sender, frame) { + LiveFrameOutcome::Keep + } else { + warn!(request_id = ?request_id, session_id = %frame_session_id, "Dropping stalled HTTP POST response stream"); + LiveFrameOutcome::Drop + }; + } + + match dispatch_to_get_listeners(frame, &frame_session_id, get_listeners) { + ListenerDispatch::Delivered | ListenerDispatch::Dropped => { + return LiveFrameOutcome::Keep; + } + ListenerDispatch::Missing => {} + } + } + + if try_send_sse_frame(sender, frame) { + LiveFrameOutcome::Keep + } else { + warn!(request_id = ?request_id, "Dropping stalled HTTP POST response stream"); + LiveFrameOutcome::Drop + } +} + +fn route_buffered_frame( + frame: &SseFrame, + parsed: Option<&OutgoingHttpMessage>, + request_id: &RequestId, + events: &mut Vec, + get_listeners: &mut HashMap>, +) -> BufferedFrameOutcome { + if parsed.and_then(|message| message.id.as_ref()) == Some(request_id) { + events.push(frame.clone()); + return BufferedFrameOutcome::Finalize { + session_id: parsed.and_then(OutgoingHttpMessage::result_session_id), + }; + } + + if let Some(frame_session_id) = parsed.and_then(OutgoingHttpMessage::params_session_id) { + match dispatch_to_get_listeners(frame, &frame_session_id, get_listeners) { + ListenerDispatch::Delivered | ListenerDispatch::Dropped => { + return BufferedFrameOutcome::Routed; + } + ListenerDispatch::Missing => {} + } + } + + events.push(frame.clone()); + BufferedFrameOutcome::Buffered +} + +pub async fn get(State(state): State, request: Request) -> Response { + if let Err(error) = validate_origin(request.headers(), state.bind_host) { + return error.into_response(); + } + + if is_websocket_request(request.headers()) { + let (mut parts, _body) = request.into_parts(); + match WebSocketUpgrade::from_request_parts(&mut parts, &state).await { + Ok(ws) => websocket_response(ws, state), + Err(_) => { + HttpTransportError::bad_request("invalid WebSocket upgrade request").into_response() + } + } + } else { + match http_get(request.headers().clone(), state).await { + Ok(response) => response, + Err(error) => error.into_response(), + } + } +} + +pub async fn post(headers: HeaderMap, State(state): State, body: String) -> Response { + if let Err(error) = validate_origin(&headers, state.bind_host) { + return error.into_response(); + } + + match http_post(headers, state, body).await { + Ok(response) => response, + Err(error) => error.into_response(), + } +} + +pub async fn delete(headers: HeaderMap, State(state): State) -> Response { + if let Err(error) = validate_origin(&headers, state.bind_host) { + return error.into_response(); + } + + match http_delete(headers, state).await { + Ok(response) => response, + Err(error) => error.into_response(), + } +} + +fn websocket_response(ws: WebSocketUpgrade, state: AppState) -> Response { + let connection_id = AcpConnectionId::default(); + let response_header = HeaderValue::from_str(&connection_id.to_string()) + .expect("generated ACP connection id must be a valid header value"); + let shutdown_rx = state.shutdown_tx.subscribe(); + let mut response = ws.on_upgrade(move |socket| async move { + if state + .manager_tx + .send(ManagerRequest::WebSocket(Box::new(ConnectionRequest { + connection_id, + socket, + shutdown_rx, + }))) + .is_err() + { + error!("Connection thread is gone; dropping WebSocket"); + } + }); + response + .headers_mut() + .insert(ACP_CONNECTION_ID_HEADER, response_header); + response +} + +async fn http_post( + headers: HeaderMap, + state: AppState, + body: String, +) -> Result { + validate_post_headers(&headers)?; + + let message = IncomingHttpMessage::parse(body)?; + if !(message.is_request() || message.is_notification() || message.is_response()) { + return Err(HttpTransportError::bad_request( + "invalid JSON-RPC message shape", + )); + } + + let connection_id = parse_connection_id_header(&headers)?; + let protocol_version = parse_protocol_version_header(&headers)?; + let session_id = parse_session_id_header(&headers)?; + + validate_http_context(&message, connection_id.as_ref(), session_id.as_ref())?; + + let (response_tx, response_rx) = oneshot::channel(); + state + .manager_tx + .send(ManagerRequest::HttpPost { + connection_id, + protocol_version, + session_id, + message, + response: response_tx, + shutdown_rx: state.shutdown_tx.subscribe(), + }) + .map_err(|error| { + HttpTransportError::internal_with("connection manager is unavailable", error) + })?; + + match response_rx.await.map_err(|error| { + HttpTransportError::internal_with("connection manager dropped the request", error) + })?? { + HttpPostOutcome::Accepted => Ok(StatusCode::ACCEPTED.into_response()), + HttpPostOutcome::Live { + connection_id, + protocol_version, + session_id, + stream, + } => Ok(build_sse_response( + connection_id, + protocol_version, + session_id, + stream, + )), + HttpPostOutcome::Buffered { + connection_id, + protocol_version, + session_id, + events, + } => Ok(build_buffered_sse_response( + connection_id, + protocol_version, + session_id, + events, + )), + } +} + +async fn http_get(headers: HeaderMap, state: AppState) -> Result { + validate_get_headers(&headers)?; + + let connection_id = parse_connection_id_header(&headers)?.ok_or( + HttpTransportError::bad_request("missing Acp-Connection-Id header"), + )?; + let protocol_version = parse_protocol_version_header(&headers)?; + let session_id = parse_session_id_header(&headers)?.ok_or(HttpTransportError::bad_request( + "missing Acp-Session-Id header", + ))?; + + let (response_tx, response_rx) = oneshot::channel(); + state + .manager_tx + .send(ManagerRequest::HttpGet { + connection_id: connection_id.clone(), + session_id: session_id.clone(), + protocol_version, + response: response_tx, + }) + .map_err(|error| { + HttpTransportError::internal_with("connection manager is unavailable", error) + })?; + + let outcome = response_rx.await.map_err(|error| { + HttpTransportError::internal_with("connection manager dropped the request", error) + })??; + + let mut response = Sse::new(stream::unfold(outcome.stream, |mut stream| async move { + stream + .recv() + .await + .map(|item| (Ok::(item.into_event()), stream)) + })) + .into_response(); + set_transport_headers( + response.headers_mut(), + &connection_id, + outcome.protocol_version.as_ref(), + Some(&session_id), + ); + response + .headers_mut() + .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); + Ok(response) +} + +async fn http_delete(headers: HeaderMap, state: AppState) -> Result { + let connection_id = parse_connection_id_header(&headers)?.ok_or( + HttpTransportError::bad_request("missing Acp-Connection-Id header"), + )?; + let protocol_version = parse_protocol_version_header(&headers)?; + + let (response_tx, response_rx) = oneshot::channel(); + state + .manager_tx + .send(ManagerRequest::HttpDelete { + connection_id, + protocol_version, + response: response_tx, + }) + .map_err(|error| { + HttpTransportError::internal_with("connection manager is unavailable", error) + })?; + + let protocol_version = response_rx.await.map_err(|error| { + HttpTransportError::internal_with("connection manager dropped the request", error) + })??; + + let mut response = StatusCode::ACCEPTED.into_response(); + set_protocol_version_header(response.headers_mut(), protocol_version.as_ref()); + Ok(response) +} + +fn validate_post_headers(headers: &HeaderMap) -> Result<(), HttpTransportError> { + match headers + .get(CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + { + Some(value) if media_type_matches(value, "application/json") => {} + _ => { + return Err(HttpTransportError::unsupported_media_type( + "Content-Type must be application/json", + )); + } + } + + let accept = headers + .get(ACCEPT) + .and_then(|value| value.to_str().ok()) + .ok_or(HttpTransportError::not_acceptable( + "Accept must include application/json and text/event-stream", + ))?; + + if !accept_contains(accept, "application/json") || !accept_contains(accept, "text/event-stream") + { + return Err(HttpTransportError::not_acceptable( + "Accept must include application/json and text/event-stream", + )); + } + + Ok(()) +} + +fn validate_get_headers(headers: &HeaderMap) -> Result<(), HttpTransportError> { + let accept = headers + .get(ACCEPT) + .and_then(|value| value.to_str().ok()) + .ok_or(HttpTransportError::not_acceptable( + "Accept must include text/event-stream", + ))?; + + if !accept_contains(accept, "text/event-stream") { + return Err(HttpTransportError::not_acceptable( + "Accept must include text/event-stream", + )); + } + + Ok(()) +} + +fn validate_http_context( + message: &IncomingHttpMessage, + connection_id: Option<&AcpConnectionId>, + session_id: Option<&acp_nats::AcpSessionId>, +) -> Result<(), HttpTransportError> { + if message.is_initialize() { + if connection_id.is_some() { + return Err(HttpTransportError::bad_request( + "initialize must not include Acp-Connection-Id", + )); + } + if session_id.is_some() { + return Err(HttpTransportError::bad_request( + "initialize must not include Acp-Session-Id", + )); + } + return Ok(()); + } + + if connection_id.is_none() { + return Err(HttpTransportError::bad_request( + "missing Acp-Connection-Id header", + )); + } + + let body_session_id = message.params_session_id()?; + if message.requires_session_id() && session_id.is_none() { + return Err(HttpTransportError::bad_request( + "missing Acp-Session-Id header", + )); + } + + if let (Some(header_session_id), Some(body_session_id)) = (session_id, body_session_id.as_ref()) + && header_session_id != body_session_id + { + return Err(HttpTransportError::bad_request( + "Acp-Session-Id header does not match request body sessionId", + )); + } + + Ok(()) +} + +fn accept_contains(header: &str, expected: &str) -> bool { + header + .split(',') + .map(str::trim) + .any(|media_type| media_type_matches(media_type, expected)) +} + +fn media_type_matches(header_value: &str, expected: &str) -> bool { + header_value + .split(';') + .next() + .map(str::trim) + .is_some_and(|media_type| media_type.eq_ignore_ascii_case(expected)) +} + +fn validate_origin(headers: &HeaderMap, bind_host: IpAddr) -> Result<(), HttpTransportError> { + let Some(origin) = headers.get(ORIGIN) else { + return Ok(()); + }; + + let origin = origin + .to_str() + .map_err(|error| HttpTransportError::forbidden_with("invalid Origin header", error))?; + let origin = origin + .parse::() + .map_err(|error| HttpTransportError::forbidden_with("invalid Origin header", error))?; + let Some(origin_host) = origin.host() else { + return Err(HttpTransportError::forbidden("invalid Origin header")); + }; + + let request_host = headers + .get(HOST) + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.parse::().ok()) + .map(|authority| authority.host().to_owned()); + + let allowed = if bind_host.is_loopback() { + is_loopback_host(origin_host) + } else if bind_host.is_unspecified() { + is_loopback_host(origin_host) + || request_host + .as_deref() + .is_some_and(|host| host.eq_ignore_ascii_case(origin_host)) + } else { + origin_host.eq_ignore_ascii_case(&bind_host.to_string()) + || request_host + .as_deref() + .is_some_and(|host| host.eq_ignore_ascii_case(origin_host)) + }; + + if allowed { + Ok(()) + } else { + Err(HttpTransportError::forbidden("Origin is not allowed")) + } +} + +fn is_loopback_host(host: &str) -> bool { + matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]") +} + +fn parse_connection_id_header( + headers: &HeaderMap, +) -> Result, HttpTransportError> { + headers + .get(ACP_CONNECTION_ID_HEADER) + .map(|value| { + value + .to_str() + .map_err(|error| { + HttpTransportError::bad_request_with("invalid Acp-Connection-Id header", error) + }) + .and_then(|value| { + AcpConnectionId::parse(value).map_err(|error| { + HttpTransportError::bad_request_with( + "invalid Acp-Connection-Id header", + error, + ) + }) + }) + }) + .transpose() +} + +fn parse_protocol_version_header( + headers: &HeaderMap, +) -> Result, HttpTransportError> { + headers + .get(ACP_PROTOCOL_VERSION_HEADER) + .map(|value| { + value + .to_str() + .map_err(|error| { + HttpTransportError::bad_request_with( + "invalid Acp-Protocol-Version header", + error, + ) + }) + .and_then(|value| { + value + .parse::() + .map(ProtocolVersion::from) + .map_err(|error| { + HttpTransportError::bad_request_with( + "invalid Acp-Protocol-Version header", + error, + ) + }) + }) + }) + .transpose() +} + +fn parse_session_id_header( + headers: &HeaderMap, +) -> Result, HttpTransportError> { + headers + .get(ACP_SESSION_ID_HEADER) + .map(|value| { + value + .to_str() + .map_err(|error| { + HttpTransportError::bad_request_with("invalid Acp-Session-Id header", error) + }) + .and_then(|value| { + acp_nats::AcpSessionId::new(value).map_err(|error| { + HttpTransportError::bad_request_with("invalid Acp-Session-Id header", error) + }) + }) + }) + .transpose() +} + +fn is_websocket_request(headers: &HeaderMap) -> bool { + headers + .get("upgrade") + .and_then(|value| value.to_str().ok()) + .map(|value| value.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) +} + +fn build_sse_response( + connection_id: AcpConnectionId, + protocol_version: Option, + session_id: Option, + stream: SseReceiver, +) -> Response { + let primed_stream = + stream::once(async { Ok::(SseFrame::prime().into_event()) }).chain( + stream::unfold(stream, |mut stream| async move { + stream + .recv() + .await + .map(|item| (Ok::(item.into_event()), stream)) + }), + ); + let mut response = Sse::new(primed_stream).into_response(); + set_transport_headers( + response.headers_mut(), + &connection_id, + protocol_version.as_ref(), + session_id.as_ref(), + ); + response + .headers_mut() + .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); + response +} + +fn build_buffered_sse_response( + connection_id: AcpConnectionId, + protocol_version: Option, + session_id: Option, + events: Vec, +) -> Response { + let stream = stream::iter( + std::iter::once(SseFrame::prime()) + .chain(events) + .map(|item| Ok::(item.into_event())), + ); + let mut response = Sse::new(stream).into_response(); + set_transport_headers( + response.headers_mut(), + &connection_id, + protocol_version.as_ref(), + session_id.as_ref(), + ); + response + .headers_mut() + .insert(X_ACCEL_BUFFERING_HEADER, HeaderValue::from_static("no")); + response +} + +fn set_transport_headers( + headers: &mut HeaderMap, + connection_id: &AcpConnectionId, + protocol_version: Option<&ProtocolVersion>, + session_id: Option<&acp_nats::AcpSessionId>, +) { + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()) + .expect("generated ACP connection id must be a valid header value"), + ); + set_protocol_version_header(headers, protocol_version); + if let Some(session_id) = session_id { + headers.insert( + ACP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id.as_str()) + .expect("generated ACP session id must be a valid header value"), + ); + } +} + +fn set_protocol_version_header( + headers: &mut HeaderMap, + protocol_version: Option<&ProtocolVersion>, +) { + if let Some(protocol_version) = protocol_version { + headers.insert( + ACP_PROTOCOL_VERSION_HEADER, + HeaderValue::from_str(&protocol_version.to_string()) + .expect("ACP protocol version must be a valid header value"), + ); + } +} + +fn validate_protocol_version_header( + provided: Option<&ProtocolVersion>, + negotiated: Option<&ProtocolVersion>, +) -> Result<(), HttpTransportError> { + if let (Some(provided), Some(negotiated)) = (provided, negotiated) + && provided != negotiated + { + return Err(HttpTransportError::bad_request( + "Acp-Protocol-Version header does not match initialized protocol version", + )); + } + + Ok(()) +} + +pub async fn run_http_connection( + connection_id: AcpConnectionId, + nats_client: N, + js_client: J, + config: acp_nats::Config, + mut command_rx: mpsc::UnboundedReceiver, + mut shutdown_rx: watch::Receiver, +) where + N: acp_nats::RequestClient + + acp_nats::PublishClient + + acp_nats::FlushClient + + acp_nats::SubscribeClient + + Clone + + 'static, + J: acp_nats::JetStreamPublisher + acp_nats::JetStreamGetStream + 'static, + trogon_nats::jetstream::JsMessageOf: trogon_nats::jetstream::JsRequestMessage, +{ + let (agent_write, mut output_read) = tokio::io::duplex(crate::constants::DUPLEX_BUFFER_SIZE); + let (mut input_write, agent_read) = tokio::io::duplex(crate::constants::DUPLEX_BUFFER_SIZE); + + let incoming = async_compat::Compat::new(agent_read); + let outgoing = async_compat::Compat::new(agent_write); + + let meter = acp_telemetry::meter("acp-nats-ws"); + let (notification_tx, notification_rx) = tokio::sync::mpsc::channel::(64); + let bridge = Rc::new(Bridge::new( + nats_client.clone(), + js_client, + SystemClock, + &meter, + config, + notification_tx, + )); + + let (connection, io_task) = + AgentSideConnection::new(bridge.clone(), outgoing, incoming, |fut| { + tokio::task::spawn_local(fut); + }); + + let connection = Rc::new(connection); + spawn_notification_forwarder(connection.clone(), notification_rx); + + let (input_tx, mut input_rx) = mpsc::channel::(HTTP_CHANNEL_CAPACITY); + let input_task = tokio::task::spawn_local(async move { + while let Some(message) = input_rx.recv().await { + if input_write.write_all(message.as_bytes()).await.is_err() { + break; + } + if input_write.write_all(b"\n").await.is_err() { + break; + } + } + }); + + let (output_tx, mut output_rx) = mpsc::channel::(HTTP_CHANNEL_CAPACITY); + let output_task = tokio::task::spawn_local(async move { + let mut reader = tokio::io::BufReader::new(&mut output_read); + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let trimmed = line.trim_end_matches(['\r', '\n']); + if trimmed.is_empty() { + continue; + } + if output_tx.send(trimmed.to_string()).await.is_err() { + break; + } + } + Err(_) => break, + } + } + }); + + let mut client_task = tokio::task::spawn_local(client::run( + nats_client, + connection.clone(), + bridge, + StdJsonSerialize, + )); + let mut io_task = tokio::task::spawn_local(io_task); + + let mut pending_request: Option = None; + let mut protocol_version: Option = None; + let mut sessions = HashSet::::new(); + let mut get_listeners = HashMap::>::new(); + + info!(%connection_id, "HTTP connection established"); + + loop { + tokio::select! { + command = command_rx.recv() => { + let Some(command) = command else { + break; + }; + + match command { + HttpConnectionCommand::Post { protocol_version: header_protocol_version, session_id, message, response } => { + if let Err(error) = validate_protocol_version_header( + header_protocol_version.as_ref(), + protocol_version.as_ref(), + ) { + let _ = response.send(Err(error)); + continue; + } + + if message.is_request() { + if pending_request.is_some() { + let _ = response.send(Err(HttpTransportError::conflict( + "only one in-flight HTTP request is supported per ACP connection", + ))); + continue; + } + + if message.activates_session() { + pending_request = Some(PendingRequest::Buffered { + request_id: message.id.clone().expect("request must have id"), + fallback_session_id: message.params_session_id().ok().flatten(), + events: Vec::new(), + response, + }); + if input_tx.try_send(message.raw).is_err() + && let Some(PendingRequest::Buffered { response, .. }) = + pending_request.take() + { + let _ = response.send(Err(HttpTransportError::internal( + "ACP runtime input queue is full", + ))); + } + continue; + } + + if let Some(session_id) = session_id.clone() { + sessions.insert(session_id); + } + + let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + pending_request = Some(PendingRequest::Live { + request_id: message.id.clone().expect("request must have id"), + capture_protocol_version: message.is_initialize(), + session_id: session_id.clone(), + sender: stream_tx, + }); + if input_tx.try_send(message.raw).is_err() { + pending_request = None; + let _ = response.send(Err(HttpTransportError::internal( + "ACP runtime input queue is full", + ))); + continue; + } + let _ = response.send(Ok(HttpPostOutcome::Live { + connection_id: connection_id.clone(), + protocol_version: protocol_version.clone(), + session_id, + stream: stream_rx, + })); + continue; + } + + if let Some(session_id) = session_id.clone() { + sessions.insert(session_id); + } + + if input_tx.try_send(message.raw).is_err() { + let _ = response.send(Err(HttpTransportError::internal( + "ACP runtime input queue is full", + ))); + continue; + } + + let _ = response.send(Ok(HttpPostOutcome::Accepted)); + } + HttpConnectionCommand::AttachListener { + session_id, + protocol_version: header_protocol_version, + response, + } => { + if let Err(error) = validate_protocol_version_header( + header_protocol_version.as_ref(), + protocol_version.as_ref(), + ) { + let _ = response.send(Err(error)); + continue; + } + + if !sessions.contains(&session_id) { + let _ = response + .send(Err(HttpTransportError::not_found("unknown ACP session"))); + continue; + } + + let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + get_listeners.entry(session_id).or_default().push(stream_tx); + let _ = response.send(Ok(HttpGetOutcome { + protocol_version: protocol_version.clone(), + stream: stream_rx, + })); + } + HttpConnectionCommand::Close { + protocol_version: header_protocol_version, + response, + } => { + if let Err(error) = validate_protocol_version_header( + header_protocol_version.as_ref(), + protocol_version.as_ref(), + ) { + let _ = response.send(Err(error)); + continue; + } + + let _ = response.send(Ok(protocol_version.clone())); + break; + } + } + } + outbound = output_rx.recv() => { + let Some(outbound) = outbound else { + break; + }; + + let frame = SseFrame::json(outbound.clone()); + let parsed = OutgoingHttpMessage::parse(&outbound); + + if let Some(pending) = pending_request.as_mut() { + match pending { + PendingRequest::Live { + request_id, + capture_protocol_version, + session_id, + sender, + } => { + if *capture_protocol_version + && parsed.as_ref().and_then(|message| message.id.as_ref()) + == Some(request_id) + { + protocol_version = parsed + .as_ref() + .and_then(OutgoingHttpMessage::result_protocol_version); + } + + match route_live_frame( + &frame, + parsed.as_ref(), + request_id, + session_id.as_ref(), + sender, + &mut get_listeners, + ) { + LiveFrameOutcome::Keep => {} + LiveFrameOutcome::Clear | LiveFrameOutcome::Drop => { + pending_request = None; + } + } + continue; + } + PendingRequest::Buffered { + request_id, + fallback_session_id, + events, + .. + } => { + match route_buffered_frame( + &frame, + parsed.as_ref(), + request_id, + events, + &mut get_listeners, + ) { + BufferedFrameOutcome::Buffered | BufferedFrameOutcome::Routed => {} + BufferedFrameOutcome::Finalize { session_id } => { + let session_id = + session_id.or_else(|| fallback_session_id.clone()); + if let Some(session_id) = session_id.clone() { + sessions.insert(session_id); + } + let events = std::mem::take(events); + if let Some(PendingRequest::Buffered { response, .. }) = + pending_request.take() + { + let _ = response.send(Ok(HttpPostOutcome::Buffered { + connection_id: connection_id.clone(), + protocol_version: protocol_version.clone(), + session_id, + events, + })); + } + } + } + continue; + } + } + } + + let Some(session_id) = parsed.and_then(|message| message.params_session_id()) else { + continue; + }; + + let _ = dispatch_to_get_listeners(&frame, &session_id, &mut get_listeners); + } + result = &mut client_task => { + match result { + Ok(()) => info!(%connection_id, "HTTP client task completed"), + Err(error) => warn!(%connection_id, error = %error, "HTTP client task failed"), + } + break; + } + result = &mut io_task => { + match result { + Ok(Ok(())) => info!(%connection_id, "HTTP IO task completed"), + Ok(Err(error)) => warn!(%connection_id, error = %error, "HTTP IO task failed"), + Err(error) => warn!(%connection_id, error = %error, "HTTP IO task join failed"), + } + break; + } + _ = shutdown_rx.wait_for(|&shutdown| shutdown) => { + info!(%connection_id, "HTTP connection shutting down"); + break; + } + } + } + + if let Some(PendingRequest::Buffered { response, .. }) = pending_request.take() { + let _ = response.send(Err(HttpTransportError::internal( + "HTTP connection closed before the request completed", + ))); + } + + input_task.abort(); + output_task.abort(); + + if !client_task.is_finished() { + client_task.abort(); + let _ = client_task.await; + } + if !io_task.is_finished() { + io_task.abort(); + let _ = io_task.await; + } + + info!(%connection_id, "HTTP connection closed"); +} + +pub async fn process_manager_request( + request: ManagerRequest, + http_connections: &mut HashMap, + websocket_handles: &mut Vec>, + http_connection_handles: &mut Vec>, + nats_client: &N, + js_client: &J, + config: &acp_nats::Config, +) where + N: acp_nats::RequestClient + + acp_nats::PublishClient + + acp_nats::FlushClient + + acp_nats::SubscribeClient + + Clone + + Send + + 'static, + J: acp_nats::JetStreamPublisher + acp_nats::JetStreamGetStream + Clone + 'static, + trogon_nats::jetstream::JsMessageOf: trogon_nats::jetstream::JsRequestMessage, +{ + websocket_handles.retain(|handle| !handle.is_finished()); + http_connection_handles.retain(|handle| !handle.is_finished()); + http_connections.retain(|_, handle| !handle.command_tx.is_closed()); + + match request { + ManagerRequest::WebSocket(request) => { + let ConnectionRequest { + connection_id, + socket, + shutdown_rx, + } = *request; + websocket_handles.push(tokio::task::spawn_local(connection::handle( + connection_id, + socket, + nats_client.clone(), + js_client.clone(), + config.clone(), + shutdown_rx, + ))); + } + ManagerRequest::HttpPost { + connection_id, + protocol_version, + session_id, + message, + response, + shutdown_rx, + } => { + let connection_id = match connection_id { + Some(connection_id) => connection_id, + None => { + if !message.is_initialize() { + let _ = response.send(Err(HttpTransportError::bad_request( + "missing Acp-Connection-Id header", + ))); + return; + } + + let connection_id = AcpConnectionId::default(); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + http_connections.insert( + connection_id.clone(), + HttpConnectionHandle { + command_tx: command_tx.clone(), + }, + ); + http_connection_handles.push(tokio::task::spawn_local(run_http_connection( + connection_id.clone(), + nats_client.clone(), + js_client.clone(), + config.clone(), + command_rx, + shutdown_rx, + ))); + connection_id + } + }; + + let Some(handle) = http_connections.get(&connection_id) else { + let _ = response.send(Err(HttpTransportError::not_found("unknown ACP connection"))); + return; + }; + + if handle + .command_tx + .send(HttpConnectionCommand::Post { + protocol_version, + session_id, + message, + response, + }) + .is_err() + { + http_connections.remove(&connection_id); + } + } + ManagerRequest::HttpGet { + connection_id, + session_id, + protocol_version, + response, + } => { + let Some(handle) = http_connections.get(&connection_id) else { + let _ = response.send(Err(HttpTransportError::not_found("unknown ACP connection"))); + return; + }; + + if handle + .command_tx + .send(HttpConnectionCommand::AttachListener { + session_id, + protocol_version, + response, + }) + .is_err() + { + http_connections.remove(&connection_id); + } + } + ManagerRequest::HttpDelete { + connection_id, + protocol_version, + response, + } => { + let Some(handle) = http_connections.remove(&connection_id) else { + let _ = response.send(Err(HttpTransportError::not_found("unknown ACP connection"))); + return; + }; + + let _ = handle.command_tx.send(HttpConnectionCommand::Close { + protocol_version, + response, + }); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use acp_nats::Config; + use axum::body::{Body, to_bytes}; + use axum::http::Request as HttpRequest; + use axum::http::header::{HOST, ORIGIN}; + use serde_json::{Value, json}; + use std::net::{IpAddr, Ipv4Addr}; + use tokio::sync::{mpsc, oneshot, watch}; + use trogon_nats::AdvancedMockNatsClient; + + #[derive(Clone)] + struct MockJs { + publisher: trogon_nats::jetstream::MockJetStreamPublisher, + consumer_factory: trogon_nats::jetstream::MockJetStreamConsumerFactory, + } + + impl MockJs { + fn new() -> Self { + Self { + publisher: trogon_nats::jetstream::MockJetStreamPublisher::new(), + consumer_factory: trogon_nats::jetstream::MockJetStreamConsumerFactory::new(), + } + } + } + + impl trogon_nats::jetstream::JetStreamPublisher for MockJs { + type PublishError = trogon_nats::mocks::MockError; + type AckFuture = std::future::Ready< + Result, + >; + + async fn publish_with_headers( + &self, + subject: S, + headers: async_nats::HeaderMap, + payload: bytes::Bytes, + ) -> Result { + self.publisher + .publish_with_headers(subject, headers, payload) + .await + } + } + + impl trogon_nats::jetstream::JetStreamGetStream for MockJs { + type Error = trogon_nats::mocks::MockError; + type Stream = trogon_nats::jetstream::MockJetStreamStream; + + async fn get_stream + Send>( + &self, + stream_name: T, + ) -> Result { + self.consumer_factory.get_stream(stream_name).await + } + } + + fn test_config() -> Config { + Config::new( + acp_nats::AcpPrefix::new("acp").unwrap(), + acp_nats::NatsConfig { + servers: vec!["localhost:4222".to_string()], + auth: trogon_nats::NatsAuth::None, + }, + ) + } + + fn test_state() -> (AppState, mpsc::UnboundedReceiver) { + let (manager_tx, manager_rx) = mpsc::unbounded_channel(); + let (shutdown_tx, _) = watch::channel(false); + ( + AppState { + bind_host: IpAddr::V4(Ipv4Addr::LOCALHOST), + manager_tx, + shutdown_tx, + }, + manager_rx, + ) + } + + async fn json_event_body(response: Response) -> Vec { + let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let body = String::from_utf8(bytes.to_vec()).unwrap(); + body.lines() + .filter_map(|line| line.strip_prefix("data: ")) + .filter(|json| !json.is_empty()) + .map(|json| serde_json::from_str(json).unwrap()) + .collect() + } + + fn frame_json(frame: &SseFrame) -> &str { + match frame { + SseFrame::Json { json, .. } => json, + SseFrame::Empty { .. } => panic!("expected JSON SSE frame"), + } + } + + fn post_headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + headers.insert( + ACCEPT, + HeaderValue::from_static("application/json, text/event-stream"), + ); + headers + } + + fn get_headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream")); + headers + } + + fn session_id() -> acp_nats::AcpSessionId { + acp_nats::AcpSessionId::new("session-1").unwrap() + } + + #[test] + fn http_transport_error_into_response_maps_status_codes() { + let cases = [ + ( + HttpTransportError::bad_request("bad"), + StatusCode::BAD_REQUEST, + ), + ( + HttpTransportError::not_found("missing"), + StatusCode::NOT_FOUND, + ), + ( + HttpTransportError::conflict("conflict"), + StatusCode::CONFLICT, + ), + ( + HttpTransportError::forbidden("forbidden"), + StatusCode::FORBIDDEN, + ), + ( + HttpTransportError::unsupported_media_type("unsupported"), + StatusCode::UNSUPPORTED_MEDIA_TYPE, + ), + ( + HttpTransportError::not_acceptable("not-acceptable"), + StatusCode::NOT_ACCEPTABLE, + ), + ( + HttpTransportError::not_implemented("not-implemented"), + StatusCode::NOT_IMPLEMENTED, + ), + ( + HttpTransportError::internal("internal"), + StatusCode::INTERNAL_SERVER_ERROR, + ), + ]; + + for (error, status) in cases { + assert_eq!(error.into_response().status(), status); + } + + let sourced = IncomingHttpMessage::parse("{".to_string()).unwrap_err(); + assert_eq!(sourced.into_response().status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn http_transport_error_display_and_source_chain_return_expected_values() { + use std::error::Error as _; + + assert_eq!(HttpTransportError::bad_request("bad").to_string(), "bad"); + assert!(HttpTransportError::bad_request("bad").source().is_none()); + assert!(HttpTransportError::not_found("missing").source().is_none()); + assert_eq!( + HttpTransportError::conflict("conflict").to_string(), + "conflict" + ); + assert!(HttpTransportError::conflict("conflict").source().is_none()); + assert_eq!( + HttpTransportError::forbidden("forbidden").to_string(), + "forbidden" + ); + assert!( + HttpTransportError::forbidden("forbidden") + .source() + .is_none() + ); + assert!( + HttpTransportError::unsupported_media_type("unsupported") + .source() + .is_none() + ); + assert!( + HttpTransportError::not_acceptable("not-acceptable") + .source() + .is_none() + ); + assert!( + HttpTransportError::not_implemented("not-implemented") + .source() + .is_none() + ); + assert_eq!( + HttpTransportError::internal("internal").to_string(), + "internal" + ); + assert!(HttpTransportError::internal("internal").source().is_none()); + + let invalid_json = IncomingHttpMessage::parse("{".to_string()).unwrap_err(); + assert_eq!(invalid_json.to_string(), "invalid JSON-RPC payload"); + assert!(invalid_json.source().is_some()); + } + + #[test] + fn incoming_http_message_parses_and_classifies_shapes() { + let request = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"session/new","params":{"cwd":".","mcpServers":[]}}"# + .to_string(), + ) + .unwrap(); + assert!(request.is_request()); + assert!(request.activates_session()); + + for raw in [ + r#"{"jsonrpc":"2.0","id":2,"method":"session/load","params":{"sessionId":"session-1","cwd":"."}}"#, + r#"{"jsonrpc":"2.0","id":3,"method":"session/resume","params":{"sessionId":"session-1","cwd":"."}}"#, + r#"{"jsonrpc":"2.0","id":4,"method":"session/fork","params":{"sessionId":"session-1","cwd":"."}}"#, + ] { + let activating = IncomingHttpMessage::parse(raw.to_string()).unwrap(); + assert!(activating.activates_session()); + assert!(!activating.requires_session_id()); + } + + let notification = + IncomingHttpMessage::parse(r#"{"jsonrpc":"2.0","method":"initialized"}"#.to_string()) + .unwrap(); + assert!(notification.is_notification()); + assert!(!notification.requires_session_id()); + + let response = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":99,"result":{"ok":true}}"#.to_string(), + ) + .unwrap(); + assert!(response.is_response()); + assert!(response.requires_session_id()); + + let null_response = + IncomingHttpMessage::parse(r#"{"jsonrpc":"2.0","id":100,"result":null}"#.to_string()) + .unwrap(); + assert!(null_response.is_response()); + assert!(null_response.requires_session_id()); + } + + #[test] + fn incoming_http_message_parse_rejects_batch_and_invalid_json() { + let batch = IncomingHttpMessage::parse(r#"[{"jsonrpc":"2.0"}]"#.to_string()).unwrap_err(); + assert!(matches!( + batch, + HttpTransportError::NotImplemented { + message: "batch JSON-RPC requests are not supported", + source: None, + } + )); + + let invalid = IncomingHttpMessage::parse("{".to_string()).unwrap_err(); + assert!(matches!( + invalid, + HttpTransportError::BadRequest { + message: "invalid JSON-RPC payload", + source: Some(_), + } + )); + } + + #[test] + fn incoming_http_message_session_id_helpers_handle_valid_and_invalid_values() { + let message = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"session/prompt","params":{"sessionId":"session-1"}}"# + .to_string(), + ) + .unwrap(); + assert_eq!(message.params_session_id().unwrap(), Some(session_id())); + + let invalid = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"session/prompt","params":{"sessionId":"bad.session"}}"# + .to_string(), + ) + .unwrap(); + assert!(matches!( + invalid.params_session_id(), + Err(HttpTransportError::BadRequest { + message: "invalid sessionId in request body", + source: Some(_), + }) + )); + } + + #[test] + fn outgoing_http_message_extracts_session_ids() { + let outbound = OutgoingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"params":{"sessionId":"session-1"}}"#, + ) + .unwrap(); + assert_eq!(outbound.params_session_id(), Some(session_id())); + + let response = OutgoingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"result":{"sessionId":"session-1"}}"#, + ) + .unwrap(); + assert_eq!(response.result_session_id(), Some(session_id())); + } + + #[test] + fn route_live_frame_keeps_same_session_notifications_on_post_stream() { + let frame = SseFrame::json( + r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-1"}}"# + .to_string(), + ); + let parsed = OutgoingHttpMessage::parse(frame_json(&frame)).unwrap(); + let request_id = RequestId::Number(1); + let request_session_id = session_id(); + let (live_tx, mut live_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let (get_tx, mut get_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let mut get_listeners = HashMap::new(); + get_listeners.insert(request_session_id.clone(), vec![get_tx]); + + let outcome = route_live_frame( + &frame, + Some(&parsed), + &request_id, + Some(&request_session_id), + &live_tx, + &mut get_listeners, + ); + + assert!(matches!(outcome, LiveFrameOutcome::Keep)); + match live_rx.try_recv().unwrap() { + SseFrame::Json { json, .. } => assert!(json.contains(r#""sessionId":"session-1""#)), + SseFrame::Empty { .. } => panic!("expected JSON SSE frame"), + } + assert!(matches!( + get_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) + )); + } + + #[test] + fn route_live_frame_sends_other_session_notifications_to_get_listeners() { + let frame = SseFrame::json( + r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-2"}}"# + .to_string(), + ); + let parsed = OutgoingHttpMessage::parse(frame_json(&frame)).unwrap(); + let request_id = RequestId::Number(1); + let request_session_id = session_id(); + let other_session_id = acp_nats::AcpSessionId::new("session-2").unwrap(); + let (live_tx, mut live_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let (get_tx, mut get_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let mut get_listeners = HashMap::new(); + get_listeners.insert(other_session_id, vec![get_tx]); + + let outcome = route_live_frame( + &frame, + Some(&parsed), + &request_id, + Some(&request_session_id), + &live_tx, + &mut get_listeners, + ); + + assert!(matches!(outcome, LiveFrameOutcome::Keep)); + assert!(matches!( + live_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) + )); + match get_rx.try_recv().unwrap() { + SseFrame::Json { json, .. } => assert!(json.contains(r#""sessionId":"session-2""#)), + SseFrame::Empty { .. } => panic!("expected JSON SSE frame"), + } + } + + #[test] + fn route_buffered_frame_sends_other_session_notifications_to_get_listeners() { + let frame = SseFrame::json( + r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-2"}}"# + .to_string(), + ); + let parsed = OutgoingHttpMessage::parse(frame_json(&frame)).unwrap(); + let request_id = RequestId::Number(1); + let other_session_id = acp_nats::AcpSessionId::new("session-2").unwrap(); + let (get_tx, mut get_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let mut events = Vec::new(); + let mut get_listeners = HashMap::new(); + get_listeners.insert(other_session_id, vec![get_tx]); + + let outcome = route_buffered_frame( + &frame, + Some(&parsed), + &request_id, + &mut events, + &mut get_listeners, + ); + + assert!(matches!(outcome, BufferedFrameOutcome::Routed)); + assert!(events.is_empty()); + match get_rx.try_recv().unwrap() { + SseFrame::Json { json, .. } => assert!(json.contains(r#""sessionId":"session-2""#)), + SseFrame::Empty { .. } => panic!("expected JSON SSE frame"), + } + } + + #[test] + fn dispatch_to_get_listeners_drops_full_listener() { + let session_id = session_id(); + let mut get_listeners = HashMap::new(); + let (listener_tx, mut listener_rx) = mpsc::channel(1); + listener_tx + .try_send(SseFrame::json( + r#"{"jsonrpc":"2.0","method":"session/update"}"#.to_string(), + )) + .unwrap(); + get_listeners.insert(session_id.clone(), vec![listener_tx]); + + let frame = SseFrame::json(r#"{"jsonrpc":"2.0","method":"session/update"}"#.to_string()); + let outcome = dispatch_to_get_listeners(&frame, &session_id, &mut get_listeners); + + assert!(matches!(outcome, ListenerDispatch::Dropped)); + assert!(!get_listeners.contains_key(&session_id)); + assert!(matches!(listener_rx.try_recv(), Ok(SseFrame::Json { .. }))); + assert!(matches!( + listener_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) + )); + } + + #[test] + fn dispatch_to_get_listeners_delivers_each_message_on_only_one_stream() { + let session_id = session_id(); + let mut get_listeners = HashMap::new(); + let (first_tx, mut first_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let (second_tx, mut second_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + get_listeners.insert(session_id.clone(), vec![first_tx, second_tx]); + + let frame = SseFrame::json( + r#"{"jsonrpc":"2.0","method":"session/update","params":{"sessionId":"session-1"}}"# + .to_string(), + ); + let outcome = dispatch_to_get_listeners(&frame, &session_id, &mut get_listeners); + + assert!(matches!(outcome, ListenerDispatch::Delivered)); + assert!(matches!(first_rx.try_recv(), Ok(SseFrame::Json { .. }))); + assert!(matches!( + second_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) + )); + } + + #[test] + fn header_validators_enforce_content_negotiation() { + let valid_post = post_headers(); + assert!(validate_post_headers(&valid_post).is_ok()); + + let mut valid_post_with_charset = post_headers(); + valid_post_with_charset.insert( + CONTENT_TYPE, + HeaderValue::from_static("application/json; charset=utf-8"), + ); + assert!(validate_post_headers(&valid_post_with_charset).is_ok()); + + let mut bad_content_type = valid_post.clone(); + bad_content_type.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + assert!(matches!( + validate_post_headers(&bad_content_type), + Err(HttpTransportError::UnsupportedMediaType { + message: "Content-Type must be application/json", + source: None, + }) + )); + + let mut bad_accept = HeaderMap::new(); + bad_accept.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + bad_accept.insert(ACCEPT, HeaderValue::from_static("application/json")); + assert!(matches!( + validate_post_headers(&bad_accept), + Err(HttpTransportError::NotAcceptable { + message: "Accept must include application/json and text/event-stream", + source: None, + }) + )); + + let valid_get = get_headers(); + assert!(validate_get_headers(&valid_get).is_ok()); + + let mut valid_post_with_q = HeaderMap::new(); + valid_post_with_q.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + valid_post_with_q.insert( + ACCEPT, + HeaderValue::from_static("application/json;q=0.9, text/event-stream"), + ); + assert!(validate_post_headers(&valid_post_with_q).is_ok()); + + let mut valid_get_with_q = HeaderMap::new(); + valid_get_with_q.insert(ACCEPT, HeaderValue::from_static("text/event-stream; q=0.5")); + assert!(validate_get_headers(&valid_get_with_q).is_ok()); + + let mut invalid_get = HeaderMap::new(); + invalid_get.insert(ACCEPT, HeaderValue::from_static("application/json")); + assert!(matches!( + validate_get_headers(&invalid_get), + Err(HttpTransportError::NotAcceptable { + message: "Accept must include text/event-stream", + source: None, + }) + )); + } + + #[test] + fn validate_origin_allows_loopback_and_matching_hosts() { + let mut loopback = HeaderMap::new(); + loopback.insert(ORIGIN, HeaderValue::from_static("http://localhost:3000")); + assert!(validate_origin(&loopback, IpAddr::V4(Ipv4Addr::LOCALHOST)).is_ok()); + + let mut proxied = HeaderMap::new(); + proxied.insert(ORIGIN, HeaderValue::from_static("https://example.com")); + proxied.insert(HOST, HeaderValue::from_static("example.com")); + assert!(validate_origin(&proxied, IpAddr::V4(Ipv4Addr::UNSPECIFIED)).is_ok()); + } + + #[test] + fn validate_origin_rejects_invalid_and_remote_origins() { + let mut invalid = HeaderMap::new(); + invalid.insert(ORIGIN, HeaderValue::from_static("not a uri")); + assert!(matches!( + validate_origin(&invalid, IpAddr::V4(Ipv4Addr::LOCALHOST)), + Err(HttpTransportError::Forbidden { + message: "invalid Origin header", + source: Some(_), + }) + )); + + let mut remote = HeaderMap::new(); + remote.insert(ORIGIN, HeaderValue::from_static("https://evil.example")); + assert!(matches!( + validate_origin(&remote, IpAddr::V4(Ipv4Addr::LOCALHOST)), + Err(HttpTransportError::Forbidden { + message: "Origin is not allowed", + source: None, + }) + )); + } + + #[test] + fn validate_http_context_enforces_connection_and_session_rules() { + let initialize = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"# + .to_string(), + ) + .unwrap(); + let connection_id = AcpConnectionId::default(); + let session_id = session_id(); + + assert!(validate_http_context(&initialize, None, None).is_ok()); + assert!(matches!( + validate_http_context(&initialize, Some(&connection_id), None), + Err(HttpTransportError::BadRequest { + message: "initialize must not include Acp-Connection-Id", + source: None, + }) + )); + + let initialized = + IncomingHttpMessage::parse(r#"{"jsonrpc":"2.0","method":"initialized"}"#.to_string()) + .unwrap(); + assert!(matches!( + validate_http_context(&initialized, None, None), + Err(HttpTransportError::BadRequest { + message: "missing Acp-Connection-Id header", + source: None, + }) + )); + + let prompt = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/prompt","params":{"sessionId":"session-1"}}"# + .to_string(), + ) + .unwrap(); + assert!(matches!( + validate_http_context(&prompt, Some(&connection_id), None), + Err(HttpTransportError::BadRequest { + message: "missing Acp-Session-Id header", + source: None, + }) + )); + + let mismatched = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/prompt","params":{"sessionId":"session-2"}}"# + .to_string(), + ) + .unwrap(); + assert!(matches!( + validate_http_context(&mismatched, Some(&connection_id), Some(&session_id)), + Err(HttpTransportError::BadRequest { + message: "Acp-Session-Id header does not match request body sessionId", + source: None, + }) + )); + + let load = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/load","params":{"sessionId":"session-1","cwd":"."}}"# + .to_string(), + ) + .unwrap(); + assert!(validate_http_context(&load, Some(&connection_id), None).is_ok()); + + let resume = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/resume","params":{"sessionId":"session-1","cwd":"."}}"# + .to_string(), + ) + .unwrap(); + assert!(validate_http_context(&resume, Some(&connection_id), None).is_ok()); + + let fork = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":2,"method":"session/fork","params":{"sessionId":"session-1","cwd":"."}}"# + .to_string(), + ) + .unwrap(); + assert!(validate_http_context(&fork, Some(&connection_id), None).is_ok()); + } + + #[test] + fn validate_protocol_version_header_allows_missing_and_rejects_mismatches() { + assert!(validate_protocol_version_header(None, None).is_ok()); + assert!(validate_protocol_version_header(Some(&ProtocolVersion::V0), None).is_ok()); + assert!(validate_protocol_version_header(None, Some(&ProtocolVersion::V0)).is_ok()); + assert!( + validate_protocol_version_header( + Some(&ProtocolVersion::V0), + Some(&ProtocolVersion::V0) + ) + .is_ok() + ); + + assert!(matches!( + validate_protocol_version_header( + Some(&ProtocolVersion::V1), + Some(&ProtocolVersion::V0) + ), + Err(HttpTransportError::BadRequest { + message: "Acp-Protocol-Version header does not match initialized protocol version", + source: None, + }) + )); + } + + #[test] + fn header_parsers_and_websocket_detection_handle_valid_and_invalid_values() { + let connection_id = AcpConnectionId::default(); + let session_id = session_id(); + let mut headers = HeaderMap::new(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + headers.insert( + ACP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id.as_str()).unwrap(), + ); + headers.insert(ACP_PROTOCOL_VERSION_HEADER, HeaderValue::from_static("0")); + headers.insert("upgrade", HeaderValue::from_static("websocket")); + + assert_eq!( + parse_connection_id_header(&headers).unwrap(), + Some(connection_id.clone()) + ); + assert_eq!(parse_session_id_header(&headers).unwrap(), Some(session_id)); + assert_eq!( + parse_protocol_version_header(&headers).unwrap(), + Some(ProtocolVersion::V0) + ); + assert!(is_websocket_request(&headers)); + + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_static("not-a-uuid"), + ); + assert!(matches!( + parse_connection_id_header(&headers), + Err(HttpTransportError::BadRequest { + message: "invalid Acp-Connection-Id header", + source: Some(_), + }) + )); + + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + headers.insert( + ACP_PROTOCOL_VERSION_HEADER, + HeaderValue::from_static("not-a-version"), + ); + assert!(matches!( + parse_protocol_version_header(&headers), + Err(HttpTransportError::BadRequest { + message: "invalid Acp-Protocol-Version header", + source: Some(_), + }) + )); + } + + #[tokio::test] + async fn http_post_returns_accepted_for_notifications() { + let (state, mut manager_rx) = test_state(); + let connection_id = AcpConnectionId::default(); + let expected_connection_id = connection_id.clone(); + + tokio::spawn(async move { + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpPost { + connection_id: Some(actual_connection_id), + session_id: None, + message, + response, + .. + } => { + assert_eq!(actual_connection_id, expected_connection_id); + assert!(message.is_notification()); + let _ = response.send(Ok(HttpPostOutcome::Accepted)); + } + _ => panic!("unexpected manager request"), + } + }); + + let mut headers = post_headers(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + + let response = http_post( + headers, + state, + r#"{"jsonrpc":"2.0","method":"initialized"}"#.to_string(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::ACCEPTED); + } + + #[tokio::test] + async fn http_post_accepts_null_result_responses() { + let (state, mut manager_rx) = test_state(); + let connection_id = AcpConnectionId::default(); + let session_id = session_id(); + let expected_connection_id = connection_id.clone(); + let expected_session_id = session_id.clone(); + + tokio::spawn(async move { + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpPost { + connection_id: Some(actual_connection_id), + session_id: Some(actual_session_id), + message, + response, + .. + } => { + assert_eq!(actual_connection_id, expected_connection_id); + assert_eq!(actual_session_id, expected_session_id); + assert!(message.is_response()); + let _ = response.send(Ok(HttpPostOutcome::Accepted)); + } + _ => panic!("unexpected manager request"), + } + }); + + let mut headers = post_headers(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + headers.insert( + ACP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id.as_str()).unwrap(), + ); + + let response = http_post( + headers, + state, + r#"{"jsonrpc":"2.0","id":1,"result":null}"#.to_string(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::ACCEPTED); + } + + #[tokio::test] + async fn http_post_returns_buffered_sse_with_session_headers() { + let (state, mut manager_rx) = test_state(); + let connection_id = AcpConnectionId::default(); + let session_id = session_id(); + let event = json!({ + "jsonrpc": "2.0", + "id": 2, + "result": { "sessionId": session_id.as_str() } + }); + let expected_event = event.clone(); + + let expected_connection_id = connection_id.clone(); + let expected_session_id = session_id.clone(); + tokio::spawn(async move { + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpPost { + connection_id: Some(actual_connection_id), + session_id: None, + message, + response, + .. + } => { + assert_eq!(actual_connection_id, expected_connection_id.clone()); + assert!(message.activates_session()); + let _ = response.send(Ok(HttpPostOutcome::Buffered { + connection_id: expected_connection_id, + protocol_version: Some(ProtocolVersion::V0), + session_id: Some(expected_session_id), + events: vec![SseFrame::json(event.to_string())], + })); + } + _ => panic!("unexpected manager request"), + } + }); + + let mut headers = post_headers(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + + let response = http_post( + headers, + state, + r#"{"jsonrpc":"2.0","id":2,"method":"session/new","params":{"cwd":".","mcpServers":[]}}"# + .to_string(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(ACP_CONNECTION_ID_HEADER).unwrap(), + HeaderValue::from_str(&connection_id.to_string()).unwrap() + ); + assert_eq!( + response.headers().get(ACP_SESSION_ID_HEADER).unwrap(), + HeaderValue::from_str(session_id.as_str()).unwrap() + ); + assert_eq!( + response.headers().get(ACP_PROTOCOL_VERSION_HEADER).unwrap(), + HeaderValue::from_static("0") + ); + assert_eq!(json_event_body(response).await, vec![expected_event]); + } + + #[tokio::test] + async fn http_get_and_delete_round_trip_through_manager() { + let (state, mut manager_rx) = test_state(); + let connection_id = AcpConnectionId::default(); + let session_id = session_id(); + let expected_connection_id = connection_id.clone(); + let expected_session_id = session_id.clone(); + + tokio::spawn(async move { + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpGet { + connection_id: actual_connection_id, + session_id: actual_session_id, + protocol_version, + response, + } => { + assert_eq!(actual_connection_id, expected_connection_id.clone()); + assert_eq!(actual_session_id, expected_session_id); + assert_eq!(protocol_version, None); + let (stream_tx, stream_rx) = mpsc::channel(HTTP_CHANNEL_CAPACITY); + let _ = stream_tx.try_send(SseFrame::json( + json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { "sessionId": "session-1" } + }) + .to_string(), + )); + drop(stream_tx); + let _ = response.send(Ok(HttpGetOutcome { + protocol_version: Some(ProtocolVersion::V0), + stream: stream_rx, + })); + } + _ => panic!("unexpected manager request"), + } + match manager_rx.recv().await.unwrap() { + ManagerRequest::HttpDelete { + connection_id: actual_connection_id, + protocol_version, + response, + } => { + assert_eq!(actual_connection_id, expected_connection_id); + assert_eq!(protocol_version, None); + let _ = response.send(Ok(Some(ProtocolVersion::V0))); + } + _ => panic!("unexpected manager request"), + } + }); + + let mut headers = get_headers(); + headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + headers.insert( + ACP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id.as_str()).unwrap(), + ); + + let response = http_get(headers, state.clone()).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(ACP_CONNECTION_ID_HEADER).unwrap(), + HeaderValue::from_str(&connection_id.to_string()).unwrap() + ); + assert_eq!( + response.headers().get(ACP_SESSION_ID_HEADER).unwrap(), + HeaderValue::from_str(session_id.as_str()).unwrap() + ); + assert_eq!( + response.headers().get(ACP_PROTOCOL_VERSION_HEADER).unwrap(), + HeaderValue::from_static("0") + ); + assert_eq!( + response.headers().get(X_ACCEL_BUFFERING_HEADER).unwrap(), + HeaderValue::from_static("no") + ); + assert_eq!( + json_event_body(response).await, + vec![json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { "sessionId": "session-1" } + })] + ); + + let mut delete_headers = HeaderMap::new(); + delete_headers.insert( + ACP_CONNECTION_ID_HEADER, + HeaderValue::from_str(&connection_id.to_string()).unwrap(), + ); + + let response = http_delete(delete_headers, state).await.unwrap(); + assert_eq!(response.status(), StatusCode::ACCEPTED); + assert_eq!( + response.headers().get(ACP_PROTOCOL_VERSION_HEADER).unwrap(), + HeaderValue::from_static("0") + ); + } + + #[tokio::test] + async fn get_rejects_incomplete_websocket_upgrade() { + let (state, _manager_rx) = test_state(); + let request = HttpRequest::builder() + .method("GET") + .uri("/acp") + .header("upgrade", "websocket") + .body(Body::empty()) + .unwrap(); + + let response = get(State(state), request).await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn post_rejects_disallowed_origin_before_processing() { + let (state, _manager_rx) = test_state(); + let mut headers = post_headers(); + headers.insert(ORIGIN, HeaderValue::from_static("https://evil.example")); + + let response = post( + headers, + State(state), + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"# + .to_string(), + ) + .await; + + assert_eq!(response.status(), StatusCode::FORBIDDEN); + let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + assert_eq!( + String::from_utf8(bytes.to_vec()).unwrap(), + "Origin is not allowed" + ); + } + + #[tokio::test] + async fn get_rejects_websocket_upgrade_with_disallowed_origin() { + let (state, _manager_rx) = test_state(); + let request = HttpRequest::builder() + .method("GET") + .uri("/acp") + .header("upgrade", "websocket") + .header(ORIGIN, "https://evil.example") + .body(Body::empty()) + .unwrap(); + + let response = get(State(state), request).await; + assert_eq!(response.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn process_manager_request_rejects_invalid_or_unknown_http_targets() { + let nats_client = AdvancedMockNatsClient::new(); + let js_client = MockJs::new(); + let config = test_config(); + let mut http_connections = HashMap::new(); + let mut websocket_handles = Vec::new(); + let mut http_connection_handles = Vec::new(); + let (_shutdown_tx, shutdown_rx) = watch::channel(false); + + let (post_response_tx, post_response_rx) = oneshot::channel(); + let post_message = + IncomingHttpMessage::parse(r#"{"jsonrpc":"2.0","method":"initialized"}"#.to_string()) + .unwrap(); + process_manager_request( + ManagerRequest::HttpPost { + connection_id: None, + protocol_version: None, + session_id: None, + message: post_message, + response: post_response_tx, + shutdown_rx: shutdown_rx.clone(), + }, + &mut http_connections, + &mut websocket_handles, + &mut http_connection_handles, + &nats_client, + &js_client, + &config, + ) + .await; + assert!(matches!( + post_response_rx.await.unwrap(), + Err(HttpTransportError::BadRequest { + message: "missing Acp-Connection-Id header", + source: None, + }) + )); + + let unknown_connection_id = AcpConnectionId::default(); + let (get_response_tx, get_response_rx) = oneshot::channel(); + process_manager_request( + ManagerRequest::HttpGet { + connection_id: unknown_connection_id.clone(), + session_id: session_id(), + protocol_version: None, + response: get_response_tx, + }, + &mut http_connections, + &mut websocket_handles, + &mut http_connection_handles, + &nats_client, + &js_client, + &config, + ) + .await; + assert!(matches!( + get_response_rx.await.unwrap(), + Err(HttpTransportError::NotFound { + message: "unknown ACP connection", + source: None, + }) + )); + + let (delete_response_tx, delete_response_rx) = oneshot::channel(); + process_manager_request( + ManagerRequest::HttpDelete { + connection_id: unknown_connection_id, + protocol_version: None, + response: delete_response_tx, + }, + &mut http_connections, + &mut websocket_handles, + &mut http_connection_handles, + &nats_client, + &js_client, + &config, + ) + .await; + assert!(matches!( + delete_response_rx.await.unwrap(), + Err(HttpTransportError::NotFound { + message: "unknown ACP connection", + source: None, + }) + )); + } + + #[tokio::test] + async fn process_manager_request_prunes_closed_http_connections() { + let nats_client = AdvancedMockNatsClient::new(); + let js_client = MockJs::new(); + let config = test_config(); + let mut http_connections = HashMap::new(); + let mut websocket_handles = Vec::new(); + let mut http_connection_handles = Vec::new(); + + let stale_connection_id = AcpConnectionId::default(); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + drop(command_rx); + http_connections.insert(stale_connection_id, HttpConnectionHandle { command_tx }); + + let unknown_connection_id = AcpConnectionId::default(); + let (response_tx, response_rx) = oneshot::channel(); + process_manager_request( + ManagerRequest::HttpGet { + connection_id: unknown_connection_id, + session_id: session_id(), + protocol_version: None, + response: response_tx, + }, + &mut http_connections, + &mut websocket_handles, + &mut http_connection_handles, + &nats_client, + &js_client, + &config, + ) + .await; + + assert!(http_connections.is_empty()); + assert!(matches!( + response_rx.await.unwrap(), + Err(HttpTransportError::NotFound { + message: "unknown ACP connection", + source: None, + }) + )); + } + + #[tokio::test(flavor = "current_thread")] + async fn process_manager_request_tracks_http_connection_tasks() { + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let nats_client = AdvancedMockNatsClient::new(); + let _injector = nats_client.inject_messages(); + nats_client.hang_next_request(); + + let js_client = MockJs::new(); + let config = test_config(); + let mut http_connections = HashMap::new(); + let mut websocket_handles = Vec::new(); + let mut http_connection_handles = Vec::new(); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let (response_tx, response_rx) = oneshot::channel(); + let initialize = IncomingHttpMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"# + .to_string(), + ) + .unwrap(); + + process_manager_request( + ManagerRequest::HttpPost { + connection_id: None, + protocol_version: None, + session_id: None, + message: initialize, + response: response_tx, + shutdown_rx, + }, + &mut http_connections, + &mut websocket_handles, + &mut http_connection_handles, + &nats_client, + &js_client, + &config, + ) + .await; + + assert_eq!(http_connections.len(), 1); + assert_eq!(http_connection_handles.len(), 1); + assert!(!http_connection_handles[0].is_finished()); + assert!(matches!( + response_rx.await.unwrap(), + Ok(HttpPostOutcome::Live { .. }) + )); + + let _ = shutdown_tx.send(true); + tokio::time::timeout(std::time::Duration::from_secs(2), async { + while !http_connection_handles[0].is_finished() { + tokio::task::yield_now().await; + } + }) + .await + .expect("HTTP connection task did not finish after shutdown"); + }) + .await; + } +} diff --git a/rsworkspace/crates/acp-nats-ws/src/upgrade.rs b/rsworkspace/crates/acp-nats-ws/src/upgrade.rs deleted file mode 100644 index a63ad1343..000000000 --- a/rsworkspace/crates/acp-nats-ws/src/upgrade.rs +++ /dev/null @@ -1,99 +0,0 @@ -use axum::extract::State; -use axum::extract::ws::{WebSocket, WebSocketUpgrade}; -use axum::response::Response; -use tokio::sync::{mpsc, watch}; -use tracing::error; - -pub struct ConnectionRequest { - pub socket: WebSocket, - pub shutdown_rx: watch::Receiver, -} - -#[derive(Clone)] -pub struct UpgradeState { - pub conn_tx: mpsc::UnboundedSender, - pub shutdown_tx: watch::Sender, -} - -pub async fn handle(ws: WebSocketUpgrade, State(state): State) -> Response { - let shutdown_rx = state.shutdown_tx.subscribe(); - ws.on_upgrade(move |socket| async move { - if state - .conn_tx - .send(ConnectionRequest { - socket, - shutdown_rx, - }) - .is_err() - { - error!("Connection thread is gone; dropping WebSocket"); - } - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::time::Duration; - use tokio::net::TcpListener; - use tokio_tungstenite::connect_async; - - #[tokio::test] - async fn handle_sends_connection_request_through_channel() { - let (shutdown_tx, _shutdown_rx) = watch::channel(false); - let (conn_tx, mut conn_rx) = mpsc::unbounded_channel::(); - - let state = UpgradeState { - conn_tx, - shutdown_tx: shutdown_tx.clone(), - }; - - let app = axum::Router::new() - .route("/ws", axum::routing::get(handle)) - .with_state(state); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - tokio::spawn(async move { - axum::serve(listener, app).await.unwrap(); - }); - - let url = format!("ws://{}/ws", addr); - let (_ws, _) = connect_async(&url).await.unwrap(); - - let req = tokio::time::timeout(Duration::from_secs(2), conn_rx.recv()) - .await - .expect("timeout waiting for ConnectionRequest") - .expect("channel closed"); - - assert!(!*req.shutdown_rx.borrow()); - } - - #[tokio::test] - async fn handle_logs_error_when_conn_rx_dropped() { - let (shutdown_tx, _shutdown_rx) = watch::channel(false); - let (conn_tx, conn_rx) = mpsc::unbounded_channel::(); - - let state = UpgradeState { - conn_tx, - shutdown_tx: shutdown_tx.clone(), - }; - - let app = axum::Router::new() - .route("/ws", axum::routing::get(handle)) - .with_state(state); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - tokio::spawn(async move { - axum::serve(listener, app).await.unwrap(); - }); - - drop(conn_rx); - - let url = format!("ws://{}/ws", addr); - let (_ws, _) = connect_async(&url).await.unwrap(); - - tokio::time::sleep(Duration::from_millis(100)).await; - } -}