diff --git a/.github/workflows/ci-rust.yml b/.github/workflows/ci-rust.yml index c1396702e..3828df0df 100644 --- a/.github/workflows/ci-rust.yml +++ b/.github/workflows/ci-rust.yml @@ -56,7 +56,7 @@ jobs: diff: true diff-branch: main diff-storage: _xml_coverage_reports - uncovered-statements-increase-failure: true - new-uncovered-statements-failure: true + uncovered-statements-increase-failure: true # DO NOT CHANGE THIS, ADD TESTS + new-uncovered-statements-failure: true # DO NOT CHANGE THIS, ADD TESTS coverage-rate-reduction-failure: true togglable-report: true diff --git a/rsworkspace/crates/acp-nats/src/client/fs_read_text_file.rs b/rsworkspace/crates/acp-nats/src/client/fs_read_text_file.rs index 162e8674f..97ba2e749 100644 --- a/rsworkspace/crates/acp-nats/src/client/fs_read_text_file.rs +++ b/rsworkspace/crates/acp-nats/src/client/fs_read_text_file.rs @@ -1,72 +1,14 @@ +use crate::client::rpc_reply; use crate::jsonrpc::extract_request_id; -use crate::nats::{FlushClient, PublishClient, headers_with_trace_context}; +use crate::nats::{FlushClient, PublishClient}; use agent_client_protocol::{ - Client, Error, ErrorCode, ReadTextFileRequest, ReadTextFileResponse, Request, RequestId, - Response, + Client, ErrorCode, ReadTextFileRequest, ReadTextFileResponse, Request, Response, }; use bytes::Bytes; use serde::de::Error as SerdeDeError; use tracing::{instrument, warn}; use trogon_std::JsonSerialize; -const CONTENT_TYPE_JSON: &str = "application/json"; -const CONTENT_TYPE_PLAIN: &str = "text/plain"; - -fn error_response_fallback_bytes(serializer: &S) -> (Bytes, &'static str) { - match serializer.to_vec(&Response::<()>::Error { - id: RequestId::Null, - error: Error::new(-32603, "Internal error"), - }) { - Ok(v) => (Bytes::from(v), CONTENT_TYPE_JSON), - Err(e) => { - warn!( - error = %e, - "Fallback JSON serialization failed, response may not be valid JSON-RPC" - ); - (Bytes::from("Internal error"), CONTENT_TYPE_PLAIN) - } - } -} - -async fn publish_reply( - nats: &N, - reply_to: &str, - bytes: Bytes, - content_type: &str, - context: &str, -) { - let mut headers = headers_with_trace_context(); - headers.insert("Content-Type", content_type); - if let Err(e) = nats - .publish_with_headers(reply_to.to_string(), headers, bytes) - .await - { - warn!(error = %e, "Failed to publish {}", context); - } - if let Err(e) = nats.flush().await { - warn!(error = %e, "Failed to flush {}", context); - } -} - -fn error_response_bytes( - serializer: &S, - request_id: RequestId, - code: ErrorCode, - message: &str, -) -> (Bytes, &'static str) { - let response = Response::<()>::Error { - id: request_id, - error: Error::new(i32::from(code), message), - }; - match serializer.to_vec(&response) { - Ok(v) => (Bytes::from(v), CONTENT_TYPE_JSON), - Err(e) => { - warn!(error = %e, "JSON serialization failed, using fallback error"); - error_response_fallback_bytes(serializer) - } - } -} - #[derive(Debug)] pub enum FsReadTextFileError { InvalidRequest(serde_json::Error), @@ -134,17 +76,17 @@ pub async fn handle id: request_id.clone(), result: response, }) - .map(|v| (Bytes::from(v), CONTENT_TYPE_JSON)) + .map(|v| (Bytes::from(v), rpc_reply::CONTENT_TYPE_JSON)) .unwrap_or_else(|e| { warn!(error = %e, "JSON serialization of response failed, sending error reply"); - error_response_bytes( + rpc_reply::error_response_bytes( serializer, request_id, ErrorCode::InternalError, &format!("Failed to serialize response: {}", e), ) }); - publish_reply( + rpc_reply::publish_reply( nats, reply_to, response_bytes, @@ -161,8 +103,8 @@ pub async fn handle "Failed to handle fs_read_text_file" ); let (bytes, content_type) = - error_response_bytes(serializer, request_id, code, &message); - publish_reply( + rpc_reply::error_response_bytes(serializer, request_id, code, &message); + rpc_reply::publish_reply( nats, reply_to, bytes, @@ -634,21 +576,6 @@ mod tests { assert!(fs_err.source().is_some()); } - #[test] - fn error_response_bytes_first_fallback_uses_null_id() { - let mock = FailNextSerialize::new(1); - let (bytes, content_type) = error_response_bytes( - &mock, - RequestId::Number(42), - ErrorCode::InvalidParams, - "test message", - ); - assert_eq!(content_type, "application/json"); - let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); - assert_eq!(parsed["id"], serde_json::Value::Null); - assert_eq!(parsed["error"]["code"], -32603); - } - #[tokio::test] async fn mock_client_request_permission_returns_err() { let client = MockClient::new("x"); @@ -674,22 +601,4 @@ mod tests { let result = client.request_permission(req).await; assert!(result.is_err()); } - - #[test] - fn error_response_bytes_last_resort_returns_plain_text() { - let mock = FailNextSerialize::new(2); - let (bytes, content_type) = - error_response_bytes(&mock, RequestId::Number(1), ErrorCode::InternalError, "msg"); - assert_eq!(content_type, "text/plain"); - assert_eq!(bytes.as_ref(), b"Internal error"); - } - - #[test] - fn error_response_fallback_bytes_std_serializer_returns_json() { - let (bytes, content_type) = error_response_fallback_bytes(&StdJsonSerialize); - assert_eq!(content_type, "application/json"); - let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); - assert_eq!(parsed["id"], serde_json::Value::Null); - assert_eq!(parsed["error"]["code"], -32603); - } } diff --git a/rsworkspace/crates/acp-nats/src/client/mod.rs b/rsworkspace/crates/acp-nats/src/client/mod.rs index 3d6a9eaf8..3ad17eb08 100644 --- a/rsworkspace/crates/acp-nats/src/client/mod.rs +++ b/rsworkspace/crates/acp-nats/src/client/mod.rs @@ -1,4 +1,6 @@ pub(crate) mod fs_read_text_file; +pub(crate) mod request_permission; +pub(crate) mod rpc_reply; pub(crate) mod session_update; use crate::agent::Bridge; @@ -7,9 +9,9 @@ use crate::in_flight_slot_guard::InFlightSlotGuard; use crate::jsonrpc::extract_request_id; use crate::nats::{ ClientMethod, FlushClient, PublishClient, RequestClient, SubscribeClient, client, - headers_with_trace_context, parse_client_subject, + parse_client_subject, }; -use agent_client_protocol::{Client, Error, ErrorCode, Response}; +use agent_client_protocol::{Client, ErrorCode}; use async_nats::Message; use bytes::Bytes; use futures::StreamExt; @@ -19,9 +21,6 @@ use tracing::{Span, error, info, instrument, warn}; use trogon_std::JsonSerialize; use trogon_std::time::GetElapsed; -const CONTENT_TYPE_JSON: &str = "application/json"; -const CONTENT_TYPE_PLAIN: &str = "text/plain"; - async fn publish_backpressure_error_reply( nats: &N, payload: &[u8], @@ -29,38 +28,20 @@ async fn publish_backpressure_error_reply::Error { - id: request_id, - error: Error::new( - i32::from(ErrorCode::Other(AGENT_UNAVAILABLE)), - "Client proxy overloaded; retry with backoff", - ), - }; - let (bytes, content_type) = serializer - .to_vec(&response) - .or_else(|e| { - warn!(error = %e, "JSON serialization of backpressure error failed, using fallback"); - serializer.to_vec(&Response::<()>::Error { - id: agent_client_protocol::RequestId::Null, - error: Error::new(-32603, "Internal error"), - }) - }) - .map(|v| (Bytes::from(v), CONTENT_TYPE_JSON)) - .unwrap_or_else(|e| { - warn!(error = %e, "Fallback JSON serialization failed, response may not be valid JSON-RPC"); - (Bytes::from("Internal error"), CONTENT_TYPE_PLAIN) - }); - let mut headers = headers_with_trace_context(); - headers.insert("Content-Type", content_type); - if let Err(e) = nats - .publish_with_headers(reply_to.to_string(), headers, bytes) - .await - { - warn!(error = %e, "Failed to publish backpressure error reply"); - } - if let Err(e) = nats.flush().await { - warn!(error = %e, "Failed to flush backpressure error reply"); - } + let (bytes, content_type) = rpc_reply::error_response_bytes( + serializer, + request_id, + ErrorCode::Other(AGENT_UNAVAILABLE), + "Client proxy overloaded; retry with backoff", + ); + rpc_reply::publish_reply( + nats, + reply_to, + bytes, + content_type, + "backpressure error reply", + ) + .await; } /// Runs the client proxy, subscribing to client subjects and dispatching to handlers. @@ -206,6 +187,17 @@ async fn dispatch_client_method< ) .await; } + ClientMethod::SessionRequestPermission => { + request_permission::handle( + &payload, + ctx.client, + reply.as_deref(), + ctx.nats, + parsed.session_id.as_str(), + ctx.serializer, + ) + .await; + } ClientMethod::SessionUpdate => { session_update::handle(&payload, ctx.client, &parsed.session_id).await; } @@ -218,7 +210,8 @@ mod tests { use crate::session_id::AcpSessionId; use agent_client_protocol::{ ContentBlock, ContentChunk, ReadTextFileRequest, ReadTextFileResponse, Request, RequestId, - RequestPermissionRequest, RequestPermissionResponse, SessionNotification, SessionUpdate, + RequestPermissionOutcome, RequestPermissionRequest, RequestPermissionResponse, + SessionNotification, SessionUpdate, }; use async_trait::async_trait; use std::cell::RefCell; @@ -499,6 +492,280 @@ mod tests { assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]); } + #[derive(Debug)] + struct RpcMockClient; + + #[async_trait(?Send)] + impl Client for RpcMockClient { + async fn session_notification( + &self, + _: SessionNotification, + ) -> agent_client_protocol::Result<()> { + Ok(()) + } + + async fn request_permission( + &self, + _: RequestPermissionRequest, + ) -> agent_client_protocol::Result { + Ok(RequestPermissionResponse::new( + RequestPermissionOutcome::Cancelled, + )) + } + + async fn read_text_file( + &self, + _: ReadTextFileRequest, + ) -> agent_client_protocol::Result { + Ok(ReadTextFileResponse::new("file contents".to_string())) + } + } + + #[tokio::test] + async fn dispatch_client_method_dispatches_session_update_with_rpc_mock_client() { + let nats = MockNatsClient::new(); + let client = RpcMockClient; + let session_id = AcpSessionId::new("sess-1").unwrap(); + + let notification = SessionNotification::new( + "sess-1", + SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::from("hi"))), + ); + let payload = bytes::Bytes::from(serde_json::to_vec(¬ification).unwrap()); + + let parsed = crate::nats::ParsedClientSubject { + session_id, + method: ClientMethod::SessionUpdate, + }; + + let ctx = DispatchContext { + nats: &nats, + client: &client, + serializer: &StdJsonSerialize, + }; + dispatch_client_method( + "acp.sess-1.client.session.update", + parsed, + payload, + None, + &ctx, + ) + .await; + } + + #[tokio::test] + async fn dispatch_client_method_dispatches_fs_read_text_file_with_rpc_mock_client() { + let nats = MockNatsClient::new(); + let client = RpcMockClient; + let session_id = AcpSessionId::new("sess-1").unwrap(); + + let envelope = Request { + id: RequestId::Number(1), + method: std::sync::Arc::from("fs/read_text_file"), + params: Some(ReadTextFileRequest::new( + agent_client_protocol::SessionId::from("sess-1"), + "/tmp/foo.txt".to_string(), + )), + }; + let payload = bytes::Bytes::from(serde_json::to_vec(&envelope).unwrap()); + + let parsed = crate::nats::ParsedClientSubject { + session_id, + method: ClientMethod::FsReadTextFile, + }; + + let ctx = DispatchContext { + nats: &nats, + client: &client, + serializer: &StdJsonSerialize, + }; + dispatch_client_method( + "acp.sess-1.client.fs.read_text_file", + parsed, + payload, + Some("_INBOX.reply".to_string()), + &ctx, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]); + } + + #[tokio::test] + async fn dispatch_client_method_dispatches_request_permission() { + let nats = MockNatsClient::new(); + let client = RpcMockClient; + let session_id = AcpSessionId::new("sess-1").unwrap(); + + let request = RequestPermissionRequest::new( + "sess-1", + agent_client_protocol::ToolCallUpdate::new( + "call-1", + agent_client_protocol::ToolCallUpdateFields::new(), + ), + vec![], + ); + let envelope = Request { + id: RequestId::Number(1), + method: std::sync::Arc::from("session/request_permission"), + params: Some(request), + }; + let payload = bytes::Bytes::from(serde_json::to_vec(&envelope).unwrap()); + + let parsed = crate::nats::ParsedClientSubject { + session_id, + method: ClientMethod::SessionRequestPermission, + }; + + let ctx = DispatchContext { + nats: &nats, + client: &client, + serializer: &StdJsonSerialize, + }; + dispatch_client_method( + "acp.sess-1.client.session.request_permission", + parsed, + payload, + Some("_INBOX.reply".to_string()), + &ctx, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]); + } + + #[tokio::test] + async fn dispatch_client_method_dispatches_request_permission_client_error_publishes_error_reply() + { + let nats = MockNatsClient::new(); + let client = MockClient::new(); + let session_id = AcpSessionId::new("sess-1").unwrap(); + + let request = RequestPermissionRequest::new( + "sess-1", + agent_client_protocol::ToolCallUpdate::new( + "call-1", + agent_client_protocol::ToolCallUpdateFields::new(), + ), + vec![], + ); + let envelope = Request { + id: RequestId::Number(1), + method: std::sync::Arc::from("session/request_permission"), + params: Some(request), + }; + let payload = bytes::Bytes::from(serde_json::to_vec(&envelope).unwrap()); + + let parsed = crate::nats::ParsedClientSubject { + session_id, + method: ClientMethod::SessionRequestPermission, + }; + + let ctx = DispatchContext { + nats: &nats, + client: &client, + serializer: &StdJsonSerialize, + }; + dispatch_client_method( + "acp.sess-1.client.session.request_permission", + parsed, + payload, + Some("_INBOX.err".to_string()), + &ctx, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.err"]); + } + + #[tokio::test] + async fn dispatch_client_method_dispatches_request_permission_with_advanced_mock() { + let nats = AdvancedMockNatsClient::new(); + let client = MockClient::new(); + let session_id = AcpSessionId::new("sess-1").unwrap(); + + let request = RequestPermissionRequest::new( + "sess-1", + agent_client_protocol::ToolCallUpdate::new( + "call-1", + agent_client_protocol::ToolCallUpdateFields::new(), + ), + vec![], + ); + let envelope = Request { + id: RequestId::Number(1), + method: std::sync::Arc::from("session/request_permission"), + params: Some(request), + }; + let payload = bytes::Bytes::from(serde_json::to_vec(&envelope).unwrap()); + + let parsed = crate::nats::ParsedClientSubject { + session_id, + method: ClientMethod::SessionRequestPermission, + }; + + let ctx = DispatchContext { + nats: &nats, + client: &client, + serializer: &StdJsonSerialize, + }; + dispatch_client_method( + "acp.sess-1.client.session.request_permission", + parsed, + payload, + Some("_INBOX.err".to_string()), + &ctx, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.err"]); + } + + #[tokio::test] + async fn dispatch_client_method_dispatches_request_permission_client_error_serialization_fallback() + { + let nats = MockNatsClient::new(); + let client = MockClient::new(); + let serializer = FailNextSerialize::new(1); + let session_id = AcpSessionId::new("sess-1").unwrap(); + + let request = RequestPermissionRequest::new( + "sess-1", + agent_client_protocol::ToolCallUpdate::new( + "call-1", + agent_client_protocol::ToolCallUpdateFields::new(), + ), + vec![], + ); + let envelope = Request { + id: RequestId::Number(1), + method: std::sync::Arc::from("session/request_permission"), + params: Some(request), + }; + let payload = bytes::Bytes::from(serde_json::to_vec(&envelope).unwrap()); + + let parsed = crate::nats::ParsedClientSubject { + session_id, + method: ClientMethod::SessionRequestPermission, + }; + + let ctx = DispatchContext { + nats: &nats, + client: &client, + serializer: &serializer, + }; + dispatch_client_method( + "acp.sess-1.client.session.request_permission", + parsed, + payload, + Some("_INBOX.err".to_string()), + &ctx, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.err"]); + } + #[tokio::test] async fn process_message_invalid_subject_no_reply_does_not_publish() { let nats = MockNatsClient::new(); diff --git a/rsworkspace/crates/acp-nats/src/client/request_permission.rs b/rsworkspace/crates/acp-nats/src/client/request_permission.rs new file mode 100644 index 000000000..4bceff656 --- /dev/null +++ b/rsworkspace/crates/acp-nats/src/client/request_permission.rs @@ -0,0 +1,542 @@ +use crate::client::rpc_reply; +use crate::jsonrpc::extract_request_id; +use crate::nats::{FlushClient, PublishClient}; +use agent_client_protocol::{ + Client, ErrorCode, Request, RequestPermissionRequest, RequestPermissionResponse, Response, +}; +use bytes::Bytes; +use serde::de::Error as SerdeDeError; +use tracing::{instrument, warn}; +use trogon_std::JsonSerialize; + +#[derive(Debug)] +pub enum RequestPermissionError { + InvalidRequest(serde_json::Error), + ClientError(agent_client_protocol::Error), +} + +impl std::fmt::Display for RequestPermissionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidRequest(e) => write!(f, "invalid request: {}", e), + Self::ClientError(e) => write!(f, "client error: {}", e), + } + } +} + +impl std::error::Error for RequestPermissionError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::InvalidRequest(e) => Some(e), + Self::ClientError(e) => Some(e), + } + } +} + +pub fn error_code_and_message(e: &RequestPermissionError) -> (ErrorCode, String) { + match e { + RequestPermissionError::InvalidRequest(inner) => ( + ErrorCode::InvalidParams, + format!("Invalid request_permission request: {}", inner), + ), + RequestPermissionError::ClientError(inner) => (inner.code, inner.message.clone()), + } +} + +/// Handles session/request_permission: parses JSON-RPC request, calls client, wraps response in +/// JSON-RPC envelope, and publishes to reply subject. Reply is required (request-reply pattern). +#[instrument( + name = "acp.client.session.request_permission", + skip(payload, client, nats, serializer) +)] +pub async fn handle( + payload: &[u8], + client: &C, + reply: Option<&str>, + nats: &N, + session_id: &str, + serializer: &S, +) { + let reply_to = match reply { + Some(r) => r, + None => { + warn!( + session_id = %session_id, + "request_permission requires reply subject; ignoring message" + ); + return; + } + }; + + let request_id = extract_request_id(payload); + match forward_to_client(payload, client, session_id).await { + Ok(response) => { + let (response_bytes, content_type) = serializer + .to_vec(&Response::Result { + id: request_id.clone(), + result: response, + }) + .map(|v| (Bytes::from(v), rpc_reply::CONTENT_TYPE_JSON)) + .unwrap_or_else(|e| { + warn!(error = %e, "JSON serialization of response failed, sending error reply"); + rpc_reply::error_response_bytes( + serializer, + request_id, + ErrorCode::InternalError, + &format!("Failed to serialize response: {}", e), + ) + }); + rpc_reply::publish_reply( + nats, + reply_to, + response_bytes, + content_type, + "request_permission reply", + ) + .await; + } + Err(e) => { + let (code, message) = error_code_and_message(&e); + warn!( + error = %e, + session_id = %session_id, + "Failed to handle request_permission" + ); + let (bytes, content_type) = + rpc_reply::error_response_bytes(serializer, request_id, code, &message); + rpc_reply::publish_reply( + nats, + reply_to, + bytes, + content_type, + "request_permission error reply", + ) + .await; + } + } +} + +async fn forward_to_client( + payload: &[u8], + client: &C, + expected_session_id: &str, +) -> Result { + let envelope: Request = + serde_json::from_slice(payload).map_err(RequestPermissionError::InvalidRequest)?; + let request = envelope.params.ok_or_else(|| { + RequestPermissionError::InvalidRequest(serde_json::Error::custom( + "params is null or missing", + )) + })?; + let params_session_id = request.session_id.to_string(); + if params_session_id != expected_session_id { + return Err(RequestPermissionError::InvalidRequest( + serde_json::Error::custom(format!( + "params.sessionId ({}) does not match subject session id ({})", + params_session_id, expected_session_id + )), + )); + } + client + .request_permission(request) + .await + .map_err(RequestPermissionError::ClientError) +} + +#[cfg(test)] +mod tests { + use super::*; + use agent_client_protocol::{ + ContentBlock, ContentChunk, PermissionOption, PermissionOptionKind, RequestId, + RequestPermissionOutcome, RequestPermissionResponse, SessionNotification, SessionUpdate, + ToolCallUpdate, ToolCallUpdateFields, + }; + use std::error::Error; + use trogon_nats::{AdvancedMockNatsClient, MockNatsClient}; + use trogon_std::{FailNextSerialize, StdJsonSerialize}; + + struct MockClient { + outcome: RequestPermissionOutcome, + } + + impl MockClient { + fn new(outcome: RequestPermissionOutcome) -> Self { + Self { outcome } + } + } + + #[async_trait::async_trait(?Send)] + impl Client for MockClient { + async fn session_notification( + &self, + _: SessionNotification, + ) -> agent_client_protocol::Result<()> { + Ok(()) + } + + async fn request_permission( + &self, + _: RequestPermissionRequest, + ) -> agent_client_protocol::Result { + Ok(RequestPermissionResponse::new(self.outcome.clone())) + } + } + + struct FailingClient; + + #[async_trait::async_trait(?Send)] + impl Client for FailingClient { + async fn session_notification( + &self, + _: SessionNotification, + ) -> agent_client_protocol::Result<()> { + Ok(()) + } + + async fn request_permission( + &self, + _: RequestPermissionRequest, + ) -> agent_client_protocol::Result { + Err(agent_client_protocol::Error::new( + i32::from(ErrorCode::InvalidParams), + "permission denied", + )) + } + } + + #[tokio::test] + async fn mock_client_session_notification_returns_ok() { + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let notification = SessionNotification::new( + "sess-1", + SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::from("hi"))), + ); + let result = client.session_notification(notification).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn failing_client_session_notification_returns_ok() { + let client = FailingClient; + let notification = SessionNotification::new( + "sess-1", + SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::from("hi"))), + ); + let result = client.session_notification(notification).await; + assert!(result.is_ok()); + } + + fn make_envelope(request: RequestPermissionRequest) -> Vec { + let envelope = Request { + id: RequestId::Number(1), + method: std::sync::Arc::from("session/request_permission"), + params: Some(request), + }; + serde_json::to_vec(&envelope).unwrap() + } + + #[tokio::test] + async fn request_permission_forwards_request_and_returns_response() { + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let options = vec![PermissionOption::new( + "allow-once", + "Allow once", + PermissionOptionKind::AllowOnce, + )]; + let request = RequestPermissionRequest::new("session-001", tool_call, options); + let payload = make_envelope(request); + + let result = forward_to_client(&payload, &client, "session-001").await; + assert!(result.is_ok()); + let response = result.unwrap(); + assert_eq!(response.outcome, RequestPermissionOutcome::Cancelled); + } + + #[tokio::test] + async fn request_permission_returns_error_when_payload_is_invalid_json() { + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let result = forward_to_client(b"not json", &client, "session-001").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn request_permission_returns_client_error_when_client_fails() { + let client = FailingClient; + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-001", tool_call, vec![]); + let payload = make_envelope(request); + + let result = forward_to_client(&payload, &client, "session-001").await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + RequestPermissionError::ClientError(_) + )); + } + + #[tokio::test] + async fn request_permission_returns_invalid_request_when_params_missing() { + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let envelope = Request:: { + id: RequestId::Number(1), + method: std::sync::Arc::from("session/request_permission"), + params: None, + }; + let payload = serde_json::to_vec(&envelope).unwrap(); + + let result = forward_to_client(&payload, &client, "session-001").await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + RequestPermissionError::InvalidRequest(_) + )); + } + + #[tokio::test] + async fn request_permission_returns_invalid_request_when_session_id_mismatch() { + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-other", tool_call, vec![]); + let payload = make_envelope(request); + + let result = forward_to_client(&payload, &client, "session-001").await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + RequestPermissionError::InvalidRequest(_) + )); + } + + #[test] + fn error_code_and_message_invalid_request_returns_invalid_params() { + let err = serde_json::from_slice::(b"not json").unwrap_err(); + let rp_err = RequestPermissionError::InvalidRequest(err); + let (code, _) = error_code_and_message(&rp_err); + assert_eq!(code, ErrorCode::InvalidParams); + } + + #[test] + fn error_code_and_message_client_error_preserves_client_code() { + let client_err = + agent_client_protocol::Error::new(ErrorCode::InvalidParams.into(), "denied"); + let rp_err = RequestPermissionError::ClientError(client_err); + let (code, message) = error_code_and_message(&rp_err); + assert_eq!(code, ErrorCode::InvalidParams); + assert_eq!(message, "denied"); + } + + #[test] + fn request_permission_error_display() { + let err = serde_json::from_slice::(b"not json").unwrap_err(); + let rp_err = RequestPermissionError::InvalidRequest(err); + assert!(rp_err.to_string().contains("invalid request")); + + let client_err = + agent_client_protocol::Error::new(ErrorCode::InvalidParams.into(), "permission denied"); + let rp_err = RequestPermissionError::ClientError(client_err); + assert!(rp_err.to_string().contains("client error")); + } + + #[test] + fn request_permission_error_source() { + let err = serde_json::from_slice::(b"not json").unwrap_err(); + let rp_err = RequestPermissionError::InvalidRequest(err); + assert!(rp_err.source().is_some()); + + let client_err = + agent_client_protocol::Error::new(ErrorCode::InvalidParams.into(), "denied"); + let rp_err = RequestPermissionError::ClientError(client_err); + assert!(rp_err.source().is_some()); + } + + #[tokio::test] + async fn handle_success_publishes_response_to_reply_subject() { + let nats = MockNatsClient::new(); + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-001", tool_call, vec![]); + let payload = make_envelope(request); + + handle( + &payload, + &client, + Some("_INBOX.reply"), + &nats, + "session-001", + &StdJsonSerialize, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]); + } + + #[tokio::test] + async fn handle_no_reply_does_not_publish() { + let nats = MockNatsClient::new(); + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-001", tool_call, vec![]); + let payload = make_envelope(request); + + handle( + &payload, + &client, + None, + &nats, + "session-001", + &StdJsonSerialize, + ) + .await; + + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn handle_session_id_mismatch_publishes_error_reply() { + let nats = MockNatsClient::new(); + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-other", tool_call, vec![]); + let payload = make_envelope(request); + + handle( + &payload, + &client, + Some("_INBOX.err"), + &nats, + "session-001", + &StdJsonSerialize, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.err"]); + } + + #[tokio::test] + async fn handle_invalid_payload_publishes_error_reply() { + let nats = MockNatsClient::new(); + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + + handle( + b"not json", + &client, + Some("_INBOX.err"), + &nats, + "session-001", + &StdJsonSerialize, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.err"]); + } + + #[tokio::test] + async fn handle_client_error_publishes_error_reply() { + let nats = MockNatsClient::new(); + let client = FailingClient; + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-001", tool_call, vec![]); + let payload = make_envelope(request); + + handle( + &payload, + &client, + Some("_INBOX.err"), + &nats, + "session-001", + &StdJsonSerialize, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.err"]); + } + + #[tokio::test] + async fn handle_success_serialization_fallback_sends_error_reply() { + let nats = MockNatsClient::new(); + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let serializer = FailNextSerialize::new(1); + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-001", tool_call, vec![]); + let payload = make_envelope(request); + + handle( + &payload, + &client, + Some("_INBOX.reply"), + &nats, + "session-001", + &serializer, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]); + } + + #[tokio::test] + async fn handle_success_flush_failure_exercises_warn_path() { + let nats = AdvancedMockNatsClient::new(); + nats.fail_next_flush(); + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-001", tool_call, vec![]); + let payload = make_envelope(request); + + handle( + &payload, + &client, + Some("_INBOX.reply"), + &nats, + "session-001", + &StdJsonSerialize, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]); + } + + #[tokio::test] + async fn handle_success_publish_failure_exercises_error_path() { + let nats = AdvancedMockNatsClient::new(); + nats.fail_next_publish(); + let client = MockClient::new(RequestPermissionOutcome::Cancelled); + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-001", tool_call, vec![]); + let payload = make_envelope(request); + + handle( + &payload, + &client, + Some("_INBOX.reply"), + &nats, + "session-001", + &StdJsonSerialize, + ) + .await; + + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn handle_client_error_serialization_last_resort_returns_plain_text() { + let nats = MockNatsClient::new(); + let client = FailingClient; + let serializer = FailNextSerialize::new(2); + let tool_call = ToolCallUpdate::new("call-1", ToolCallUpdateFields::new()); + let request = RequestPermissionRequest::new("session-001", tool_call, vec![]); + let payload = make_envelope(request); + + handle( + &payload, + &client, + Some("_INBOX.err"), + &nats, + "session-001", + &serializer, + ) + .await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.err"]); + } +} diff --git a/rsworkspace/crates/acp-nats/src/client/rpc_reply.rs b/rsworkspace/crates/acp-nats/src/client/rpc_reply.rs new file mode 100644 index 000000000..caf28a5ea --- /dev/null +++ b/rsworkspace/crates/acp-nats/src/client/rpc_reply.rs @@ -0,0 +1,103 @@ +use crate::nats::{FlushClient, PublishClient, headers_with_trace_context}; +use agent_client_protocol::{Error, ErrorCode, RequestId, Response}; +use bytes::Bytes; +use tracing::warn; +use trogon_std::JsonSerialize; + +pub const CONTENT_TYPE_JSON: &str = "application/json"; +pub const CONTENT_TYPE_PLAIN: &str = "text/plain"; + +pub fn error_response_fallback_bytes(serializer: &S) -> (Bytes, &'static str) { + match serializer.to_vec(&Response::<()>::Error { + id: RequestId::Null, + error: Error::new(-32603, "Internal error"), + }) { + Ok(v) => (Bytes::from(v), CONTENT_TYPE_JSON), + Err(e) => { + warn!( + error = %e, + "Fallback JSON serialization failed, response may not be valid JSON-RPC" + ); + (Bytes::from("Internal error"), CONTENT_TYPE_PLAIN) + } + } +} + +pub async fn publish_reply( + nats: &N, + reply_to: &str, + bytes: Bytes, + content_type: &str, + context: &str, +) { + let mut headers = headers_with_trace_context(); + headers.insert("Content-Type", content_type); + if let Err(e) = nats + .publish_with_headers(reply_to.to_string(), headers, bytes) + .await + { + warn!(error = %e, "Failed to publish {}", context); + } + if let Err(e) = nats.flush().await { + warn!(error = %e, "Failed to flush {}", context); + } +} + +pub fn error_response_bytes( + serializer: &S, + request_id: RequestId, + code: ErrorCode, + message: &str, +) -> (Bytes, &'static str) { + let response = Response::<()>::Error { + id: request_id, + error: Error::new(i32::from(code), message), + }; + match serializer.to_vec(&response) { + Ok(v) => (Bytes::from(v), CONTENT_TYPE_JSON), + Err(e) => { + warn!(error = %e, "JSON serialization failed, using fallback error"); + error_response_fallback_bytes(serializer) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use agent_client_protocol::{ErrorCode, RequestId}; + use trogon_std::{FailNextSerialize, StdJsonSerialize}; + + #[test] + fn error_response_bytes_first_fallback_uses_null_id() { + let mock = FailNextSerialize::new(1); + let (bytes, content_type) = error_response_bytes( + &mock, + RequestId::Number(42), + ErrorCode::InvalidParams, + "test message", + ); + assert_eq!(content_type, "application/json"); + let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); + assert_eq!(parsed["id"], serde_json::Value::Null); + assert_eq!(parsed["error"]["code"], -32603); + } + + #[test] + fn error_response_bytes_last_resort_returns_plain_text() { + let mock = FailNextSerialize::new(2); + let (bytes, content_type) = + error_response_bytes(&mock, RequestId::Number(1), ErrorCode::InternalError, "msg"); + assert_eq!(content_type, "text/plain"); + assert_eq!(bytes.as_ref(), b"Internal error"); + } + + #[test] + fn error_response_fallback_bytes_std_serializer_returns_json() { + let (bytes, content_type) = error_response_fallback_bytes(&StdJsonSerialize); + assert_eq!(content_type, "application/json"); + let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); + assert_eq!(parsed["id"], serde_json::Value::Null); + assert_eq!(parsed["error"]["code"], -32603); + } +} diff --git a/rsworkspace/crates/acp-nats/src/nats/client_method.rs b/rsworkspace/crates/acp-nats/src/nats/client_method.rs index 97dc29bb5..adddf6b7b 100644 --- a/rsworkspace/crates/acp-nats/src/nats/client_method.rs +++ b/rsworkspace/crates/acp-nats/src/nats/client_method.rs @@ -1,6 +1,7 @@ #[derive(Debug, Clone, PartialEq, Eq)] pub enum ClientMethod { FsReadTextFile, + SessionRequestPermission, SessionUpdate, } @@ -8,6 +9,7 @@ impl ClientMethod { pub fn from_subject_suffix(suffix: &str) -> Option { match suffix { "client.fs.read_text_file" => Some(Self::FsReadTextFile), + "client.session.request_permission" => Some(Self::SessionRequestPermission), "client.session.update" => Some(Self::SessionUpdate), _ => None, } diff --git a/rsworkspace/crates/acp-nats/src/nats/parsing.rs b/rsworkspace/crates/acp-nats/src/nats/parsing.rs index 49db37d7c..5f1693aa4 100644 --- a/rsworkspace/crates/acp-nats/src/nats/parsing.rs +++ b/rsworkspace/crates/acp-nats/src/nats/parsing.rs @@ -28,6 +28,14 @@ mod tests { assert_eq!(parsed.method, ClientMethod::FsReadTextFile); } + #[test] + fn test_parse_session_request_permission() { + let subject = "acp.sess123.client.session.request_permission"; + let parsed = parse_client_subject(subject).unwrap(); + assert_eq!(parsed.session_id.as_str(), "sess123"); + assert_eq!(parsed.method, ClientMethod::SessionRequestPermission); + } + #[test] fn test_parse_session_update() { let subject = "acp.sess123.client.session.update"; diff --git a/rsworkspace/crates/acp-nats/src/nats/subjects.rs b/rsworkspace/crates/acp-nats/src/nats/subjects.rs index 95835d975..782c94fa0 100644 --- a/rsworkspace/crates/acp-nats/src/nats/subjects.rs +++ b/rsworkspace/crates/acp-nats/src/nats/subjects.rs @@ -43,6 +43,13 @@ pub mod client { format!("{}.{}.client.fs.read_text_file", prefix, session_id) } + pub fn session_request_permission(prefix: &str, session_id: &str) -> String { + format!( + "{}.{}.client.session.request_permission", + prefix, session_id + ) + } + pub fn session_update(prefix: &str, session_id: &str) -> String { format!("{}.{}.client.session.update", prefix, session_id) } @@ -67,6 +74,14 @@ mod tests { ); } + #[test] + fn client_session_request_permission_subject() { + assert_eq!( + client::session_request_permission("acp", "s1"), + "acp.s1.client.session.request_permission" + ); + } + #[test] fn client_session_update_subject() { assert_eq!(