Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 79 additions & 36 deletions rsworkspace/crates/acp-nats-agent/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use acp_nats::jetstream::consumers::commands_observer;
use acp_nats::jetstream::streams::commands_stream_name;
use acp_nats::nats::agent::wildcards::GlobalAllSubject;
use acp_nats::nats::session::wildcards::{AllAgentExtSubject, AllAgentSubject};
use acp_nats::nats::{
GlobalAgentMethod, ParsedAgentSubject, SessionAgentMethod, parse_agent_subject,
};
use acp_nats::{AcpPrefix, AcpSessionId, NatsClientProxy};
use acp_nats::{
AcpPrefix, AcpSessionId, NatsClientProxy, PromptResponseSubject, ReqId, ResponseSubject,
};
use agent_client_protocol::{
Agent, AuthenticateRequest, CancelNotification, CloseSessionRequest, ExtNotification,
ExtRequest, ForkSessionRequest, InitializeRequest, ListSessionsRequest, LoadSessionRequest,
Expand Down Expand Up @@ -158,8 +164,8 @@ where
N: SubscribeClient + PublishClient + FlushClient + Clone + 'static,
A: Agent + 'static,
{
let global_wildcard = acp_nats::nats::agent::wildcards::GlobalAllSubject::new(prefix);
let session_wildcard = acp_nats::nats::session::wildcards::AllAgentSubject::new(prefix);
let global_wildcard = GlobalAllSubject::new(prefix);
let session_wildcard = AllAgentSubject::new(prefix);

info!(
global = %global_wildcard,
Expand Down Expand Up @@ -204,8 +210,8 @@ where
N: SubscribeClient + PublishClient + FlushClient + Clone + 'static,
A: Agent + 'static,
{
let global_wildcard = acp_nats::nats::agent::wildcards::GlobalAllSubject::new(prefix);
let ext_wildcard = acp_nats::nats::session::wildcards::AllAgentExtSubject::new(prefix);
let global_wildcard = GlobalAllSubject::new(prefix);
let ext_wildcard = AllAgentExtSubject::new(prefix);

info!(
global = %global_wildcard,
Expand Down Expand Up @@ -469,9 +475,8 @@ where
};
}
_ = keepalive.tick() => {
if let Err(e) = js_msg.ack_with(AckKind::Progress).await {
warn!(error = %e, "Failed to send in_progress keepalive");
}
let _ = js_msg.ack_with(AckKind::Progress).await
.inspect_err(|e| warn!(error = %e, "Failed to send in_progress keepalive"));
}
}
}
Expand All @@ -490,8 +495,8 @@ where
trogon_nats::jetstream::JsMessageOf<J>: JsDispatchMessage,
A: Agent + 'static,
{
let stream_name = acp_nats::jetstream::streams::commands_stream_name(prefix);
let config = acp_nats::jetstream::consumers::commands_observer();
let stream_name = commands_stream_name(prefix);
let config = commands_observer();

info!(stream = %stream_name, "Starting JetStream consumer for COMMANDS stream");

Expand Down Expand Up @@ -567,18 +572,14 @@ async fn dispatch_js_message<N: PublishClient + FlushClient, A: Agent, M: JsDisp
.headers
.as_ref()
.and_then(|h| h.get(trogon_nats::REQ_ID_HEADER))
.map(|v| acp_nats::ReqId::from_header(v.as_str()));
.map(|v| ReqId::from_header(v.as_str()));

let reply_subject: Option<String> = match (&req_id, &method) {
(Some(rid), SessionAgentMethod::Prompt) => Some(
acp_nats::nats::session::agent::PromptResponseSubject::new(prefix, &session_id, rid)
.to_string(),
),
(Some(rid), SessionAgentMethod::Prompt) => {
Some(PromptResponseSubject::new(prefix, &session_id, rid).to_string())
}
(_, SessionAgentMethod::Cancel) => None,
(Some(rid), _) => Some(
acp_nats::nats::session::agent::ResponseSubject::new(prefix, &session_id, rid)
.to_string(),
),
(Some(rid), _) => Some(ResponseSubject::new(prefix, &session_id, rid).to_string()),
(None, _) => {
warn!(subject, "JetStream message missing X-Req-Id header");
None
Expand Down Expand Up @@ -673,20 +674,20 @@ async fn dispatch_js_message<N: PublishClient + FlushClient, A: Agent, M: JsDisp
}
}
Err(DispatchError::NotificationHandler(_)) => {
if let Err(e) = js_msg.ack().await {
warn!(subject, error = %e, "Failed to ack after notification handler error");
}
let _ = js_msg.ack().await.inspect_err(
|e| warn!(subject, error = %e, "Failed to ack after notification handler error"),
);
}
}

if let Err(e) = result {
let _ = result.inspect_err(|e| {
warn!(
subject,
session_id = session_id.as_str(),
error = %e,
"Error handling JetStream request"
);
}
});
}

#[cfg(test)]
Expand All @@ -702,13 +703,23 @@ mod tests {
struct MockAgent {
initialized: RefCell<bool>,
cancelled: RefCell<Vec<String>>,
fail_cancel: bool,
}

impl MockAgent {
fn new() -> Self {
Self {
initialized: RefCell::new(false),
cancelled: RefCell::new(Vec::new()),
fail_cancel: false,
}
}

fn failing_cancel() -> Self {
Self {
initialized: RefCell::new(false),
cancelled: RefCell::new(Vec::new()),
fail_cancel: true,
}
}
}
Expand Down Expand Up @@ -754,6 +765,9 @@ mod tests {
}

async fn cancel(&self, args: CancelNotification) -> agent_client_protocol::Result<()> {
if self.fail_cancel {
return Err(AcpError::internal_error());
}
self.cancelled
.borrow_mut()
.push(args.session_id.to_string());
Expand Down Expand Up @@ -1844,6 +1858,32 @@ mod tests {
dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;
}

#[tokio::test]
async fn dispatch_js_message_cancel_notification_handler_error_ack_failure() {
use tracing_subscriber::util::SubscriberInitExt;
let _guard = tracing_subscriber::fmt().with_test_writer().set_default();

let nats = MockNatsClient::new();
let agent = MockAgent::failing_cancel();
let payload = serialize(&CancelNotification::new("s1"));
let js_msg = MockJsMessage::with_failing_signals(async_nats::Message {
subject: "acp.session.s1.agent.cancel".into(),
reply: None,
payload: Bytes::copy_from_slice(&payload),
headers: None,
status: None,
description: None,
length: payload.len(),
});
dispatch_js_message(js_msg, &agent, &nats, &test_prefix()).await;
}

fn init_handler_error(
_: InitializeRequest,
) -> std::future::Ready<agent_client_protocol::Result<InitializeResponse>> {
std::future::ready(Err(AcpError::internal_error()))
}

#[tokio::test]
async fn handle_request_with_keepalive_completes_fast() {
let nats = MockNatsClient::new();
Expand Down Expand Up @@ -1871,12 +1911,7 @@ mod tests {
));
let msg = make_nats_message("acp.agent.initialize", &payload, None);
let js_msg = make_js_msg("acp.agent.initialize", &payload, None);

let result =
handle_request_with_keepalive(&msg, &nats, &js_msg, |_: InitializeRequest| async {
Err::<InitializeResponse, _>(agent_client_protocol::Error::new(-1, "not called"))
})
.await;
let result = handle_request_with_keepalive(&msg, &nats, &js_msg, init_handler_error).await;
assert!(result.is_err());
}

Expand All @@ -1885,15 +1920,23 @@ mod tests {
let nats = MockNatsClient::new();
let msg = make_nats_message("acp.agent.initialize", b"not json", Some("_INBOX.1"));
let js_msg = make_js_msg("acp.agent.initialize", b"not json", Some("_INBOX.1"));

let result =
handle_request_with_keepalive(&msg, &nats, &js_msg, |_: InitializeRequest| async {
Err::<InitializeResponse, _>(agent_client_protocol::Error::new(-1, "not called"))
})
.await;
let result = handle_request_with_keepalive(&msg, &nats, &js_msg, init_handler_error).await;
assert!(result.is_err());
}

#[tokio::test]
async fn handle_request_with_keepalive_handler_returns_error() {
let nats = MockNatsClient::new();
let payload = serialize(&InitializeRequest::new(
agent_client_protocol::ProtocolVersion::V0,
));
let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.1"));
let js_msg = make_js_msg("acp.agent.initialize", &payload, Some("_INBOX.1"));
let result = handle_request_with_keepalive(&msg, &nats, &js_msg, init_handler_error).await;
assert!(result.is_ok());
assert!(!nats.published_messages().is_empty());
}

#[tokio::test(start_paused = true)]
async fn handle_request_with_keepalive_progress_ack_failure() {
use tracing_subscriber::util::SubscriberInitExt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,18 @@ impl super::super::stream::StreamAssignment for PromptResponseSubject {
const STREAM: Option<super::super::stream::AcpStream> =
Some(super::super::stream::AcpStream::Responses);
}

#[cfg(test)]
mod tests {
use super::*;
use async_nats::subject::ToSubject as _;

#[test]
fn to_subject_matches_display() {
let prefix = crate::acp_prefix::AcpPrefix::new("acp").expect("prefix");
let session_id = crate::session_id::AcpSessionId::new("s1").expect("session_id");
let req_id = crate::req_id::ReqId::from_header("r1");
let subject = PromptResponseSubject::new(&prefix, &session_id, &req_id);
assert_eq!(subject.to_subject().as_str(), subject.to_string());
}
}
15 changes: 15 additions & 0 deletions rsworkspace/crates/acp-nats/src/nats/subjects/responses/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,18 @@ impl super::super::stream::StreamAssignment for UpdateSubject {
const STREAM: Option<super::super::stream::AcpStream> =
Some(super::super::stream::AcpStream::Notifications);
}

#[cfg(test)]
mod tests {
use super::*;
use async_nats::subject::ToSubject as _;

#[test]
fn to_subject_matches_display() {
let prefix = crate::acp_prefix::AcpPrefix::new("acp").expect("prefix");
let session_id = crate::session_id::AcpSessionId::new("s1").expect("session_id");
let req_id = crate::req_id::ReqId::from_header("r1");
let subject = UpdateSubject::new(&prefix, &session_id, &req_id);
assert_eq!(subject.to_subject().as_str(), subject.to_string());
}
}
Loading