diff --git a/rsworkspace/crates/acp-nats-agent/src/connection.rs b/rsworkspace/crates/acp-nats-agent/src/connection.rs index abd826348..3b19e8012 100644 --- a/rsworkspace/crates/acp-nats-agent/src/connection.rs +++ b/rsworkspace/crates/acp-nats-agent/src/connection.rs @@ -1,7 +1,13 @@ +use acp_nats::jetstream::consumers::commands_observer; +use acp_nats::jetstream::streams::commands_stream_name; +use acp_nats::nats::agent::wildcards::GlobalAllSubject; +use acp_nats::nats::session::wildcards::{AllAgentExtSubject, AllAgentSubject}; use acp_nats::nats::{ GlobalAgentMethod, ParsedAgentSubject, SessionAgentMethod, parse_agent_subject, }; -use acp_nats::{AcpPrefix, AcpSessionId, NatsClientProxy}; +use acp_nats::{ + AcpPrefix, AcpSessionId, NatsClientProxy, PromptResponseSubject, ReqId, ResponseSubject, +}; use agent_client_protocol::{ Agent, AuthenticateRequest, CancelNotification, CloseSessionRequest, ExtNotification, ExtRequest, ForkSessionRequest, InitializeRequest, ListSessionsRequest, LoadSessionRequest, @@ -158,8 +164,8 @@ where N: SubscribeClient + PublishClient + FlushClient + Clone + 'static, A: Agent + 'static, { - let global_wildcard = acp_nats::nats::agent::wildcards::GlobalAllSubject::new(prefix); - let session_wildcard = acp_nats::nats::session::wildcards::AllAgentSubject::new(prefix); + let global_wildcard = GlobalAllSubject::new(prefix); + let session_wildcard = AllAgentSubject::new(prefix); info!( global = %global_wildcard, @@ -204,8 +210,8 @@ where N: SubscribeClient + PublishClient + FlushClient + Clone + 'static, A: Agent + 'static, { - let global_wildcard = acp_nats::nats::agent::wildcards::GlobalAllSubject::new(prefix); - let ext_wildcard = acp_nats::nats::session::wildcards::AllAgentExtSubject::new(prefix); + let global_wildcard = GlobalAllSubject::new(prefix); + let ext_wildcard = AllAgentExtSubject::new(prefix); info!( global = %global_wildcard, @@ -469,9 +475,8 @@ where }; } _ = keepalive.tick() => { - if let Err(e) = js_msg.ack_with(AckKind::Progress).await { - warn!(error = %e, "Failed to send in_progress keepalive"); - } + let _ = js_msg.ack_with(AckKind::Progress).await + .inspect_err(|e| warn!(error = %e, "Failed to send in_progress keepalive")); } } } @@ -490,8 +495,8 @@ where trogon_nats::jetstream::JsMessageOf: JsDispatchMessage, A: Agent + 'static, { - let stream_name = acp_nats::jetstream::streams::commands_stream_name(prefix); - let config = acp_nats::jetstream::consumers::commands_observer(); + let stream_name = commands_stream_name(prefix); + let config = commands_observer(); info!(stream = %stream_name, "Starting JetStream consumer for COMMANDS stream"); @@ -567,18 +572,14 @@ async fn dispatch_js_message = match (&req_id, &method) { - (Some(rid), SessionAgentMethod::Prompt) => Some( - acp_nats::nats::session::agent::PromptResponseSubject::new(prefix, &session_id, rid) - .to_string(), - ), + (Some(rid), SessionAgentMethod::Prompt) => { + Some(PromptResponseSubject::new(prefix, &session_id, rid).to_string()) + } (_, SessionAgentMethod::Cancel) => None, - (Some(rid), _) => Some( - acp_nats::nats::session::agent::ResponseSubject::new(prefix, &session_id, rid) - .to_string(), - ), + (Some(rid), _) => Some(ResponseSubject::new(prefix, &session_id, rid).to_string()), (None, _) => { warn!(subject, "JetStream message missing X-Req-Id header"); None @@ -673,20 +674,20 @@ async fn dispatch_js_message { - if let Err(e) = js_msg.ack().await { - warn!(subject, error = %e, "Failed to ack after notification handler error"); - } + let _ = js_msg.ack().await.inspect_err( + |e| warn!(subject, error = %e, "Failed to ack after notification handler error"), + ); } } - if let Err(e) = result { + let _ = result.inspect_err(|e| { warn!( subject, session_id = session_id.as_str(), error = %e, "Error handling JetStream request" ); - } + }); } #[cfg(test)] @@ -702,6 +703,7 @@ mod tests { struct MockAgent { initialized: RefCell, cancelled: RefCell>, + fail_cancel: bool, } impl MockAgent { @@ -709,6 +711,15 @@ mod tests { Self { initialized: RefCell::new(false), cancelled: RefCell::new(Vec::new()), + fail_cancel: false, + } + } + + fn failing_cancel() -> Self { + Self { + initialized: RefCell::new(false), + cancelled: RefCell::new(Vec::new()), + fail_cancel: true, } } } @@ -754,6 +765,9 @@ mod tests { } async fn cancel(&self, args: CancelNotification) -> agent_client_protocol::Result<()> { + if self.fail_cancel { + return Err(AcpError::internal_error()); + } self.cancelled .borrow_mut() .push(args.session_id.to_string()); @@ -1844,6 +1858,32 @@ mod tests { dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await; } + #[tokio::test] + async fn dispatch_js_message_cancel_notification_handler_error_ack_failure() { + use tracing_subscriber::util::SubscriberInitExt; + let _guard = tracing_subscriber::fmt().with_test_writer().set_default(); + + let nats = MockNatsClient::new(); + let agent = MockAgent::failing_cancel(); + let payload = serialize(&CancelNotification::new("s1")); + let js_msg = MockJsMessage::with_failing_signals(async_nats::Message { + subject: "acp.session.s1.agent.cancel".into(), + reply: None, + payload: Bytes::copy_from_slice(&payload), + headers: None, + status: None, + description: None, + length: payload.len(), + }); + dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await; + } + + fn init_handler_error( + _: InitializeRequest, + ) -> std::future::Ready> { + std::future::ready(Err(AcpError::internal_error())) + } + #[tokio::test] async fn handle_request_with_keepalive_completes_fast() { let nats = MockNatsClient::new(); @@ -1871,12 +1911,7 @@ mod tests { )); let msg = make_nats_message("acp.agent.initialize", &payload, None); let js_msg = make_js_msg("acp.agent.initialize", &payload, None); - - let result = - handle_request_with_keepalive(&msg, &nats, &js_msg, |_: InitializeRequest| async { - Err::(agent_client_protocol::Error::new(-1, "not called")) - }) - .await; + let result = handle_request_with_keepalive(&msg, &nats, &js_msg, init_handler_error).await; assert!(result.is_err()); } @@ -1885,15 +1920,23 @@ mod tests { let nats = MockNatsClient::new(); let msg = make_nats_message("acp.agent.initialize", b"not json", Some("_INBOX.1")); let js_msg = make_js_msg("acp.agent.initialize", b"not json", Some("_INBOX.1")); - - let result = - handle_request_with_keepalive(&msg, &nats, &js_msg, |_: InitializeRequest| async { - Err::(agent_client_protocol::Error::new(-1, "not called")) - }) - .await; + let result = handle_request_with_keepalive(&msg, &nats, &js_msg, init_handler_error).await; assert!(result.is_err()); } + #[tokio::test] + async fn handle_request_with_keepalive_handler_returns_error() { + let nats = MockNatsClient::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.1")); + let js_msg = make_js_msg("acp.agent.initialize", &payload, Some("_INBOX.1")); + let result = handle_request_with_keepalive(&msg, &nats, &js_msg, init_handler_error).await; + assert!(result.is_ok()); + assert!(!nats.published_messages().is_empty()); + } + #[tokio::test(start_paused = true)] async fn handle_request_with_keepalive_progress_ack_failure() { use tracing_subscriber::util::SubscriberInitExt; diff --git a/rsworkspace/crates/acp-nats/src/nats/subjects/responses/prompt_response.rs b/rsworkspace/crates/acp-nats/src/nats/subjects/responses/prompt_response.rs index 558563f14..24f730dd9 100644 --- a/rsworkspace/crates/acp-nats/src/nats/subjects/responses/prompt_response.rs +++ b/rsworkspace/crates/acp-nats/src/nats/subjects/responses/prompt_response.rs @@ -44,3 +44,18 @@ impl super::super::stream::StreamAssignment for PromptResponseSubject { const STREAM: Option = Some(super::super::stream::AcpStream::Responses); } + +#[cfg(test)] +mod tests { + use super::*; + use async_nats::subject::ToSubject as _; + + #[test] + fn to_subject_matches_display() { + let prefix = crate::acp_prefix::AcpPrefix::new("acp").expect("prefix"); + let session_id = crate::session_id::AcpSessionId::new("s1").expect("session_id"); + let req_id = crate::req_id::ReqId::from_header("r1"); + let subject = PromptResponseSubject::new(&prefix, &session_id, &req_id); + assert_eq!(subject.to_subject().as_str(), subject.to_string()); + } +} diff --git a/rsworkspace/crates/acp-nats/src/nats/subjects/responses/update.rs b/rsworkspace/crates/acp-nats/src/nats/subjects/responses/update.rs index e1bf89b7d..a168227ba 100644 --- a/rsworkspace/crates/acp-nats/src/nats/subjects/responses/update.rs +++ b/rsworkspace/crates/acp-nats/src/nats/subjects/responses/update.rs @@ -44,3 +44,18 @@ impl super::super::stream::StreamAssignment for UpdateSubject { const STREAM: Option = Some(super::super::stream::AcpStream::Notifications); } + +#[cfg(test)] +mod tests { + use super::*; + use async_nats::subject::ToSubject as _; + + #[test] + fn to_subject_matches_display() { + let prefix = crate::acp_prefix::AcpPrefix::new("acp").expect("prefix"); + let session_id = crate::session_id::AcpSessionId::new("s1").expect("session_id"); + let req_id = crate::req_id::ReqId::from_header("r1"); + let subject = UpdateSubject::new(&prefix, &session_id, &req_id); + assert_eq!(subject.to_subject().as_str(), subject.to_string()); + } +}