Skip to content

Commit 608cbc2

Browse files
committed
refactor(acp-nats-agent): clean up test helpers and reduce boilerplate
Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent 54c5f3e commit 608cbc2

File tree

1 file changed

+124
-164
lines changed

1 file changed

+124
-164
lines changed

rsworkspace/crates/acp-nats-agent/src/connection.rs

Lines changed: 124 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -397,21 +397,48 @@ mod tests {
397397
serde_json::to_vec(value).unwrap()
398398
}
399399

400-
#[tokio::test]
401-
async fn dispatch_initialize_calls_agent_and_publishes_response() {
400+
async fn dispatch<T: serde::Serialize>(
401+
subject: &str,
402+
args: &T,
403+
reply: Option<&str>,
404+
) -> (MockNatsClient, MockAgent) {
402405
let nats = MockNatsClient::new();
403406
let agent = MockAgent::new();
404-
let payload = serialize(&InitializeRequest::new(
405-
agent_client_protocol::ProtocolVersion::V0,
406-
));
407-
let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.1"));
407+
let payload = serialize(args);
408+
let msg = make_nats_message(subject, &payload, reply);
409+
dispatch_message(msg, &agent, &nats).await;
410+
(nats, agent)
411+
}
408412

413+
async fn dispatch_raw(
414+
subject: &str,
415+
payload: &[u8],
416+
reply: Option<&str>,
417+
) -> (MockNatsClient, MockAgent) {
418+
let nats = MockNatsClient::new();
419+
let agent = MockAgent::new();
420+
let msg = make_nats_message(subject, payload, reply);
409421
dispatch_message(msg, &agent, &nats).await;
422+
(nats, agent)
423+
}
424+
425+
fn published_response<T: serde::de::DeserializeOwned>(nats: &MockNatsClient) -> T {
426+
let payloads = nats.published_payloads();
427+
assert_eq!(payloads.len(), 1);
428+
serde_json::from_slice(&payloads[0]).unwrap()
429+
}
430+
431+
fn init_request() -> InitializeRequest {
432+
InitializeRequest::new(agent_client_protocol::ProtocolVersion::V0)
433+
}
434+
435+
#[tokio::test]
436+
async fn dispatch_initialize_calls_agent_and_publishes_response() {
437+
let (nats, agent) =
438+
dispatch("acp.agent.initialize", &init_request(), Some("_INBOX.1")).await;
410439

411440
assert!(*agent.initialized.borrow());
412-
let published = nats.published_payloads();
413-
assert_eq!(published.len(), 1);
414-
let response: InitializeResponse = serde_json::from_slice(&published[0]).unwrap();
441+
let response: InitializeResponse = published_response(&nats);
415442
assert_eq!(
416443
response.protocol_version,
417444
agent_client_protocol::ProtocolVersion::V0
@@ -420,98 +447,73 @@ mod tests {
420447

421448
#[tokio::test]
422449
async fn dispatch_authenticate_error_publishes_acp_error() {
423-
let nats = MockNatsClient::new();
424-
let agent = MockAgent::new();
425-
let payload = serialize(&AuthenticateRequest::new("basic"));
426-
let msg = make_nats_message("acp.agent.authenticate", &payload, Some("_INBOX.2"));
427-
428-
dispatch_message(msg, &agent, &nats).await;
450+
let (nats, _) = dispatch(
451+
"acp.agent.authenticate",
452+
&AuthenticateRequest::new("basic"),
453+
Some("_INBOX.2"),
454+
)
455+
.await;
429456

430-
let published = nats.published_payloads();
431-
assert_eq!(published.len(), 1);
432-
let error: AcpError = serde_json::from_slice(&published[0]).unwrap();
457+
let error: AcpError = published_response(&nats);
433458
assert_eq!(error.code, ErrorCode::MethodNotFound);
434459
}
435460

436461
#[tokio::test]
437462
async fn dispatch_cancel_is_notification_no_reply_published() {
438-
let nats = MockNatsClient::new();
439-
let agent = MockAgent::new();
440-
let payload = serialize(&CancelNotification::new("sess-1"));
441-
let msg = make_nats_message("acp.s1.agent.session.cancel", &payload, None);
442-
443-
dispatch_message(msg, &agent, &nats).await;
463+
let (nats, agent) = dispatch(
464+
"acp.s1.agent.session.cancel",
465+
&CancelNotification::new("sess-1"),
466+
None,
467+
)
468+
.await;
444469

445470
assert_eq!(agent.cancelled.borrow().len(), 1);
446471
assert!(nats.published_messages().is_empty());
447472
}
448473

449474
#[tokio::test]
450475
async fn dispatch_invalid_payload_publishes_error_reply() {
451-
let nats = MockNatsClient::new();
452-
let agent = MockAgent::new();
453-
let msg = make_nats_message("acp.agent.initialize", b"not json", Some("_INBOX.err"));
454-
455-
dispatch_message(msg, &agent, &nats).await;
476+
let (nats, agent) =
477+
dispatch_raw("acp.agent.initialize", b"not json", Some("_INBOX.err")).await;
456478

457479
assert!(!*agent.initialized.borrow());
458-
let published = nats.published_payloads();
459-
assert_eq!(published.len(), 1);
460-
let error: AcpError = serde_json::from_slice(&published[0]).unwrap();
480+
let error: AcpError = published_response(&nats);
461481
assert_eq!(error.code, ErrorCode::InvalidParams);
462482
}
463483

464484
#[tokio::test]
465485
async fn dispatch_request_without_reply_subject_does_not_publish() {
466-
let nats = MockNatsClient::new();
467-
let agent = MockAgent::new();
468-
let payload = serialize(&InitializeRequest::new(
469-
agent_client_protocol::ProtocolVersion::V0,
470-
));
471-
let msg = make_nats_message("acp.agent.initialize", &payload, None);
472-
473-
dispatch_message(msg, &agent, &nats).await;
474-
486+
let (nats, _) = dispatch("acp.agent.initialize", &init_request(), None).await;
475487
assert!(nats.published_messages().is_empty());
476488
}
477489

478490
#[tokio::test]
479491
async fn dispatch_unknown_subject_is_silently_ignored() {
480-
let nats = MockNatsClient::new();
481-
let agent = MockAgent::new();
482-
let msg = make_nats_message("acp.something.else", b"{}", Some("_INBOX.1"));
483-
484-
dispatch_message(msg, &agent, &nats).await;
485-
492+
let (nats, _) = dispatch_raw("acp.something.else", b"{}", Some("_INBOX.1")).await;
486493
assert!(nats.published_messages().is_empty());
487494
}
488495

489496
#[tokio::test]
490497
async fn dispatch_prompt_returns_stop_reason() {
491-
let nats = MockNatsClient::new();
492-
let agent = MockAgent::new();
493-
let payload = serialize(&PromptRequest::new("sess-1", vec![]));
494-
let msg = make_nats_message("acp.s1.agent.session.prompt", &payload, Some("_INBOX.3"));
495-
496-
dispatch_message(msg, &agent, &nats).await;
498+
let (nats, _) = dispatch(
499+
"acp.s1.agent.session.prompt",
500+
&PromptRequest::new("sess-1", vec![]),
501+
Some("_INBOX.3"),
502+
)
503+
.await;
497504

498-
let published = nats.published_payloads();
499-
assert_eq!(published.len(), 1);
500-
let response: PromptResponse = serde_json::from_slice(&published[0]).unwrap();
505+
let response: PromptResponse = published_response(&nats);
501506
assert_eq!(response.stop_reason, StopReason::EndTurn);
502507
}
503508

504509
#[tokio::test]
505510
async fn dispatch_publishes_to_correct_reply_subject() {
506-
let nats = MockNatsClient::new();
507-
let agent = MockAgent::new();
508-
let payload = serialize(&InitializeRequest::new(
509-
agent_client_protocol::ProtocolVersion::V0,
510-
));
511-
let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.specific"));
512-
513-
dispatch_message(msg, &agent, &nats).await;
514-
511+
let (nats, _) = dispatch(
512+
"acp.agent.initialize",
513+
&init_request(),
514+
Some("_INBOX.specific"),
515+
)
516+
.await;
515517
assert_eq!(nats.published_messages(), vec!["_INBOX.specific"]);
516518
}
517519

@@ -558,150 +560,108 @@ mod tests {
558560
);
559561
}
560562

563+
fn raw_value(json: &str) -> std::sync::Arc<serde_json::value::RawValue> {
564+
std::sync::Arc::from(serde_json::value::RawValue::from_string(json.to_string()).unwrap())
565+
}
566+
561567
#[tokio::test]
562568
async fn dispatch_ext_with_reply_calls_ext_method() {
563-
let nats = MockNatsClient::new();
564-
let agent = MockAgent::new();
565-
let payload = serialize(&agent_client_protocol::ExtRequest::new(
566-
"my_tool",
567-
std::sync::Arc::from(
568-
serde_json::value::RawValue::from_string("{}".to_string()).unwrap(),
569-
),
570-
));
571-
let msg = make_nats_message("acp.agent.ext.my_tool", &payload, Some("_INBOX.ext"));
572-
573-
dispatch_message(msg, &agent, &nats).await;
574-
569+
let (nats, _) = dispatch(
570+
"acp.agent.ext.my_tool",
571+
&agent_client_protocol::ExtRequest::new("my_tool", raw_value("{}")),
572+
Some("_INBOX.ext"),
573+
)
574+
.await;
575575
assert_eq!(nats.published_messages(), vec!["_INBOX.ext"]);
576576
}
577577

578578
#[tokio::test]
579579
async fn dispatch_ext_without_reply_calls_ext_notification() {
580-
let nats = MockNatsClient::new();
581-
let agent = MockAgent::new();
582-
let payload = serialize(&agent_client_protocol::ExtNotification::new(
583-
"my_tool",
584-
std::sync::Arc::from(
585-
serde_json::value::RawValue::from_string("{}".to_string()).unwrap(),
586-
),
587-
));
588-
let msg = make_nats_message("acp.agent.ext.my_tool", &payload, None);
589-
590-
dispatch_message(msg, &agent, &nats).await;
591-
580+
let (nats, _) = dispatch(
581+
"acp.agent.ext.my_tool",
582+
&agent_client_protocol::ExtNotification::new("my_tool", raw_value("{}")),
583+
None,
584+
)
585+
.await;
592586
assert!(nats.published_messages().is_empty());
593587
}
594588

589+
async fn assert_dispatch_publishes<T: serde::Serialize>(subject: &str, args: &T) {
590+
let (nats, _) = dispatch(subject, args, Some("_INBOX.r")).await;
591+
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
592+
}
593+
595594
#[tokio::test]
596595
async fn dispatch_new_session_publishes_response() {
597-
let nats = MockNatsClient::new();
598-
let agent = MockAgent::new();
599-
let payload = serialize(&NewSessionRequest::new("/tmp"));
600-
let msg = make_nats_message("acp.agent.session.new", &payload, Some("_INBOX.r"));
601-
602-
dispatch_message(msg, &agent, &nats).await;
603-
604-
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
596+
assert_dispatch_publishes("acp.agent.session.new", &NewSessionRequest::new("/tmp")).await;
605597
}
606598

607599
#[tokio::test]
608600
async fn dispatch_session_load_publishes_response() {
609-
let nats = MockNatsClient::new();
610-
let agent = MockAgent::new();
611-
let payload = serialize(&LoadSessionRequest::new("sess-1", "/tmp"));
612-
let msg = make_nats_message("acp.s1.agent.session.load", &payload, Some("_INBOX.r"));
613-
614-
dispatch_message(msg, &agent, &nats).await;
615-
616-
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
601+
assert_dispatch_publishes(
602+
"acp.s1.agent.session.load",
603+
&LoadSessionRequest::new("sess-1", "/tmp"),
604+
)
605+
.await;
617606
}
618607

619608
#[tokio::test]
620609
async fn dispatch_list_sessions_publishes_response() {
621-
let nats = MockNatsClient::new();
622-
let agent = MockAgent::new();
623-
let payload = serialize(&ListSessionsRequest::new());
624-
let msg = make_nats_message("acp.agent.session.list", &payload, Some("_INBOX.r"));
625-
626-
dispatch_message(msg, &agent, &nats).await;
627-
628-
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
610+
assert_dispatch_publishes("acp.agent.session.list", &ListSessionsRequest::new()).await;
629611
}
630612

631613
#[tokio::test]
632614
async fn dispatch_set_session_mode_publishes_response() {
633-
let nats = MockNatsClient::new();
634-
let agent = MockAgent::new();
635-
let payload = serialize(&SetSessionModeRequest::new("sess-1", "code"));
636-
let msg = make_nats_message("acp.s1.agent.session.set_mode", &payload, Some("_INBOX.r"));
637-
638-
dispatch_message(msg, &agent, &nats).await;
639-
640-
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
615+
assert_dispatch_publishes(
616+
"acp.s1.agent.session.set_mode",
617+
&SetSessionModeRequest::new("sess-1", "code"),
618+
)
619+
.await;
641620
}
642621

643622
#[tokio::test]
644623
async fn dispatch_set_session_config_option_publishes_response() {
645-
let nats = MockNatsClient::new();
646-
let agent = MockAgent::new();
647-
let payload = serialize(&SetSessionConfigOptionRequest::new("sess-1", "key", "val"));
648-
let msg = make_nats_message(
624+
assert_dispatch_publishes(
649625
"acp.s1.agent.session.set_config_option",
650-
&payload,
651-
Some("_INBOX.r"),
652-
);
653-
654-
dispatch_message(msg, &agent, &nats).await;
655-
656-
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
626+
&SetSessionConfigOptionRequest::new("sess-1", "key", "val"),
627+
)
628+
.await;
657629
}
658630

659631
#[tokio::test]
660632
async fn dispatch_set_session_model_publishes_response() {
661-
let nats = MockNatsClient::new();
662-
let agent = MockAgent::new();
663-
let payload = serialize(&SetSessionModelRequest::new("sess-1", "gpt-4"));
664-
let msg = make_nats_message("acp.s1.agent.session.set_model", &payload, Some("_INBOX.r"));
665-
666-
dispatch_message(msg, &agent, &nats).await;
667-
668-
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
633+
assert_dispatch_publishes(
634+
"acp.s1.agent.session.set_model",
635+
&SetSessionModelRequest::new("sess-1", "gpt-4"),
636+
)
637+
.await;
669638
}
670639

671640
#[tokio::test]
672641
async fn dispatch_fork_session_publishes_response() {
673-
let nats = MockNatsClient::new();
674-
let agent = MockAgent::new();
675-
let payload = serialize(&ForkSessionRequest::new("sess-1", "/tmp"));
676-
let msg = make_nats_message("acp.s1.agent.session.fork", &payload, Some("_INBOX.r"));
677-
678-
dispatch_message(msg, &agent, &nats).await;
679-
680-
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
642+
assert_dispatch_publishes(
643+
"acp.s1.agent.session.fork",
644+
&ForkSessionRequest::new("sess-1", "/tmp"),
645+
)
646+
.await;
681647
}
682648

683649
#[tokio::test]
684650
async fn dispatch_resume_session_publishes_response() {
685-
let nats = MockNatsClient::new();
686-
let agent = MockAgent::new();
687-
let payload = serialize(&ResumeSessionRequest::new("sess-1", "/tmp"));
688-
let msg = make_nats_message("acp.s1.agent.session.resume", &payload, Some("_INBOX.r"));
689-
690-
dispatch_message(msg, &agent, &nats).await;
691-
692-
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
651+
assert_dispatch_publishes(
652+
"acp.s1.agent.session.resume",
653+
&ResumeSessionRequest::new("sess-1", "/tmp"),
654+
)
655+
.await;
693656
}
694657

695658
#[tokio::test]
696659
async fn dispatch_close_session_publishes_response() {
697-
let nats = MockNatsClient::new();
698-
let agent = MockAgent::new();
699-
let payload = serialize(&CloseSessionRequest::new("sess-1"));
700-
let msg = make_nats_message("acp.s1.agent.session.close", &payload, Some("_INBOX.r"));
701-
702-
dispatch_message(msg, &agent, &nats).await;
703-
704-
assert_eq!(nats.published_messages(), vec!["_INBOX.r"]);
660+
assert_dispatch_publishes(
661+
"acp.s1.agent.session.close",
662+
&CloseSessionRequest::new("sess-1"),
663+
)
664+
.await;
705665
}
706666

707667
#[test]

0 commit comments

Comments
 (0)