diff --git a/rsworkspace/crates/acp-nats-agent/src/connection.rs b/rsworkspace/crates/acp-nats-agent/src/connection.rs index 4750546dc..5fb4c906a 100644 --- a/rsworkspace/crates/acp-nats-agent/src/connection.rs +++ b/rsworkspace/crates/acp-nats-agent/src/connection.rs @@ -397,21 +397,48 @@ mod tests { serde_json::to_vec(value).unwrap() } - #[tokio::test] - async fn dispatch_initialize_calls_agent_and_publishes_response() { + async fn dispatch( + subject: &str, + args: &T, + reply: Option<&str>, + ) -> (MockNatsClient, MockAgent) { let nats = MockNatsClient::new(); let agent = MockAgent::new(); - let payload = serialize(&InitializeRequest::new( - agent_client_protocol::ProtocolVersion::V0, - )); - let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.1")); + let payload = serialize(args); + let msg = make_nats_message(subject, &payload, reply); + dispatch_message(msg, &agent, &nats).await; + (nats, agent) + } + async fn dispatch_raw( + subject: &str, + payload: &[u8], + reply: Option<&str>, + ) -> (MockNatsClient, MockAgent) { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let msg = make_nats_message(subject, payload, reply); dispatch_message(msg, &agent, &nats).await; + (nats, agent) + } + + fn published_response(nats: &MockNatsClient) -> T { + let payloads = nats.published_payloads(); + assert_eq!(payloads.len(), 1); + serde_json::from_slice(&payloads[0]).unwrap() + } + + fn init_request() -> InitializeRequest { + InitializeRequest::new(agent_client_protocol::ProtocolVersion::V0) + } + + #[tokio::test] + async fn dispatch_initialize_calls_agent_and_publishes_response() { + let (nats, agent) = + dispatch("acp.agent.initialize", &init_request(), Some("_INBOX.1")).await; assert!(*agent.initialized.borrow()); - let published = nats.published_payloads(); - assert_eq!(published.len(), 1); - let response: InitializeResponse = serde_json::from_slice(&published[0]).unwrap(); + let response: InitializeResponse = published_response(&nats); assert_eq!( response.protocol_version, agent_client_protocol::ProtocolVersion::V0 @@ -420,27 +447,25 @@ mod tests { #[tokio::test] async fn dispatch_authenticate_error_publishes_acp_error() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&AuthenticateRequest::new("basic")); - let msg = make_nats_message("acp.agent.authenticate", &payload, Some("_INBOX.2")); - - dispatch_message(msg, &agent, &nats).await; + let (nats, _) = dispatch( + "acp.agent.authenticate", + &AuthenticateRequest::new("basic"), + Some("_INBOX.2"), + ) + .await; - let published = nats.published_payloads(); - assert_eq!(published.len(), 1); - let error: AcpError = serde_json::from_slice(&published[0]).unwrap(); + let error: AcpError = published_response(&nats); assert_eq!(error.code, ErrorCode::MethodNotFound); } #[tokio::test] async fn dispatch_cancel_is_notification_no_reply_published() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&CancelNotification::new("sess-1")); - let msg = make_nats_message("acp.s1.agent.session.cancel", &payload, None); - - dispatch_message(msg, &agent, &nats).await; + let (nats, agent) = dispatch( + "acp.s1.agent.session.cancel", + &CancelNotification::new("sess-1"), + None, + ) + .await; assert_eq!(agent.cancelled.borrow().len(), 1); assert!(nats.published_messages().is_empty()); @@ -448,70 +473,47 @@ mod tests { #[tokio::test] async fn dispatch_invalid_payload_publishes_error_reply() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let msg = make_nats_message("acp.agent.initialize", b"not json", Some("_INBOX.err")); - - dispatch_message(msg, &agent, &nats).await; + let (nats, agent) = + dispatch_raw("acp.agent.initialize", b"not json", Some("_INBOX.err")).await; assert!(!*agent.initialized.borrow()); - let published = nats.published_payloads(); - assert_eq!(published.len(), 1); - let error: AcpError = serde_json::from_slice(&published[0]).unwrap(); + let error: AcpError = published_response(&nats); assert_eq!(error.code, ErrorCode::InvalidParams); } #[tokio::test] async fn dispatch_request_without_reply_subject_does_not_publish() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&InitializeRequest::new( - agent_client_protocol::ProtocolVersion::V0, - )); - let msg = make_nats_message("acp.agent.initialize", &payload, None); - - dispatch_message(msg, &agent, &nats).await; - + let (nats, _) = dispatch("acp.agent.initialize", &init_request(), None).await; assert!(nats.published_messages().is_empty()); } #[tokio::test] async fn dispatch_unknown_subject_is_silently_ignored() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let msg = make_nats_message("acp.something.else", b"{}", Some("_INBOX.1")); - - dispatch_message(msg, &agent, &nats).await; - + let (nats, _) = dispatch_raw("acp.something.else", b"{}", Some("_INBOX.1")).await; assert!(nats.published_messages().is_empty()); } #[tokio::test] async fn dispatch_prompt_returns_stop_reason() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&PromptRequest::new("sess-1", vec![])); - let msg = make_nats_message("acp.s1.agent.session.prompt", &payload, Some("_INBOX.3")); - - dispatch_message(msg, &agent, &nats).await; + let (nats, _) = dispatch( + "acp.s1.agent.session.prompt", + &PromptRequest::new("sess-1", vec![]), + Some("_INBOX.3"), + ) + .await; - let published = nats.published_payloads(); - assert_eq!(published.len(), 1); - let response: PromptResponse = serde_json::from_slice(&published[0]).unwrap(); + let response: PromptResponse = published_response(&nats); assert_eq!(response.stop_reason, StopReason::EndTurn); } #[tokio::test] async fn dispatch_publishes_to_correct_reply_subject() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&InitializeRequest::new( - agent_client_protocol::ProtocolVersion::V0, - )); - let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.specific")); - - dispatch_message(msg, &agent, &nats).await; - + let (nats, _) = dispatch( + "acp.agent.initialize", + &init_request(), + Some("_INBOX.specific"), + ) + .await; assert_eq!(nats.published_messages(), vec!["_INBOX.specific"]); } @@ -558,150 +560,108 @@ mod tests { ); } + fn raw_value(json: &str) -> std::sync::Arc { + std::sync::Arc::from(serde_json::value::RawValue::from_string(json.to_string()).unwrap()) + } + #[tokio::test] async fn dispatch_ext_with_reply_calls_ext_method() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&agent_client_protocol::ExtRequest::new( - "my_tool", - std::sync::Arc::from( - serde_json::value::RawValue::from_string("{}".to_string()).unwrap(), - ), - )); - let msg = make_nats_message("acp.agent.ext.my_tool", &payload, Some("_INBOX.ext")); - - dispatch_message(msg, &agent, &nats).await; - + let (nats, _) = dispatch( + "acp.agent.ext.my_tool", + &agent_client_protocol::ExtRequest::new("my_tool", raw_value("{}")), + Some("_INBOX.ext"), + ) + .await; assert_eq!(nats.published_messages(), vec!["_INBOX.ext"]); } #[tokio::test] async fn dispatch_ext_without_reply_calls_ext_notification() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&agent_client_protocol::ExtNotification::new( - "my_tool", - std::sync::Arc::from( - serde_json::value::RawValue::from_string("{}".to_string()).unwrap(), - ), - )); - let msg = make_nats_message("acp.agent.ext.my_tool", &payload, None); - - dispatch_message(msg, &agent, &nats).await; - + let (nats, _) = dispatch( + "acp.agent.ext.my_tool", + &agent_client_protocol::ExtNotification::new("my_tool", raw_value("{}")), + None, + ) + .await; assert!(nats.published_messages().is_empty()); } + async fn assert_dispatch_publishes(subject: &str, args: &T) { + let (nats, _) = dispatch(subject, args, Some("_INBOX.r")).await; + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + #[tokio::test] async fn dispatch_new_session_publishes_response() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&NewSessionRequest::new("/tmp")); - let msg = make_nats_message("acp.agent.session.new", &payload, Some("_INBOX.r")); - - dispatch_message(msg, &agent, &nats).await; - - assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + assert_dispatch_publishes("acp.agent.session.new", &NewSessionRequest::new("/tmp")).await; } #[tokio::test] async fn dispatch_session_load_publishes_response() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&LoadSessionRequest::new("sess-1", "/tmp")); - let msg = make_nats_message("acp.s1.agent.session.load", &payload, Some("_INBOX.r")); - - dispatch_message(msg, &agent, &nats).await; - - assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + assert_dispatch_publishes( + "acp.s1.agent.session.load", + &LoadSessionRequest::new("sess-1", "/tmp"), + ) + .await; } #[tokio::test] async fn dispatch_list_sessions_publishes_response() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&ListSessionsRequest::new()); - let msg = make_nats_message("acp.agent.session.list", &payload, Some("_INBOX.r")); - - dispatch_message(msg, &agent, &nats).await; - - assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + assert_dispatch_publishes("acp.agent.session.list", &ListSessionsRequest::new()).await; } #[tokio::test] async fn dispatch_set_session_mode_publishes_response() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&SetSessionModeRequest::new("sess-1", "code")); - let msg = make_nats_message("acp.s1.agent.session.set_mode", &payload, Some("_INBOX.r")); - - dispatch_message(msg, &agent, &nats).await; - - assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + assert_dispatch_publishes( + "acp.s1.agent.session.set_mode", + &SetSessionModeRequest::new("sess-1", "code"), + ) + .await; } #[tokio::test] async fn dispatch_set_session_config_option_publishes_response() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&SetSessionConfigOptionRequest::new("sess-1", "key", "val")); - let msg = make_nats_message( + assert_dispatch_publishes( "acp.s1.agent.session.set_config_option", - &payload, - Some("_INBOX.r"), - ); - - dispatch_message(msg, &agent, &nats).await; - - assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + &SetSessionConfigOptionRequest::new("sess-1", "key", "val"), + ) + .await; } #[tokio::test] async fn dispatch_set_session_model_publishes_response() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&SetSessionModelRequest::new("sess-1", "gpt-4")); - let msg = make_nats_message("acp.s1.agent.session.set_model", &payload, Some("_INBOX.r")); - - dispatch_message(msg, &agent, &nats).await; - - assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + assert_dispatch_publishes( + "acp.s1.agent.session.set_model", + &SetSessionModelRequest::new("sess-1", "gpt-4"), + ) + .await; } #[tokio::test] async fn dispatch_fork_session_publishes_response() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&ForkSessionRequest::new("sess-1", "/tmp")); - let msg = make_nats_message("acp.s1.agent.session.fork", &payload, Some("_INBOX.r")); - - dispatch_message(msg, &agent, &nats).await; - - assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + assert_dispatch_publishes( + "acp.s1.agent.session.fork", + &ForkSessionRequest::new("sess-1", "/tmp"), + ) + .await; } #[tokio::test] async fn dispatch_resume_session_publishes_response() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&ResumeSessionRequest::new("sess-1", "/tmp")); - let msg = make_nats_message("acp.s1.agent.session.resume", &payload, Some("_INBOX.r")); - - dispatch_message(msg, &agent, &nats).await; - - assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + assert_dispatch_publishes( + "acp.s1.agent.session.resume", + &ResumeSessionRequest::new("sess-1", "/tmp"), + ) + .await; } #[tokio::test] async fn dispatch_close_session_publishes_response() { - let nats = MockNatsClient::new(); - let agent = MockAgent::new(); - let payload = serialize(&CloseSessionRequest::new("sess-1")); - let msg = make_nats_message("acp.s1.agent.session.close", &payload, Some("_INBOX.r")); - - dispatch_message(msg, &agent, &nats).await; - - assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + assert_dispatch_publishes( + "acp.s1.agent.session.close", + &CloseSessionRequest::new("sess-1"), + ) + .await; } #[test]