Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 94 additions & 76 deletions rsworkspace/crates/acp-nats-agent/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ mod tests {
use super::*;
use agent_client_protocol::{
AuthenticateResponse, Error as AcpError, ErrorCode, InitializeResponse, LogoutResponse,
PromptResponse, StopReason,
NewSessionResponse, PromptResponse, StopReason,
};
use std::cell::RefCell;
use trogon_nats::MockNatsClient;
Expand Down Expand Up @@ -866,7 +866,9 @@ mod tests {

#[tokio::test]
async fn dispatch_logout_publishes_response() {
assert_dispatch_publishes("acp.agent.logout", &LogoutRequest::new()).await;
let (nats, _) = dispatch("acp.agent.logout", &LogoutRequest::new(), Some("_INBOX.r")).await;
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
let _: LogoutResponse = published_response(&nats);
Comment thread
yordis marked this conversation as resolved.
}

#[tokio::test]
Expand All @@ -878,7 +880,7 @@ mod tests {
)
.await;

assert_eq!(agent.cancelled.borrow().len(), 1);
assert_eq!(agent.cancelled.borrow().as_slice(), ["s1"]);
assert!(nats.published_messages().is_empty());
}

Expand Down Expand Up @@ -926,6 +928,11 @@ mod tests {
)
.await;
assert_eq!(nats.published_messages(), vec!["_INBOX.specific"]);
let response: InitializeResponse = published_response(&nats);
assert_eq!(
response.protocol_version,
agent_client_protocol::ProtocolVersion::V0
);
}

#[test]
Expand Down Expand Up @@ -984,6 +991,8 @@ mod tests {
)
.await;
assert_eq!(nats.published_messages(), vec!["_INBOX.ext"]);
let value: serde_json::Value = published_response(&nats);
assert!(value.is_null());
}

#[tokio::test]
Expand All @@ -997,19 +1006,30 @@ mod tests {
assert!(nats.published_messages().is_empty());
}

async fn assert_dispatch_publishes<T: serde::Serialize>(subject: &str, args: &T) {
async fn assert_dispatch_method_not_found<T: serde::Serialize>(subject: &str, args: &T) {
let (nats, _) = dispatch(subject, args, Some("_INBOX.r")).await;
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
let error: AcpError = published_response(&nats);
assert_eq!(error.code, ErrorCode::MethodNotFound);
}

#[tokio::test]
async fn dispatch_new_session_publishes_response() {
assert_dispatch_publishes("acp.agent.session.new", &NewSessionRequest::new("/tmp")).await;
let (nats, _) = dispatch(
"acp.agent.session.new",
&NewSessionRequest::new("/tmp"),
Some("_INBOX.r"),
)
.await;

assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
let response: NewSessionResponse = published_response(&nats);
assert_eq!(response.session_id.to_string(), "sess-1");
Comment thread
yordis marked this conversation as resolved.
}

#[tokio::test]
async fn dispatch_session_load_publishes_response() {
assert_dispatch_publishes(
assert_dispatch_method_not_found(
"acp.session.s1.agent.load",
&LoadSessionRequest::new("s1", "/tmp"),
)
Expand All @@ -1018,12 +1038,13 @@ mod tests {

#[tokio::test]
async fn dispatch_list_sessions_publishes_response() {
assert_dispatch_publishes("acp.agent.session.list", &ListSessionsRequest::new()).await;
assert_dispatch_method_not_found("acp.agent.session.list", &ListSessionsRequest::new())
.await;
}

#[tokio::test]
async fn dispatch_set_session_mode_publishes_response() {
assert_dispatch_publishes(
assert_dispatch_method_not_found(
"acp.session.s1.agent.set_mode",
&SetSessionModeRequest::new("s1", "code"),
)
Expand All @@ -1032,7 +1053,7 @@ mod tests {

#[tokio::test]
async fn dispatch_set_session_config_option_publishes_response() {
assert_dispatch_publishes(
assert_dispatch_method_not_found(
"acp.session.s1.agent.set_config_option",
&SetSessionConfigOptionRequest::new("s1", "key", "val"),
)
Expand All @@ -1041,7 +1062,7 @@ mod tests {

#[tokio::test]
async fn dispatch_set_session_model_publishes_response() {
assert_dispatch_publishes(
assert_dispatch_method_not_found(
"acp.session.s1.agent.set_model",
&SetSessionModelRequest::new("s1", "gpt-4"),
)
Expand All @@ -1050,7 +1071,7 @@ mod tests {

#[tokio::test]
async fn dispatch_fork_session_publishes_response() {
assert_dispatch_publishes(
assert_dispatch_method_not_found(
"acp.session.s1.agent.fork",
&ForkSessionRequest::new("s1", "/tmp"),
)
Expand All @@ -1059,7 +1080,7 @@ mod tests {

#[tokio::test]
async fn dispatch_resume_session_publishes_response() {
assert_dispatch_publishes(
assert_dispatch_method_not_found(
"acp.session.s1.agent.resume",
&ResumeSessionRequest::new("s1", "/tmp"),
)
Expand All @@ -1068,7 +1089,7 @@ mod tests {

#[tokio::test]
async fn dispatch_close_session_publishes_response() {
assert_dispatch_publishes(
assert_dispatch_method_not_found(
"acp.session.s1.agent.close",
&CloseSessionRequest::new("s1"),
)
Expand Down Expand Up @@ -1251,7 +1272,12 @@ mod tests {
tokio::task::yield_now().await;
tokio::task::yield_now().await;

assert!(!nats.published_messages().is_empty());
assert_eq!(nats.published_messages(), vec!["_INBOX.serve"]);
let response: InitializeResponse = published_response(&nats);
assert_eq!(
response.protocol_version,
agent_client_protocol::ProtocolVersion::V0
);
})
.await;
}
Expand Down Expand Up @@ -1301,7 +1327,7 @@ mod tests {
tokio::task::yield_now().await;
tokio::task::yield_now().await;

assert!(!nats.published_messages().is_empty());
assert_js_response_method_not_found(&nats, "acp.session.s1.agent.response.req-1");
})
.await;
}
Expand Down Expand Up @@ -1364,18 +1390,6 @@ mod tests {
.await;
}

#[tokio::test]
async fn dispatch_js_message_success_acks() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&LoadSessionRequest::new("s1", "/tmp"));
let js_msg = make_js_msg("acp.session.s1.agent.load", &payload, None);

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;

assert!(!nats.published_messages().is_empty());
}

#[tokio::test]
async fn dispatch_js_message_unknown_subject_terms() {
let nats = MockNatsClient::new();
Expand All @@ -1395,9 +1409,11 @@ mod tests {

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;

let payloads = nats.published_payloads();
assert_eq!(payloads.len(), 1);
let error: agent_client_protocol::Error = serde_json::from_slice(&payloads[0]).unwrap();
assert_eq!(
nats.published_messages(),
vec!["acp.session.s1.agent.response.req-1"]
);
let error: AcpError = published_response(&nats);
assert_eq!(error.code, ErrorCode::InvalidParams);
}

Expand Down Expand Up @@ -1528,14 +1544,12 @@ mod tests {

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;

let subjects = nats.published_messages();
assert!(
subjects
.iter()
.any(|s| s.starts_with("acp.session.s1.agent.prompt.response.")),
"expected prompt.response subject, got: {:?}",
subjects
assert_eq!(
nats.published_messages(),
vec!["acp.session.s1.agent.prompt.response.req-1"]
);
let response: PromptResponse = published_response(&nats);
assert_eq!(response.stop_reason, StopReason::EndTurn);
}

#[tokio::test]
Expand All @@ -1547,14 +1561,7 @@ mod tests {

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;

let subjects = nats.published_messages();
assert!(
subjects
.iter()
.any(|s| s.starts_with("acp.session.s1.agent.response.")),
"expected response subject, got: {:?}",
subjects
);
assert_js_response_method_not_found(&nats, "acp.session.s1.agent.response.req-1");
}

#[tokio::test]
Expand Down Expand Up @@ -1610,8 +1617,12 @@ mod tests {
tokio::task::yield_now().await;
tokio::task::yield_now().await;

assert_eq!(nats.published_messages().len(), 1);
assert_eq!(nats.published_messages()[0], "_INBOX.serve");
assert_eq!(nats.published_messages(), vec!["_INBOX.serve"]);
let response: InitializeResponse = published_response(&nats);
assert_eq!(
response.protocol_version,
agent_client_protocol::ProtocolVersion::V0
);
})
.await;
}
Expand Down Expand Up @@ -1688,87 +1699,85 @@ mod tests {

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;

assert_eq!(agent.cancelled.borrow().len(), 1);
assert_eq!(agent.cancelled.borrow().as_slice(), ["s1"]);
}

fn assert_js_response_method_not_found(nats: &MockNatsClient, expected_subject: &str) {
assert_eq!(nats.published_messages(), vec![expected_subject]);
let error: AcpError = published_response(nats);
assert_eq!(error.code, ErrorCode::MethodNotFound);
}

#[tokio::test]
async fn dispatch_js_message_set_mode() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&SetSessionModeRequest::new("s1", "code"));
let js_msg = make_js_msg("acp.session.s1.agent.set_mode", &payload, Some("_INBOX.r"));
let js_msg = make_js_msg("acp.session.s1.agent.set_mode", &payload, None);

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;

assert!(!nats.published_messages().is_empty());
assert_js_response_method_not_found(&nats, "acp.session.s1.agent.response.req-1");
}

#[tokio::test]
async fn dispatch_js_message_close_session() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&CloseSessionRequest::new("s1"));
let js_msg = make_js_msg("acp.session.s1.agent.close", &payload, Some("_INBOX.r"));
let js_msg = make_js_msg("acp.session.s1.agent.close", &payload, None);

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;

assert!(!nats.published_messages().is_empty());
assert_js_response_method_not_found(&nats, "acp.session.s1.agent.response.req-1");
}

#[tokio::test]
async fn dispatch_js_message_fork_session() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&ForkSessionRequest::new("s1", "/tmp"));
let js_msg = make_js_msg("acp.session.s1.agent.fork", &payload, Some("_INBOX.r"));
let js_msg = make_js_msg("acp.session.s1.agent.fork", &payload, None);

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;

assert!(!nats.published_messages().is_empty());
assert_js_response_method_not_found(&nats, "acp.session.s1.agent.response.req-1");
}

#[tokio::test]
async fn dispatch_js_message_set_config_option() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&SetSessionConfigOptionRequest::new("s1", "key", "val"));
let js_msg = make_js_msg(
"acp.session.s1.agent.set_config_option",
&payload,
Some("_INBOX.r"),
);
let js_msg = make_js_msg("acp.session.s1.agent.set_config_option", &payload, None);

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;
assert!(!nats.published_messages().is_empty());

assert_js_response_method_not_found(&nats, "acp.session.s1.agent.response.req-1");
}

#[tokio::test]
async fn dispatch_js_message_set_model() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&SetSessionModelRequest::new("s1", "gpt-4"));
let js_msg = make_js_msg("acp.session.s1.agent.set_model", &payload, Some("_INBOX.r"));
let js_msg = make_js_msg("acp.session.s1.agent.set_model", &payload, None);

dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;
assert!(!nats.published_messages().is_empty());

assert_js_response_method_not_found(&nats, "acp.session.s1.agent.response.req-1");
}

#[tokio::test]
async fn dispatch_js_message_resume_session() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&ResumeSessionRequest::new("s1", "/tmp"));
let js_msg = make_js_msg("acp.session.s1.agent.resume", &payload, Some("_INBOX.r"));
dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;
assert!(!nats.published_messages().is_empty());
}
let js_msg = make_js_msg("acp.session.s1.agent.resume", &payload, None);

#[tokio::test]
async fn dispatch_js_message_prompt() {
let nats = MockNatsClient::new();
let agent = MockAgent::new();
let payload = serialize(&PromptRequest::new("s1", vec![]));
let js_msg = make_js_msg("acp.session.s1.agent.prompt", &payload, Some("_INBOX.r"));
dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;
assert!(!nats.published_messages().is_empty());

assert_js_response_method_not_found(&nats, "acp.session.s1.agent.response.req-1");
}

#[tokio::test]
Expand Down Expand Up @@ -1900,7 +1909,12 @@ mod tests {
})
.await;
assert!(result.is_ok());
assert!(!nats.published_messages().is_empty());
assert_eq!(nats.published_messages(), vec!["_INBOX.1"]);
let response: InitializeResponse = published_response(&nats);
assert_eq!(
response.protocol_version,
agent_client_protocol::ProtocolVersion::V0
);
}

#[tokio::test]
Expand Down Expand Up @@ -1934,7 +1948,9 @@ mod tests {
let js_msg = make_js_msg("acp.agent.initialize", &payload, Some("_INBOX.1"));
let result = handle_request_with_keepalive(&msg, &nats, &js_msg, init_handler_error).await;
assert!(result.is_ok());
assert!(!nats.published_messages().is_empty());
assert_eq!(nats.published_messages(), vec!["_INBOX.1"]);
let error: AcpError = published_response(&nats);
assert_eq!(error.code, ErrorCode::InternalError);
}

#[tokio::test(start_paused = true)]
Expand Down Expand Up @@ -1984,6 +2000,8 @@ mod tests {
})
.await;
assert!(result.is_ok());
assert!(!nats.published_messages().is_empty());
assert_eq!(nats.published_messages(), vec!["_INBOX.1"]);
let error: AcpError = published_response(&nats);
assert_eq!(error.code, ErrorCode::MethodNotFound);
}
}
Loading