diff --git a/rsworkspace/Cargo.lock b/rsworkspace/Cargo.lock
index 9ed295afb..2f720feae 100644
--- a/rsworkspace/Cargo.lock
+++ b/rsworkspace/Cargo.lock
@@ -76,13 +76,16 @@ dependencies = [
"clap",
"futures-util",
"opentelemetry",
+ "serde",
"serde_json",
"tokio",
"tokio-tungstenite 0.29.0",
+ "tower",
"tracing",
"tracing-subscriber",
"trogon-nats",
"trogon-std",
+ "uuid",
]
[[package]]
diff --git a/rsworkspace/crates/acp-nats-ws/Cargo.toml b/rsworkspace/crates/acp-nats-ws/Cargo.toml
index c5451d991..9850b06d3 100644
--- a/rsworkspace/crates/acp-nats-ws/Cargo.toml
+++ b/rsworkspace/crates/acp-nats-ws/Cargo.toml
@@ -17,13 +17,16 @@ bytes = { workspace = true }
clap = { workspace = true, features = ["env"] }
futures-util = { workspace = true, features = ["sink"] }
opentelemetry = { workspace = true }
+serde = { workspace = true, features = ["derive"] }
+serde_json = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal", "net", "sync", "io-util"] }
tracing = { workspace = true }
trogon-nats = { workspace = true }
trogon-std = { workspace = true, features = ["telemetry-http"] }
+uuid = { workspace = true }
[dev-dependencies]
-serde_json = { workspace = true }
+tower = { version = "0.5", features = ["util"] }
tokio-tungstenite = { workspace = true }
tracing-subscriber = { workspace = true, features = ["fmt"] }
trogon-nats = { workspace = true, features = ["test-support"] }
diff --git a/rsworkspace/crates/acp-nats-ws/README.md b/rsworkspace/crates/acp-nats-ws/README.md
index 846fc17b6..c30f99a05 100644
--- a/rsworkspace/crates/acp-nats-ws/README.md
+++ b/rsworkspace/crates/acp-nats-ws/README.md
@@ -1,21 +1,22 @@
-# ACP NATS WebSocket
+# ACP NATS Streamable HTTP & WebSocket
-Translates [Agent Client Protocol](https://agentclientprotocol.com) (ACP) messages between WebSocket connections and [NATS](https://nats.io), letting browser-based UIs and remote clients talk to distributed agent backends over a standard WebSocket endpoint.
+Translates [Agent Client Protocol](https://agentclientprotocol.com) (ACP) messages between [NATS](https://nats.io) and the draft remote transport served on `/acp`, including both Streamable HTTP (`POST`/`GET`/`DELETE`) and WebSocket upgrade.
For managed NATS infrastructure in production, we recommend
Synadia.
```mermaid
graph LR
- A1[Client1] <-->|ws| B[acp-nats-ws]
- A2[Client2] <-->|ws| B
+ A1[Client1] <-->|http or ws| B[acp-nats-ws]
+ A2[Client2] <-->|http or ws| B
AN[ClientN] <-->|ws| B
B <-->|NATS| C[Backend]
```
## Features
-- Multiple concurrent WebSocket connections, each with its own ACP session
-- Bidirectional ACP bridge with request forwarding
+- Streamable HTTP transport on `/acp` with session-scoped SSE listeners
+- WebSocket upgrade on `/acp`
+- Multiple concurrent ACP connections sharing the same NATS bridge
- OpenTelemetry integration (logs, metrics, traces)
- Graceful shutdown (SIGINT/SIGTERM) with per-connection drain
- Custom prefix support for multi-tenancy
@@ -33,9 +34,25 @@ cargo build --release -p acp-nats-ws
Connect with any WebSocket client:
```bash
-websocat ws://127.0.0.1:8080/ws
+websocat ws://127.0.0.1:8080/acp
```
+Or use Streamable HTTP:
+
+```bash
+curl -i \
+ -H 'Content-Type: application/json' \
+ -H 'Accept: application/json, text/event-stream' \
+ -d '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}' \
+ http://127.0.0.1:8080/acp
+```
+
+`POST /acp` returns an SSE response for JSON-RPC requests, `GET /acp` opens a session-scoped SSE listener with `Acp-Connection-Id` and `Acp-Session-Id`, and `DELETE /acp` terminates a connection. The WebSocket upgrade response and HTTP initialize response both include `Acp-Connection-Id`.
+
+After `initialize`, HTTP clients may send `Acp-Protocol-Version` on `POST`/`GET`/`DELETE`. When present, it must match the negotiated ACP protocol version for that connection.
+
+When clients send an `Origin` header, `/acp` validates it against the bound host and rejects disallowed origins with `403 Forbidden`. Streamable HTTP `POST` SSE responses also emit a priming SSE event ID before JSON-RPC payloads and attach event IDs to streamed JSON events.
+
## Configuration
### WebSocket Server
diff --git a/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs
new file mode 100644
index 000000000..47bd2fc99
--- /dev/null
+++ b/rsworkspace/crates/acp-nats-ws/src/acp_connection_id.rs
@@ -0,0 +1,88 @@
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+pub struct AcpConnectionId(uuid::Uuid);
+
+impl AcpConnectionId {
+ pub fn new(uuid: uuid::Uuid) -> Self {
+ Self(uuid)
+ }
+
+ pub fn parse(s: &str) -> Result {
+ uuid::Uuid::parse_str(s)
+ .map(Self::new)
+ .map_err(AcpConnectionIdError::InvalidUuid)
+ }
+}
+
+impl Default for AcpConnectionId {
+ fn default() -> Self {
+ Self::new(uuid::Uuid::now_v7())
+ }
+}
+
+impl std::fmt::Display for AcpConnectionId {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+#[derive(Debug)]
+pub enum AcpConnectionIdError {
+ InvalidUuid(uuid::Error),
+}
+
+impl std::fmt::Display for AcpConnectionIdError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::InvalidUuid(error) => write!(f, "invalid ACP connection id: {error}"),
+ }
+ }
+}
+
+impl std::error::Error for AcpConnectionIdError {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ match self {
+ Self::InvalidUuid(error) => Some(error),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::error::Error as _;
+
+ #[test]
+ fn new_wraps_existing_uuid() {
+ let uuid = uuid::Uuid::nil();
+ assert_eq!(AcpConnectionId::new(uuid).to_string(), uuid.to_string());
+ }
+
+ #[test]
+ fn parse_round_trips_uuid() {
+ let id = AcpConnectionId::default();
+ let parsed = AcpConnectionId::parse(&id.to_string()).unwrap();
+ assert_eq!(parsed, id);
+ }
+
+ #[test]
+ fn parse_rejects_invalid_uuid() {
+ assert!(AcpConnectionId::parse("not-a-uuid").is_err());
+ }
+
+ #[test]
+ fn default_generates_non_empty_id() {
+ assert!(!AcpConnectionId::default().to_string().is_empty());
+ }
+
+ #[test]
+ fn parse_error_displays_context() {
+ let error = AcpConnectionId::parse("not-a-uuid").unwrap_err();
+ assert!(error.to_string().contains("invalid ACP connection id"));
+ }
+
+ #[test]
+ fn parse_error_exposes_uuid_source() {
+ let error = AcpConnectionId::parse("not-a-uuid").unwrap_err();
+ assert!(error.source().is_some());
+ }
+}
diff --git a/rsworkspace/crates/acp-nats-ws/src/connection.rs b/rsworkspace/crates/acp-nats-ws/src/connection.rs
index 7f54f25b9..0cfa840d7 100644
--- a/rsworkspace/crates/acp-nats-ws/src/connection.rs
+++ b/rsworkspace/crates/acp-nats-ws/src/connection.rs
@@ -9,10 +9,12 @@ use tokio::sync::watch;
use tracing::{error, info, warn};
use trogon_std::time::SystemClock;
+use crate::acp_connection_id::AcpConnectionId;
use crate::constants::DUPLEX_BUFFER_SIZE;
/// Handles a single WebSocket connection by bridging it to NATS via ACP.
pub async fn handle(
+ connection_id: AcpConnectionId,
socket: WebSocket,
nats_client: N,
js_client: J,
@@ -68,7 +70,7 @@ pub async fn handle(
let mut io_task = tokio::task::spawn_local(io_task);
- info!("WebSocket connection established, ACP bridge running");
+ info!(%connection_id, "WebSocket connection established, ACP bridge running");
let shutdown_result = tokio::select! {
result = &mut client_task => {
@@ -118,8 +120,8 @@ pub async fn handle(
}
match shutdown_result {
- Ok(()) => info!("WebSocket connection closed cleanly"),
- Err(e) => warn!(error = e, "WebSocket connection closed with error"),
+ Ok(()) => info!(%connection_id, "WebSocket connection closed cleanly"),
+ Err(e) => warn!(%connection_id, error = e, "WebSocket connection closed with error"),
}
}
@@ -128,34 +130,24 @@ async fn run_recv_pump(
mut ws_recv_write: tokio::io::DuplexStream,
) {
while let Some(Ok(msg)) = ws_receiver.next().await {
- let bytes = match msg {
- Message::Text(t) => bytes::Bytes::from(t),
- Message::Binary(b) => b,
+ let text = match msg {
+ Message::Text(text) => text,
+ Message::Binary(_) => continue,
Message::Close(_) => break,
_ => continue,
};
- match std::str::from_utf8(&bytes) {
- Ok(text) => {
- let line = text.trim_end_matches(['\r', '\n']);
- if line.is_empty() {
- continue;
- }
+ let line = text.trim_end_matches(['\r', '\n']);
+ if line.is_empty() {
+ continue;
+ }
- if ws_recv_write.write_all(line.as_bytes()).await.is_err() {
- break;
- }
+ if ws_recv_write.write_all(line.as_bytes()).await.is_err() {
+ break;
+ }
- if ws_recv_write.write_all(b"\n").await.is_err() {
- break;
- }
- }
- Err(e) => {
- warn!(
- error = %e,
- "Received non-UTF-8 WebSocket message, dropping frame"
- );
- }
+ if ws_recv_write.write_all(b"\n").await.is_err() {
+ break;
}
}
}
@@ -196,6 +188,8 @@ mod tests {
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
+ use crate::constants::ACP_ENDPOINT;
+
#[derive(Clone)]
struct EchoState;
@@ -213,12 +207,12 @@ mod tests {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = axum::Router::new()
- .route("/ws", axum::routing::get(echo_handler))
+ .route(ACP_ENDPOINT, axum::routing::get(echo_handler))
.with_state(EchoState);
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
- format!("ws://{}/ws", addr)
+ format!("ws://{}{}", addr, ACP_ENDPOINT)
}
#[tokio::test]
@@ -245,4 +239,30 @@ mod tests {
}
}
}
+
+ #[tokio::test]
+ async fn binary_messages_are_ignored() {
+ let url = start_echo_server().await;
+ let (mut ws, _) = connect_async(&url).await.unwrap();
+
+ ws.send(TungsteniteMessage::Binary(bytes::Bytes::from_static(
+ b"ignored",
+ )))
+ .await
+ .unwrap();
+ ws.send(TungsteniteMessage::Text("kept".into()))
+ .await
+ .unwrap();
+
+ let msg = tokio::time::timeout(Duration::from_secs(2), ws.next())
+ .await
+ .expect("timeout")
+ .expect("stream ended")
+ .unwrap();
+
+ match msg {
+ TungsteniteMessage::Text(text) => assert_eq!(text, "kept"),
+ other => panic!("expected text frame, got {other:?}"),
+ }
+ }
}
diff --git a/rsworkspace/crates/acp-nats-ws/src/constants.rs b/rsworkspace/crates/acp-nats-ws/src/constants.rs
index f78fd678c..93c278bad 100644
--- a/rsworkspace/crates/acp-nats-ws/src/constants.rs
+++ b/rsworkspace/crates/acp-nats-ws/src/constants.rs
@@ -1,6 +1,12 @@
use std::net::{IpAddr, Ipv4Addr};
+pub const ACP_CONNECTION_ID_HEADER: &str = "acp-connection-id";
+pub const ACP_ENDPOINT: &str = "/acp";
+pub const ACP_PROTOCOL_VERSION_HEADER: &str = "acp-protocol-version";
+pub const ACP_SESSION_ID_HEADER: &str = "acp-session-id";
pub const DEFAULT_HOST: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
pub const DEFAULT_PORT: u16 = 8080;
pub const DUPLEX_BUFFER_SIZE: usize = 64 * 1024;
+pub const HTTP_CHANNEL_CAPACITY: usize = 64;
pub const THREAD_NAME: &str = "acp-ws-local";
+pub const X_ACCEL_BUFFERING_HEADER: &str = "x-accel-buffering";
diff --git a/rsworkspace/crates/acp-nats-ws/src/main.rs b/rsworkspace/crates/acp-nats-ws/src/main.rs
index 6faee93d7..4be6b0e95 100644
--- a/rsworkspace/crates/acp-nats-ws/src/main.rs
+++ b/rsworkspace/crates/acp-nats-ws/src/main.rs
@@ -1,11 +1,12 @@
+mod acp_connection_id;
mod config;
mod connection;
mod constants;
-mod upgrade;
+mod transport;
use tokio::sync::{mpsc, watch};
use tracing::info;
-use upgrade::{ConnectionRequest, UpgradeState};
+use transport::{AppState, ManagerRequest};
#[cfg(not(coverage))]
use {
@@ -26,7 +27,7 @@ async fn main() -> Result<(), Box> {
);
let ws_config = config::apply_timeout_overrides(ws_config, &SystemEnv);
- info!("ACP WebSocket bridge starting");
+ info!("ACP remote transport bridge starting");
let nats_connect_timeout = acp_nats::nats_connect_timeout(&SystemEnv);
let nats_client = nats::connect(ws_config.acp.nats(), nats_connect_timeout).await?;
@@ -35,27 +36,33 @@ async fn main() -> Result<(), Box> {
let js_client = trogon_nats::jetstream::NatsJetStreamClient::new(js_context);
let (shutdown_tx, _) = watch::channel(false);
- let (conn_tx, conn_rx) = mpsc::unbounded_channel::();
+ let (manager_tx, manager_rx) = mpsc::unbounded_channel::();
let conn_thread = std::thread::Builder::new()
.name(THREAD_NAME.into())
- .spawn(move || run_connection_thread(conn_rx, nats_client, js_client, ws_config.acp))?;
+ .spawn(move || run_connection_thread(manager_rx, nats_client, js_client, ws_config.acp))?;
- let state = UpgradeState {
- conn_tx,
+ let state = AppState {
+ bind_host: ws_config.host,
+ manager_tx,
shutdown_tx: shutdown_tx.clone(),
};
let app = trogon_std::telemetry::http::instrument_router(
axum::Router::new()
- .route("/ws", axum::routing::get(upgrade::handle))
+ .route(
+ ACP_ENDPOINT,
+ axum::routing::get(transport::get)
+ .post(transport::post)
+ .delete(transport::delete),
+ )
.with_state(state),
);
let addr = SocketAddr::from((ws_config.host, ws_config.port));
let listener = tokio::net::TcpListener::bind(addr).await?;
- info!(address = %addr, "Listening for WebSocket connections");
+ info!(address = %addr, "Listening for ACP transport connections");
let result = axum::serve(listener, app)
.with_graceful_shutdown(async move {
@@ -66,8 +73,8 @@ async fn main() -> Result<(), Box> {
.await;
match &result {
- Ok(()) => info!("ACP WebSocket bridge stopped"),
- Err(e) => error!(error = %e, "ACP WebSocket bridge stopped with error"),
+ Ok(()) => info!("ACP remote transport bridge stopped"),
+ Err(e) => error!(error = %e, "ACP remote transport bridge stopped with error"),
}
// `serve` returning drops the Router (and its AppState.conn_tx), which
@@ -88,13 +95,13 @@ async fn main() -> Result<(), Box> {
#[cfg(coverage)]
fn main() {}
-use constants::THREAD_NAME;
+use constants::{ACP_ENDPOINT, THREAD_NAME};
/// Runs a single-threaded tokio runtime with a
/// `LocalSet`. All WebSocket connections are processed here because the ACP
/// `Agent` trait is `?Send`, requiring `spawn_local` / `Rc`.
fn run_connection_thread(
- conn_rx: mpsc::UnboundedReceiver,
+ manager_rx: mpsc::UnboundedReceiver,
nats_client: N,
js_client: J,
config: acp_nats::Config,
@@ -115,7 +122,12 @@ fn run_connection_thread(
.expect("failed to create per-connection runtime");
let local = tokio::task::LocalSet::new();
- rt.block_on(local.run_until(process_connections(conn_rx, nats_client, js_client, config)));
+ rt.block_on(local.run_until(process_connections(
+ manager_rx,
+ nats_client,
+ js_client,
+ config,
+ )));
// run_until returns once process_connections completes, but
// sub-tasks spawned by connection handlers (pumps,
@@ -130,7 +142,7 @@ fn run_connection_thread(
}
async fn process_connections(
- mut conn_rx: mpsc::UnboundedReceiver,
+ mut manager_rx: mpsc::UnboundedReceiver,
nats_client: N,
js_client: J,
config: acp_nats::Config,
@@ -145,29 +157,40 @@ async fn process_connections(
J: acp_nats::JetStreamPublisher + acp_nats::JetStreamGetStream + 'static,
trogon_nats::jetstream::JsMessageOf: trogon_nats::jetstream::JsRequestMessage,
{
- let mut conn_handles: Vec> = Vec::new();
-
- while let Some(req) = conn_rx.recv().await {
- conn_handles.retain(|h| !h.is_finished());
- let client = nats_client.clone();
- let js = js_client.clone();
- let cfg = config.clone();
- conn_handles.push(tokio::task::spawn_local(connection::handle(
- req.socket,
- client,
- js,
- cfg,
- req.shutdown_rx,
- )));
+ let mut websocket_handles: Vec> = Vec::new();
+ let mut http_connection_handles: Vec> = Vec::new();
+ let mut http_connections = std::collections::HashMap::new();
+
+ while let Some(request) = manager_rx.recv().await {
+ transport::process_manager_request(
+ request,
+ &mut http_connections,
+ &mut websocket_handles,
+ &mut http_connection_handles,
+ &nats_client,
+ &js_client,
+ &config,
+ )
+ .await;
}
- let active = conn_handles.iter().filter(|h| !h.is_finished()).count();
+ let active = websocket_handles
+ .iter()
+ .filter(|h| !h.is_finished())
+ .count()
+ + http_connection_handles
+ .iter()
+ .filter(|h| !h.is_finished())
+ .count();
info!(
active_connections = active,
"Connection channel closed, draining active connections"
);
- for handle in conn_handles {
+ for handle in websocket_handles {
+ let _ = handle.await;
+ }
+ for handle in http_connection_handles {
let _ = handle.await;
}
@@ -177,12 +200,21 @@ async fn process_connections(
#[cfg(test)]
mod tests {
use super::*;
+ use crate::constants::{
+ ACP_CONNECTION_ID_HEADER, ACP_PROTOCOL_VERSION_HEADER, ACP_SESSION_ID_HEADER,
+ };
use acp_nats::Config;
+ use axum::body::{Body, to_bytes};
+ use axum::http::header::{ACCEPT, CONTENT_TYPE};
+ use axum::http::{Request, StatusCode};
use futures_util::{SinkExt, StreamExt};
+ use serde_json::{Value, json};
+ use std::net::{IpAddr, Ipv4Addr};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message;
+ use tower::ServiceExt;
use trogon_nats::AdvancedMockNatsClient;
#[derive(Clone)]
@@ -230,41 +262,75 @@ mod tests {
}
}
- #[tokio::test]
- async fn test_websocket_connection_lifecycle() {
- let nats_mock = AdvancedMockNatsClient::new();
- let config = Config::new(
+ fn test_config() -> Config {
+ Config::new(
acp_nats::AcpPrefix::new("acp").unwrap(),
acp_nats::NatsConfig {
servers: vec!["localhost:4222".to_string()],
auth: trogon_nats::NatsAuth::None,
},
- );
-
- // Required by AdvancedMockNatsClient to not error out on subscribe()
- let _injector = nats_mock.inject_messages();
+ )
+ }
- let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
- let (conn_tx, conn_rx) = mpsc::unbounded_channel::();
+ fn test_app(state: AppState) -> axum::Router {
+ axum::Router::new()
+ .route(
+ ACP_ENDPOINT,
+ axum::routing::get(transport::get)
+ .post(transport::post)
+ .delete(transport::delete),
+ )
+ .with_state(state)
+ }
- let nats_mock_clone = nats_mock.clone();
- let conn_thread = std::thread::Builder::new()
+ fn spawn_connection_thread(
+ nats_mock: AdvancedMockNatsClient,
+ manager_rx: mpsc::UnboundedReceiver,
+ ) -> std::thread::JoinHandle<()> {
+ let config = test_config();
+ std::thread::Builder::new()
.name(THREAD_NAME.into())
- .spawn(move || run_connection_thread(conn_rx, nats_mock_clone, MockJs::new(), config))
- .expect("failed to spawn connection thread");
+ .spawn(move || run_connection_thread(manager_rx, nats_mock, MockJs::new(), config))
+ .expect("failed to spawn connection thread")
+ }
- let state = UpgradeState {
- conn_tx,
+ fn build_test_app(
+ nats_mock: AdvancedMockNatsClient,
+ ) -> (
+ axum::Router,
+ watch::Sender,
+ std::thread::JoinHandle<()>,
+ ) {
+ let (shutdown_tx, _) = watch::channel(false);
+ let (manager_tx, manager_rx) = mpsc::unbounded_channel::();
+ let conn_thread = spawn_connection_thread(nats_mock, manager_rx);
+ let app = test_app(AppState {
+ bind_host: IpAddr::V4(Ipv4Addr::LOCALHOST),
+ manager_tx,
shutdown_tx: shutdown_tx.clone(),
- };
+ });
+ (app, shutdown_tx, conn_thread)
+ }
- let app = axum::Router::new()
- .route("/ws", axum::routing::get(upgrade::handle))
- .with_state(state);
+ async fn start_test_server(
+ nats_mock: AdvancedMockNatsClient,
+ ) -> (
+ std::net::SocketAddr,
+ watch::Sender,
+ tokio::task::JoinHandle<()>,
+ std::thread::JoinHandle<()>,
+ ) {
+ let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
+ let (manager_tx, manager_rx) = mpsc::unbounded_channel::();
+ let conn_thread = spawn_connection_thread(nats_mock, manager_rx);
+ let app = test_app(AppState {
+ bind_host: IpAddr::V4(Ipv4Addr::LOCALHOST),
+ manager_tx,
+ shutdown_tx: shutdown_tx.clone(),
+ });
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
-
let server_task = tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move {
@@ -274,13 +340,49 @@ mod tests {
.unwrap();
});
+ (addr, shutdown_tx, server_task, conn_thread)
+ }
+
+ async fn body_text(response: axum::response::Response) -> String {
+ let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
+ String::from_utf8(bytes.to_vec()).unwrap()
+ }
+
+ fn sse_events(body: &str) -> Vec {
+ body.lines()
+ .filter_map(|line| line.strip_prefix("data: "))
+ .filter(|json| !json.is_empty())
+ .map(|json| serde_json::from_str(json).unwrap())
+ .collect()
+ }
+
+ fn http_post_request(body: &str) -> Request {
+ Request::builder()
+ .method("POST")
+ .uri(ACP_ENDPOINT)
+ .header(CONTENT_TYPE, "application/json")
+ .header(ACCEPT, "application/json, text/event-stream")
+ .body(Body::from(body.to_owned()))
+ .unwrap()
+ }
+
+ #[tokio::test]
+ async fn test_websocket_connection_lifecycle() {
+ let nats_mock = AdvancedMockNatsClient::new();
+
+ // Required by AdvancedMockNatsClient to not error out on subscribe()
+ let _injector = nats_mock.inject_messages();
+
// Setup mock response for NATS
let nats_response = r#"{"agentCapabilities": {"loadSession": false, "mcpCapabilities": {"http": false, "sse": false}, "promptCapabilities": {"audio": false, "embeddedContext": false, "image": false}, "sessionCapabilities": {}}, "authMethods": [], "protocolVersion": 0}"#;
nats_mock.set_response("acp.agent.initialize", nats_response.into());
+ let (addr, shutdown_tx, server_task, conn_thread) = start_test_server(nats_mock).await;
+
// Connect client
- let ws_url = format!("ws://{}/ws", addr);
- let (mut ws_stream, _) = connect_async(ws_url).await.unwrap();
+ let ws_url = format!("ws://{}{}", addr, ACP_ENDPOINT);
+ let (mut ws_stream, response) = connect_async(ws_url).await.unwrap();
+ assert!(response.headers().contains_key(ACP_CONNECTION_ID_HEADER));
// Send initialize request
let req =
@@ -322,49 +424,13 @@ mod tests {
#[tokio::test]
async fn test_shutdown_while_connection_active() {
let nats_mock = AdvancedMockNatsClient::new();
- let config = Config::new(
- acp_nats::AcpPrefix::new("acp").unwrap(),
- acp_nats::NatsConfig {
- servers: vec!["localhost:4222".to_string()],
- auth: trogon_nats::NatsAuth::None,
- },
- );
let _injector = nats_mock.inject_messages();
nats_mock.hang_next_request();
+ let (addr, shutdown_tx, server_task, conn_thread) = start_test_server(nats_mock).await;
- let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
- let (conn_tx, conn_rx) = mpsc::unbounded_channel::();
-
- let nats_mock_clone = nats_mock.clone();
- let conn_thread = std::thread::Builder::new()
- .name(THREAD_NAME.into())
- .spawn(move || run_connection_thread(conn_rx, nats_mock_clone, MockJs::new(), config))
- .expect("failed to spawn connection thread");
-
- let state = UpgradeState {
- conn_tx,
- shutdown_tx: shutdown_tx.clone(),
- };
-
- let app = axum::Router::new()
- .route("/ws", axum::routing::get(upgrade::handle))
- .with_state(state);
-
- let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
- let addr = listener.local_addr().unwrap();
-
- let server_task = tokio::spawn(async move {
- axum::serve(listener, app)
- .with_graceful_shutdown(async move {
- let _ = shutdown_rx.changed().await;
- })
- .await
- .unwrap();
- });
-
- let ws_url = format!("ws://{}/ws", addr);
+ let ws_url = format!("ws://{}{}", addr, ACP_ENDPOINT);
let (mut ws_stream, _) = connect_async(&ws_url).await.unwrap();
let req =
@@ -383,4 +449,373 @@ mod tests {
conn_thread.join().unwrap();
}
+
+ #[tokio::test]
+ async fn streamable_http_initialize_returns_connection_id_and_sse_response() {
+ let nats_mock = AdvancedMockNatsClient::new();
+ let _injector = nats_mock.inject_messages();
+ nats_mock.set_response(
+ "acp.agent.initialize",
+ r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"#
+ .into(),
+ );
+
+ let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock);
+ let response = app
+ .clone()
+ .oneshot(http_post_request(
+ r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#,
+ ))
+ .await
+ .unwrap();
+
+ assert_eq!(response.status(), StatusCode::OK);
+ assert_eq!(
+ response.headers().get(CONTENT_TYPE).unwrap(),
+ "text/event-stream"
+ );
+
+ let connection_id = response
+ .headers()
+ .get(ACP_CONNECTION_ID_HEADER)
+ .unwrap()
+ .to_str()
+ .unwrap()
+ .to_owned();
+ assert!(crate::acp_connection_id::AcpConnectionId::parse(&connection_id).is_ok());
+
+ let body = body_text(response).await;
+ let events = sse_events(&body);
+ assert_eq!(events.len(), 1);
+ assert_eq!(events[0]["id"], 1);
+ assert_eq!(events[0]["result"]["protocolVersion"], 0);
+
+ shutdown_tx.send(true).unwrap();
+ drop(app);
+ conn_thread.join().unwrap();
+ }
+
+ #[tokio::test]
+ async fn streamable_http_session_new_returns_session_header_and_body() {
+ let nats_mock = AdvancedMockNatsClient::new();
+ let _injector = nats_mock.inject_messages();
+ nats_mock.set_response(
+ "acp.agent.initialize",
+ r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"#
+ .into(),
+ );
+ nats_mock.set_response(
+ "acp.agent.session.new",
+ r#"{"sessionId":"test-session-1"}"#.into(),
+ );
+
+ let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock);
+ let initialize = app
+ .clone()
+ .oneshot(http_post_request(
+ r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#,
+ ))
+ .await
+ .unwrap();
+ let connection_id = initialize
+ .headers()
+ .get(ACP_CONNECTION_ID_HEADER)
+ .unwrap()
+ .to_str()
+ .unwrap()
+ .to_owned();
+ let _ = body_text(initialize).await;
+
+ let session_new = app
+ .clone()
+ .oneshot(
+ Request::builder()
+ .method("POST")
+ .uri(ACP_ENDPOINT)
+ .header(CONTENT_TYPE, "application/json")
+ .header(ACCEPT, "application/json, text/event-stream")
+ .header(ACP_CONNECTION_ID_HEADER, &connection_id)
+ .header(ACP_PROTOCOL_VERSION_HEADER, "0")
+ .body(Body::from(
+ r#"{"jsonrpc":"2.0","id":2,"method":"session/new","params":{"cwd":".","mcpServers":[]}}"#,
+ ))
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(session_new.status(), StatusCode::OK);
+ assert_eq!(
+ session_new
+ .headers()
+ .get(ACP_CONNECTION_ID_HEADER)
+ .unwrap()
+ .to_str()
+ .unwrap(),
+ connection_id
+ );
+
+ let session_id = session_new
+ .headers()
+ .get(ACP_SESSION_ID_HEADER)
+ .and_then(|value| value.to_str().ok())
+ .map(str::to_owned);
+ assert_eq!(
+ session_new
+ .headers()
+ .get(ACP_PROTOCOL_VERSION_HEADER)
+ .and_then(|value| value.to_str().ok()),
+ Some("0")
+ );
+ let body = body_text(session_new).await;
+ let events = sse_events(&body);
+ assert_eq!(events.len(), 1);
+ assert_eq!(events[0]["id"], 2);
+ assert_eq!(events[0]["result"]["sessionId"], "test-session-1");
+ assert_eq!(session_id.as_deref(), Some("test-session-1"));
+
+ let _ = shutdown_tx.send(true);
+ drop(app);
+ conn_thread.join().unwrap();
+ }
+
+ #[tokio::test]
+ async fn streamable_http_session_load_uses_request_session_id_header() {
+ let nats_mock = AdvancedMockNatsClient::new();
+ let _injector = nats_mock.inject_messages();
+ nats_mock.set_response(
+ "acp.agent.initialize",
+ r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"#
+ .into(),
+ );
+ nats_mock.set_response("acp.session.test-session-1.agent.load", "{}".into());
+
+ let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock);
+ let initialize = app
+ .clone()
+ .oneshot(http_post_request(
+ r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#,
+ ))
+ .await
+ .unwrap();
+ let connection_id = initialize
+ .headers()
+ .get(ACP_CONNECTION_ID_HEADER)
+ .unwrap()
+ .to_str()
+ .unwrap()
+ .to_owned();
+ let _ = body_text(initialize).await;
+
+ let session_load = app
+ .clone()
+ .oneshot(
+ Request::builder()
+ .method("POST")
+ .uri(ACP_ENDPOINT)
+ .header(CONTENT_TYPE, "application/json")
+ .header(ACCEPT, "application/json, text/event-stream")
+ .header(ACP_CONNECTION_ID_HEADER, &connection_id)
+ .header(ACP_PROTOCOL_VERSION_HEADER, "0")
+ .body(Body::from(
+ r#"{"jsonrpc":"2.0","id":2,"method":"session/load","params":{"sessionId":"test-session-1","cwd":"."}}"#,
+ ))
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(session_load.status(), StatusCode::OK);
+ assert_eq!(
+ session_load
+ .headers()
+ .get(ACP_CONNECTION_ID_HEADER)
+ .unwrap()
+ .to_str()
+ .unwrap(),
+ connection_id
+ );
+ assert_eq!(
+ session_load
+ .headers()
+ .get(ACP_SESSION_ID_HEADER)
+ .and_then(|value| value.to_str().ok()),
+ Some("test-session-1")
+ );
+ assert_eq!(
+ session_load
+ .headers()
+ .get(ACP_PROTOCOL_VERSION_HEADER)
+ .and_then(|value| value.to_str().ok()),
+ Some("0")
+ );
+
+ let body = body_text(session_load).await;
+ let events = sse_events(&body);
+ assert_eq!(events.len(), 1);
+ assert_eq!(events[0]["id"], 2);
+ assert_eq!(events[0]["result"], json!(null));
+
+ let _ = shutdown_tx.send(true);
+ drop(app);
+ conn_thread.join().unwrap();
+ }
+
+ #[tokio::test]
+ async fn streamable_http_get_requires_connection_and_session_headers() {
+ let nats_mock = AdvancedMockNatsClient::new();
+ let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock);
+
+ let response = app
+ .clone()
+ .oneshot(
+ Request::builder()
+ .method("GET")
+ .uri(ACP_ENDPOINT)
+ .header(ACCEPT, "text/event-stream")
+ .body(Body::empty())
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(response.status(), StatusCode::BAD_REQUEST);
+
+ let _ = shutdown_tx.send(true);
+ drop(app);
+ conn_thread.join().unwrap();
+ }
+
+ #[tokio::test]
+ async fn legacy_websocket_alias_is_not_routed() {
+ let nats_mock = AdvancedMockNatsClient::new();
+ let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock);
+
+ let response = app
+ .clone()
+ .oneshot(
+ Request::builder()
+ .method("GET")
+ .uri("/ws")
+ .body(Body::empty())
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(response.status(), StatusCode::NOT_FOUND);
+
+ let _ = shutdown_tx.send(true);
+ drop(app);
+ conn_thread.join().unwrap();
+ }
+
+ #[tokio::test]
+ async fn streamable_http_delete_terminates_initialized_connection() {
+ let nats_mock = AdvancedMockNatsClient::new();
+ let _injector = nats_mock.inject_messages();
+ nats_mock.set_response(
+ "acp.agent.initialize",
+ r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"#
+ .into(),
+ );
+
+ let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock);
+ let initialize = app
+ .clone()
+ .oneshot(http_post_request(
+ r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#,
+ ))
+ .await
+ .unwrap();
+ let connection_id = initialize
+ .headers()
+ .get(ACP_CONNECTION_ID_HEADER)
+ .unwrap()
+ .to_str()
+ .unwrap()
+ .to_owned();
+ let _ = body_text(initialize).await;
+
+ let response = app
+ .clone()
+ .oneshot(
+ Request::builder()
+ .method("DELETE")
+ .uri(ACP_ENDPOINT)
+ .header(ACP_CONNECTION_ID_HEADER, &connection_id)
+ .header(ACP_PROTOCOL_VERSION_HEADER, "0")
+ .body(Body::empty())
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(response.status(), StatusCode::ACCEPTED);
+ assert_eq!(
+ response
+ .headers()
+ .get(ACP_PROTOCOL_VERSION_HEADER)
+ .and_then(|value| value.to_str().ok()),
+ Some("0")
+ );
+
+ let _ = shutdown_tx.send(true);
+ drop(app);
+ conn_thread.join().unwrap();
+ }
+
+ #[tokio::test]
+ async fn streamable_http_rejects_mismatched_protocol_version_after_initialize() {
+ let nats_mock = AdvancedMockNatsClient::new();
+ let _injector = nats_mock.inject_messages();
+ nats_mock.set_response(
+ "acp.agent.initialize",
+ r#"{"agentCapabilities":{"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false},"sessionCapabilities":{}},"authMethods":[],"protocolVersion":0}"#
+ .into(),
+ );
+
+ let (app, shutdown_tx, conn_thread) = build_test_app(nats_mock);
+ let initialize = app
+ .clone()
+ .oneshot(http_post_request(
+ r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":0}}"#,
+ ))
+ .await
+ .unwrap();
+ let connection_id = initialize
+ .headers()
+ .get(ACP_CONNECTION_ID_HEADER)
+ .unwrap()
+ .to_str()
+ .unwrap()
+ .to_owned();
+ let _ = body_text(initialize).await;
+
+ let response = app
+ .clone()
+ .oneshot(
+ Request::builder()
+ .method("POST")
+ .uri(ACP_ENDPOINT)
+ .header(CONTENT_TYPE, "application/json")
+ .header(ACCEPT, "application/json, text/event-stream")
+ .header(ACP_CONNECTION_ID_HEADER, &connection_id)
+ .header(ACP_PROTOCOL_VERSION_HEADER, "1")
+ .body(Body::from(r#"{"jsonrpc":"2.0","method":"initialized"}"#))
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(response.status(), StatusCode::BAD_REQUEST);
+ assert_eq!(
+ body_text(response).await,
+ "Acp-Protocol-Version header does not match initialized protocol version"
+ );
+
+ let _ = shutdown_tx.send(true);
+ drop(app);
+ conn_thread.join().unwrap();
+ }
}
diff --git a/rsworkspace/crates/acp-nats-ws/src/transport.rs b/rsworkspace/crates/acp-nats-ws/src/transport.rs
new file mode 100644
index 000000000..9c710b49a
--- /dev/null
+++ b/rsworkspace/crates/acp-nats-ws/src/transport.rs
@@ -0,0 +1,2853 @@
+use crate::acp_connection_id::AcpConnectionId;
+use crate::connection;
+use crate::constants::{
+ ACP_CONNECTION_ID_HEADER, ACP_PROTOCOL_VERSION_HEADER, ACP_SESSION_ID_HEADER,
+ HTTP_CHANNEL_CAPACITY, X_ACCEL_BUFFERING_HEADER,
+};
+use acp_nats::{StdJsonSerialize, agent::Bridge, client, spawn_notification_forwarder};
+use agent_client_protocol::{AgentSideConnection, ProtocolVersion, RequestId, SessionNotification};
+use axum::extract::FromRequestParts;
+use axum::extract::Request;
+use axum::extract::State;
+use axum::extract::ws::{WebSocket, WebSocketUpgrade};
+use axum::http::header::{ACCEPT, CONTENT_TYPE, HOST, ORIGIN};
+use axum::http::{HeaderMap, HeaderValue, StatusCode, Uri, uri::Authority};
+use axum::response::sse::{Event, Sse};
+use axum::response::{IntoResponse, Response};
+use futures_util::{StreamExt, stream};
+use serde::Deserialize;
+use serde_json::Value;
+use std::collections::{HashMap, HashSet};
+use std::convert::Infallible;
+use std::net::IpAddr;
+use std::rc::Rc;
+use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
+use tokio::sync::{mpsc, oneshot, watch};
+use tracing::{error, info, warn};
+use trogon_std::time::SystemClock;
+use uuid::Uuid;
+
+type BoxError = Box;
+type SseSender = mpsc::Sender;
+type SseReceiver = mpsc::Receiver;
+
+#[derive(Clone)]
+pub struct AppState {
+ pub bind_host: IpAddr,
+ pub manager_tx: mpsc::UnboundedSender,
+ pub shutdown_tx: watch::Sender,
+}
+
+pub struct ConnectionRequest {
+ pub connection_id: AcpConnectionId,
+ pub socket: WebSocket,
+ pub shutdown_rx: watch::Receiver,
+}
+
+pub enum ManagerRequest {
+ WebSocket(Box),
+ HttpPost {
+ connection_id: Option,
+ protocol_version: Option,
+ session_id: Option,
+ message: IncomingHttpMessage,
+ response: oneshot::Sender>,
+ shutdown_rx: watch::Receiver,
+ },
+ HttpGet {
+ connection_id: AcpConnectionId,
+ session_id: acp_nats::AcpSessionId,
+ protocol_version: Option,
+ response: oneshot::Sender>,
+ },
+ HttpDelete {
+ connection_id: AcpConnectionId,
+ protocol_version: Option,
+ response: oneshot::Sender, HttpTransportError>>,
+ },
+}
+
+#[derive(Debug)]
+pub enum HttpPostOutcome {
+ Accepted,
+ Live {
+ connection_id: AcpConnectionId,
+ protocol_version: Option,
+ session_id: Option,
+ stream: SseReceiver,
+ },
+ Buffered {
+ connection_id: AcpConnectionId,
+ protocol_version: Option,
+ session_id: Option,
+ events: Vec,
+ },
+}
+
+#[derive(Debug)]
+pub struct HttpGetOutcome {
+ pub protocol_version: Option,
+ pub stream: SseReceiver,
+}
+
+#[derive(Debug)]
+pub enum HttpTransportError {
+ BadRequest {
+ message: &'static str,
+ source: Option,
+ },
+ NotFound {
+ message: &'static str,
+ source: Option,
+ },
+ Conflict {
+ message: &'static str,
+ source: Option,
+ },
+ Forbidden {
+ message: &'static str,
+ source: Option,
+ },
+ UnsupportedMediaType {
+ message: &'static str,
+ source: Option,
+ },
+ NotAcceptable {
+ message: &'static str,
+ source: Option,
+ },
+ NotImplemented {
+ message: &'static str,
+ source: Option,
+ },
+ Internal {
+ message: &'static str,
+ source: Option,
+ },
+}
+
+impl std::fmt::Display for HttpTransportError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.write_str(self.message())
+ }
+}
+
+impl std::error::Error for HttpTransportError {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ self.source_ref()
+ }
+}
+
+impl HttpTransportError {
+ fn message(&self) -> &'static str {
+ match self {
+ Self::BadRequest { message, .. }
+ | Self::NotFound { message, .. }
+ | Self::Conflict { message, .. }
+ | Self::Forbidden { message, .. }
+ | Self::UnsupportedMediaType { message, .. }
+ | Self::NotAcceptable { message, .. }
+ | Self::NotImplemented { message, .. }
+ | Self::Internal { message, .. } => message,
+ }
+ }
+
+ fn source_ref(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ match self {
+ Self::BadRequest { source, .. }
+ | Self::NotFound { source, .. }
+ | Self::Conflict { source, .. }
+ | Self::Forbidden { source, .. }
+ | Self::UnsupportedMediaType { source, .. }
+ | Self::NotAcceptable { source, .. }
+ | Self::NotImplemented { source, .. }
+ | Self::Internal { source, .. } => source
+ .as_deref()
+ .map(|source| source as &(dyn std::error::Error + 'static)),
+ }
+ }
+
+ fn bad_request(message: &'static str) -> Self {
+ Self::BadRequest {
+ message,
+ source: None,
+ }
+ }
+
+ fn bad_request_with(
+ message: &'static str,
+ source: impl std::error::Error + Send + 'static,
+ ) -> Self {
+ Self::BadRequest {
+ message,
+ source: Some(Box::new(source)),
+ }
+ }
+
+ fn not_found(message: &'static str) -> Self {
+ Self::NotFound {
+ message,
+ source: None,
+ }
+ }
+
+ fn conflict(message: &'static str) -> Self {
+ Self::Conflict {
+ message,
+ source: None,
+ }
+ }
+
+ fn forbidden(message: &'static str) -> Self {
+ Self::Forbidden {
+ message,
+ source: None,
+ }
+ }
+
+ fn forbidden_with(
+ message: &'static str,
+ source: impl std::error::Error + Send + 'static,
+ ) -> Self {
+ Self::Forbidden {
+ message,
+ source: Some(Box::new(source)),
+ }
+ }
+
+ fn unsupported_media_type(message: &'static str) -> Self {
+ Self::UnsupportedMediaType {
+ message,
+ source: None,
+ }
+ }
+
+ fn not_acceptable(message: &'static str) -> Self {
+ Self::NotAcceptable {
+ message,
+ source: None,
+ }
+ }
+
+ fn not_implemented(message: &'static str) -> Self {
+ Self::NotImplemented {
+ message,
+ source: None,
+ }
+ }
+
+ fn internal(message: &'static str) -> Self {
+ Self::Internal {
+ message,
+ source: None,
+ }
+ }
+
+ fn internal_with(
+ message: &'static str,
+ source: impl std::error::Error + Send + 'static,
+ ) -> Self {
+ Self::Internal {
+ message,
+ source: Some(Box::new(source)),
+ }
+ }
+
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::BadRequest { .. } => StatusCode::BAD_REQUEST,
+ Self::NotFound { .. } => StatusCode::NOT_FOUND,
+ Self::Conflict { .. } => StatusCode::CONFLICT,
+ Self::Forbidden { .. } => StatusCode::FORBIDDEN,
+ Self::UnsupportedMediaType { .. } => StatusCode::UNSUPPORTED_MEDIA_TYPE,
+ Self::NotAcceptable { .. } => StatusCode::NOT_ACCEPTABLE,
+ Self::NotImplemented { .. } => StatusCode::NOT_IMPLEMENTED,
+ Self::Internal { .. } => StatusCode::INTERNAL_SERVER_ERROR,
+ }
+ }
+
+ fn into_response(self) -> Response {
+ let status = self.status_code();
+ (status, self.to_string()).into_response()
+ }
+}
+
+#[derive(Debug)]
+pub struct HttpConnectionHandle {
+ pub command_tx: mpsc::UnboundedSender,
+}
+
+#[derive(Debug)]
+pub enum HttpConnectionCommand {
+ Post {
+ protocol_version: Option,
+ session_id: Option,
+ message: IncomingHttpMessage,
+ response: oneshot::Sender>,
+ },
+ AttachListener {
+ session_id: acp_nats::AcpSessionId,
+ protocol_version: Option,
+ response: oneshot::Sender>,
+ },
+ Close {
+ protocol_version: Option,
+ response: oneshot::Sender, HttpTransportError>>,
+ },
+}
+
+#[derive(Clone, Debug)]
+pub(crate) enum SseFrame {
+ Empty {
+ event_id: String,
+ },
+ Json {
+ event_id: Option,
+ json: String,
+ },
+}
+
+impl SseFrame {
+ fn prime() -> Self {
+ Self::Empty {
+ event_id: Uuid::new_v4().to_string(),
+ }
+ }
+
+ fn json(json: String) -> Self {
+ Self::Json {
+ event_id: Some(Uuid::new_v4().to_string()),
+ json,
+ }
+ }
+
+ fn into_event(self) -> Event {
+ match self {
+ Self::Empty { event_id } => Event::default().id(event_id).data(""),
+ Self::Json { event_id, json } => {
+ let event = Event::default().data(json);
+ if let Some(event_id) = event_id {
+ event.id(event_id)
+ } else {
+ event
+ }
+ }
+ }
+ }
+}
+
+#[derive(Debug)]
+enum PendingRequest {
+ Live {
+ request_id: RequestId,
+ capture_protocol_version: bool,
+ session_id: Option,
+ sender: SseSender,
+ },
+ Buffered {
+ request_id: RequestId,
+ fallback_session_id: Option,
+ events: Vec,
+ response: oneshot::Sender>,
+ },
+}
+
+#[derive(Debug, Deserialize)]
+pub struct IncomingHttpMessage {
+ pub id: Option,
+ pub method: Option,
+ pub params: Option,
+ #[serde(skip)]
+ pub raw: String,
+ #[serde(skip)]
+ has_result: bool,
+ #[serde(skip)]
+ has_error: bool,
+}
+
+impl IncomingHttpMessage {
+ pub fn parse(raw: String) -> Result {
+ let trimmed = raw.trim_start();
+ if trimmed.starts_with('[') {
+ return Err(HttpTransportError::not_implemented(
+ "batch JSON-RPC requests are not supported",
+ ));
+ }
+
+ let value = serde_json::from_str::(&raw).map_err(|error| {
+ HttpTransportError::bad_request_with("invalid JSON-RPC payload", error)
+ })?;
+ let (has_result, has_error) = value
+ .as_object()
+ .map(|object| (object.contains_key("result"), object.contains_key("error")))
+ .unwrap_or((false, false));
+ let mut parsed = serde_json::from_value::(value).map_err(|error| {
+ HttpTransportError::bad_request_with("invalid JSON-RPC payload", error)
+ })?;
+ parsed.raw = raw;
+ parsed.has_result = has_result;
+ parsed.has_error = has_error;
+ Ok(parsed)
+ }
+
+ fn is_request(&self) -> bool {
+ self.id.is_some() && self.method.is_some()
+ }
+
+ fn is_notification(&self) -> bool {
+ self.id.is_none() && self.method.is_some()
+ }
+
+ fn is_response(&self) -> bool {
+ self.id.is_some() && self.method.is_none() && (self.has_result || self.has_error)
+ }
+
+ fn method_name(&self) -> Option<&str> {
+ self.method.as_deref()
+ }
+
+ fn is_initialize(&self) -> bool {
+ self.method_name() == Some("initialize")
+ }
+
+ fn activates_session(&self) -> bool {
+ matches!(
+ self.method_name(),
+ Some("session/new")
+ | Some("session/load")
+ | Some("session/resume")
+ | Some("session/fork")
+ )
+ }
+
+ fn requires_session_id(&self) -> bool {
+ if self.is_response() {
+ return true;
+ }
+
+ match self.method_name() {
+ Some("initialize") | Some("authenticate") | Some("session/list") => false,
+ _ if self.activates_session() => false,
+ Some(method) if method.starts_with("session/") => true,
+ _ => false,
+ }
+ }
+
+ fn params_session_id(&self) -> Result