Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 124 additions & 164 deletions rsworkspace/crates/acp-nats-agent/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,21 +397,48 @@ mod tests {
serde_json::to_vec(value).unwrap()
}

#[tokio::test]
async fn dispatch_initialize_calls_agent_and_publishes_response() {
async fn dispatch<T: serde::Serialize>(
subject: &str,
args: &T,
reply: Option<&str>,
) -> (MockNatsClient, MockAgent) {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&InitializeRequest::new(
agent_client_protocol::ProtocolVersion::V0,
));
let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.1"));
let payload = serialize(args);
let msg = make_nats_message(subject, &payload, reply);
dispatch_message(msg, &agent, &nats).await;
(nats, agent)
}

async fn dispatch_raw(
subject: &str,
payload: &[u8],
reply: Option<&str>,
) -> (MockNatsClient, MockAgent) {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let msg = make_nats_message(subject, payload, reply);
dispatch_message(msg, &agent, &nats).await;
(nats, agent)
}

fn published_response<T: serde::de::DeserializeOwned>(nats: &MockNatsClient) -> T {
let payloads = nats.published_payloads();
assert_eq!(payloads.len(), 1);
serde_json::from_slice(&payloads[0]).unwrap()
}

fn init_request() -> InitializeRequest {
InitializeRequest::new(agent_client_protocol::ProtocolVersion::V0)
}

#[tokio::test]
async fn dispatch_initialize_calls_agent_and_publishes_response() {
let (nats, agent) =
dispatch("acp.agent.initialize", &init_request(), Some("_INBOX.1")).await;

assert!(*agent.initialized.borrow());
let published = nats.published_payloads();
assert_eq!(published.len(), 1);
let response: InitializeResponse = serde_json::from_slice(&published[0]).unwrap();
let response: InitializeResponse = published_response(&nats);
assert_eq!(
response.protocol_version,
agent_client_protocol::ProtocolVersion::V0
Expand All @@ -420,98 +447,73 @@ mod tests {

#[tokio::test]
async fn dispatch_authenticate_error_publishes_acp_error() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&AuthenticateRequest::new("basic"));
let msg = make_nats_message("acp.agent.authenticate", &payload, Some("_INBOX.2"));

dispatch_message(msg, &agent, &nats).await;
let (nats, _) = dispatch(
"acp.agent.authenticate",
&AuthenticateRequest::new("basic"),
Some("_INBOX.2"),
)
.await;

let published = nats.published_payloads();
assert_eq!(published.len(), 1);
let error: AcpError = serde_json::from_slice(&published[0]).unwrap();
let error: AcpError = published_response(&nats);
assert_eq!(error.code, ErrorCode::MethodNotFound);
}

#[tokio::test]
async fn dispatch_cancel_is_notification_no_reply_published() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&CancelNotification::new("sess-1"));
let msg = make_nats_message("acp.s1.agent.session.cancel", &payload, None);

dispatch_message(msg, &agent, &nats).await;
let (nats, agent) = dispatch(
"acp.s1.agent.session.cancel",
&CancelNotification::new("sess-1"),
None,
)
.await;

assert_eq!(agent.cancelled.borrow().len(), 1);
assert!(nats.published_messages().is_empty());
}

#[tokio::test]
async fn dispatch_invalid_payload_publishes_error_reply() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let msg = make_nats_message("acp.agent.initialize", b"not json", Some("_INBOX.err"));

dispatch_message(msg, &agent, &nats).await;
let (nats, agent) =
dispatch_raw("acp.agent.initialize", b"not json", Some("_INBOX.err")).await;

assert!(!*agent.initialized.borrow());
let published = nats.published_payloads();
assert_eq!(published.len(), 1);
let error: AcpError = serde_json::from_slice(&published[0]).unwrap();
let error: AcpError = published_response(&nats);
assert_eq!(error.code, ErrorCode::InvalidParams);
}

#[tokio::test]
async fn dispatch_request_without_reply_subject_does_not_publish() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&InitializeRequest::new(
agent_client_protocol::ProtocolVersion::V0,
));
let msg = make_nats_message("acp.agent.initialize", &payload, None);

dispatch_message(msg, &agent, &nats).await;

let (nats, _) = dispatch("acp.agent.initialize", &init_request(), None).await;
assert!(nats.published_messages().is_empty());
}

#[tokio::test]
async fn dispatch_unknown_subject_is_silently_ignored() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let msg = make_nats_message("acp.something.else", b"{}", Some("_INBOX.1"));

dispatch_message(msg, &agent, &nats).await;

let (nats, _) = dispatch_raw("acp.something.else", b"{}", Some("_INBOX.1")).await;
assert!(nats.published_messages().is_empty());
}

#[tokio::test]
async fn dispatch_prompt_returns_stop_reason() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&PromptRequest::new("sess-1", vec![]));
let msg = make_nats_message("acp.s1.agent.session.prompt", &payload, Some("_INBOX.3"));

dispatch_message(msg, &agent, &nats).await;
let (nats, _) = dispatch(
"acp.s1.agent.session.prompt",
&PromptRequest::new("sess-1", vec![]),
Some("_INBOX.3"),
)
.await;

let published = nats.published_payloads();
assert_eq!(published.len(), 1);
let response: PromptResponse = serde_json::from_slice(&published[0]).unwrap();
let response: PromptResponse = published_response(&nats);
assert_eq!(response.stop_reason, StopReason::EndTurn);
}

#[tokio::test]
async fn dispatch_publishes_to_correct_reply_subject() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&InitializeRequest::new(
agent_client_protocol::ProtocolVersion::V0,
));
let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.specific"));

dispatch_message(msg, &agent, &nats).await;

let (nats, _) = dispatch(
"acp.agent.initialize",
&init_request(),
Some("_INBOX.specific"),
)
.await;
assert_eq!(nats.published_messages(), vec!["_INBOX.specific"]);
}

Expand Down Expand Up @@ -558,150 +560,108 @@ mod tests {
);
}

fn raw_value(json: &str) -> std::sync::Arc<serde_json::value::RawValue> {
std::sync::Arc::from(serde_json::value::RawValue::from_string(json.to_string()).unwrap())
}

#[tokio::test]
async fn dispatch_ext_with_reply_calls_ext_method() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&agent_client_protocol::ExtRequest::new(
"my_tool",
std::sync::Arc::from(
serde_json::value::RawValue::from_string("{}".to_string()).unwrap(),
),
));
let msg = make_nats_message("acp.agent.ext.my_tool", &payload, Some("_INBOX.ext"));

dispatch_message(msg, &agent, &nats).await;

let (nats, _) = dispatch(
"acp.agent.ext.my_tool",
&agent_client_protocol::ExtRequest::new("my_tool", raw_value("{}")),
Some("_INBOX.ext"),
)
.await;
assert_eq!(nats.published_messages(), vec!["_INBOX.ext"]);
}

#[tokio::test]
async fn dispatch_ext_without_reply_calls_ext_notification() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&agent_client_protocol::ExtNotification::new(
"my_tool",
std::sync::Arc::from(
serde_json::value::RawValue::from_string("{}".to_string()).unwrap(),
),
));
let msg = make_nats_message("acp.agent.ext.my_tool", &payload, None);

dispatch_message(msg, &agent, &nats).await;

let (nats, _) = dispatch(
"acp.agent.ext.my_tool",
&agent_client_protocol::ExtNotification::new("my_tool", raw_value("{}")),
None,
)
.await;
assert!(nats.published_messages().is_empty());
}

async fn assert_dispatch_publishes<T: serde::Serialize>(subject: &str, args: &T) {
let (nats, _) = dispatch(subject, args, Some("_INBOX.r")).await;
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
}

#[tokio::test]
async fn dispatch_new_session_publishes_response() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&NewSessionRequest::new("/tmp"));
let msg = make_nats_message("acp.agent.session.new", &payload, Some("_INBOX.r"));

dispatch_message(msg, &agent, &nats).await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
assert_dispatch_publishes("acp.agent.session.new", &NewSessionRequest::new("/tmp")).await;
}

#[tokio::test]
async fn dispatch_session_load_publishes_response() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&LoadSessionRequest::new("sess-1", "/tmp"));
let msg = make_nats_message("acp.s1.agent.session.load", &payload, Some("_INBOX.r"));

dispatch_message(msg, &agent, &nats).await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
assert_dispatch_publishes(
"acp.s1.agent.session.load",
&LoadSessionRequest::new("sess-1", "/tmp"),
)
.await;
}

#[tokio::test]
async fn dispatch_list_sessions_publishes_response() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&ListSessionsRequest::new());
let msg = make_nats_message("acp.agent.session.list", &payload, Some("_INBOX.r"));

dispatch_message(msg, &agent, &nats).await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
assert_dispatch_publishes("acp.agent.session.list", &ListSessionsRequest::new()).await;
}

#[tokio::test]
async fn dispatch_set_session_mode_publishes_response() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&SetSessionModeRequest::new("sess-1", "code"));
let msg = make_nats_message("acp.s1.agent.session.set_mode", &payload, Some("_INBOX.r"));

dispatch_message(msg, &agent, &nats).await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
assert_dispatch_publishes(
"acp.s1.agent.session.set_mode",
&SetSessionModeRequest::new("sess-1", "code"),
)
.await;
}

#[tokio::test]
async fn dispatch_set_session_config_option_publishes_response() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&SetSessionConfigOptionRequest::new("sess-1", "key", "val"));
let msg = make_nats_message(
assert_dispatch_publishes(
"acp.s1.agent.session.set_config_option",
&payload,
Some("_INBOX.r"),
);

dispatch_message(msg, &agent, &nats).await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
&SetSessionConfigOptionRequest::new("sess-1", "key", "val"),
)
.await;
}

#[tokio::test]
async fn dispatch_set_session_model_publishes_response() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&SetSessionModelRequest::new("sess-1", "gpt-4"));
let msg = make_nats_message("acp.s1.agent.session.set_model", &payload, Some("_INBOX.r"));

dispatch_message(msg, &agent, &nats).await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
assert_dispatch_publishes(
"acp.s1.agent.session.set_model",
&SetSessionModelRequest::new("sess-1", "gpt-4"),
)
.await;
}

#[tokio::test]
async fn dispatch_fork_session_publishes_response() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&ForkSessionRequest::new("sess-1", "/tmp"));
let msg = make_nats_message("acp.s1.agent.session.fork", &payload, Some("_INBOX.r"));

dispatch_message(msg, &agent, &nats).await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
assert_dispatch_publishes(
"acp.s1.agent.session.fork",
&ForkSessionRequest::new("sess-1", "/tmp"),
)
.await;
}

#[tokio::test]
async fn dispatch_resume_session_publishes_response() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&ResumeSessionRequest::new("sess-1", "/tmp"));
let msg = make_nats_message("acp.s1.agent.session.resume", &payload, Some("_INBOX.r"));

dispatch_message(msg, &agent, &nats).await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
assert_dispatch_publishes(
"acp.s1.agent.session.resume",
&ResumeSessionRequest::new("sess-1", "/tmp"),
)
.await;
}

#[tokio::test]
async fn dispatch_close_session_publishes_response() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&CloseSessionRequest::new("sess-1"));
let msg = make_nats_message("acp.s1.agent.session.close", &payload, Some("_INBOX.r"));

dispatch_message(msg, &agent, &nats).await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
assert_dispatch_publishes(
"acp.s1.agent.session.close",
&CloseSessionRequest::new("sess-1"),
)
.await;
}

#[test]
Expand Down
Loading