Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 215 additions & 1 deletion rsworkspace/crates/acp-nats/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub(crate) mod rpc_reply;
pub(crate) mod session_update;
pub(crate) mod terminal_create;
pub(crate) mod terminal_kill;
pub(crate) mod terminal_output;

use crate::agent::Bridge;
use crate::error::AGENT_UNAVAILABLE;
Expand Down Expand Up @@ -225,6 +226,17 @@ async fn dispatch_client_method<
)
.await;
}
ClientMethod::TerminalOutput => {
terminal_output::handle(
&payload,
ctx.client,
reply.as_deref(),
ctx.nats,
parsed.session_id.as_str(),
ctx.serializer,
)
.await;
}
}
}

Expand All @@ -237,7 +249,7 @@ mod tests {
KillTerminalCommandRequest, KillTerminalCommandResponse, ReadTextFileRequest,
ReadTextFileResponse, Request, RequestId, RequestPermissionOutcome,
RequestPermissionRequest, RequestPermissionResponse, SessionNotification, SessionUpdate,
ToolCallUpdate, ToolCallUpdateFields,
TerminalOutputRequest, TerminalOutputResponse, ToolCallUpdate, ToolCallUpdateFields,
};
use async_trait::async_trait;
use std::cell::RefCell;
Expand All @@ -248,19 +260,25 @@ mod tests {
pub(super) struct MockClient {
notifications: RefCell<Vec<String>>,
kill_terminal_calls: RefCell<usize>,
terminal_output_calls: RefCell<usize>,
}

impl MockClient {
pub(super) fn new() -> Self {
Self {
notifications: RefCell::new(Vec::new()),
kill_terminal_calls: RefCell::new(0),
terminal_output_calls: RefCell::new(0),
}
}

pub(super) fn kill_terminal_call_count(&self) -> usize {
*self.kill_terminal_calls.borrow()
}

pub(super) fn terminal_output_call_count(&self) -> usize {
*self.terminal_output_calls.borrow()
}
}

#[async_trait(?Send)]
Expand Down Expand Up @@ -304,6 +322,17 @@ mod tests {
*self.kill_terminal_calls.borrow_mut() += 1;
Ok(KillTerminalCommandResponse::new())
}

async fn terminal_output(
&self,
_: TerminalOutputRequest,
) -> agent_client_protocol::Result<TerminalOutputResponse> {
*self.terminal_output_calls.borrow_mut() += 1;
Ok(TerminalOutputResponse::new(
"mock output".to_string(),
false,
))
}
Comment thread
yordis marked this conversation as resolved.
Comment thread
yordis marked this conversation as resolved.
}

fn make_msg(subject: &str, payload: &[u8], reply: Option<&str>) -> async_nats::Message {
Expand Down Expand Up @@ -577,6 +606,171 @@ mod tests {
assert!(response.get("error").is_none());
}

#[tokio::test]
async fn dispatch_client_method_dispatches_terminal_output() {
let nats = MockNatsClient::new();
let client = MockClient::new();
let session_id = AcpSessionId::new("sess-1").unwrap();

let envelope = Request {
id: RequestId::Number(1),
method: std::sync::Arc::from("terminal/output"),
params: Some(TerminalOutputRequest::new("sess-1", "term-001")),
};
let payload = bytes::Bytes::from(serde_json::to_vec(&envelope).unwrap());

let parsed = crate::nats::ParsedClientSubject {
session_id,
method: ClientMethod::TerminalOutput,
};

let ctx = DispatchContext {
nats: &nats,
client: &client,
serializer: &StdJsonSerialize,
};
dispatch_client_method(
"acp.sess-1.client.terminal.output",
parsed,
payload,
Some("_INBOX.reply".to_string()),
&ctx,
)
.await;

assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]);
let payloads = nats.published_payloads();
assert_eq!(payloads.len(), 1);
let response: serde_json::Value = serde_json::from_slice(payloads[0].as_ref()).unwrap();
assert_eq!(response.get("id"), Some(&serde_json::Value::from(1)));
assert!(response.get("result").is_some());
assert!(response.get("error").is_none());
assert_eq!(
client.terminal_output_call_count(),
1,
"terminal_output handler must run"
);
assert_eq!(
client.kill_terminal_call_count(),
0,
"kill handler must not run"
);
}

#[tokio::test]
async fn dispatch_client_method_dispatches_terminal_output_client_error_publishes_error_reply()
{
let nats = MockNatsClient::new();
let client = TerminalKillFailingClient;
let session_id = AcpSessionId::new("sess-1").unwrap();

let envelope = Request {
id: RequestId::Number(1),
method: std::sync::Arc::from("terminal/output"),
params: Some(TerminalOutputRequest::new("sess-1", "term-001")),
};
let payload = bytes::Bytes::from(serde_json::to_vec(&envelope).unwrap());

let parsed = crate::nats::ParsedClientSubject {
session_id,
method: ClientMethod::TerminalOutput,
};

let ctx = DispatchContext {
nats: &nats,
client: &client,
serializer: &StdJsonSerialize,
};
dispatch_client_method(
"acp.sess-1.client.terminal.output",
parsed,
payload,
Some("_INBOX.reply".to_string()),
&ctx,
)
.await;

assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]);
let payloads = nats.published_payloads();
let response: serde_json::Value = serde_json::from_slice(payloads[0].as_ref()).unwrap();
assert!(response.get("error").is_some());
assert_eq!(
response.get("error").and_then(|e| e.get("code")),
Some(&serde_json::Value::from(-32603))
);
}

#[tokio::test]
async fn dispatch_client_method_dispatches_terminal_output_with_rpc_mock_client() {
let nats = MockNatsClient::new();
let client = RpcMockClient;
let session_id = AcpSessionId::new("sess-1").unwrap();

let envelope = Request {
id: RequestId::Number(1),
method: std::sync::Arc::from("terminal/output"),
params: Some(TerminalOutputRequest::new("sess-1", "term-001")),
};
let payload = bytes::Bytes::from(serde_json::to_vec(&envelope).unwrap());

let parsed = crate::nats::ParsedClientSubject {
session_id,
method: ClientMethod::TerminalOutput,
};

let ctx = DispatchContext {
nats: &nats,
client: &client,
serializer: &StdJsonSerialize,
};
dispatch_client_method(
"acp.sess-1.client.terminal.output",
parsed,
payload,
Some("_INBOX.reply".to_string()),
&ctx,
)
.await;

assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]);
}

#[tokio::test]
async fn dispatch_client_method_dispatches_terminal_output_serialization_fallback() {
let nats = MockNatsClient::new();
let client = MockClient::new();
let serializer = FailNextSerialize::new(1);
let session_id = AcpSessionId::new("sess-1").unwrap();

let envelope = Request {
id: RequestId::Number(1),
method: std::sync::Arc::from("terminal/output"),
params: Some(TerminalOutputRequest::new("sess-1", "term-001")),
};
let payload = bytes::Bytes::from(serde_json::to_vec(&envelope).unwrap());

let parsed = crate::nats::ParsedClientSubject {
session_id,
method: ClientMethod::TerminalOutput,
};

let ctx = DispatchContext {
nats: &nats,
client: &client,
serializer: &serializer,
};
dispatch_client_method(
"acp.sess-1.client.terminal.output",
parsed,
payload,
Some("_INBOX.reply".to_string()),
&ctx,
)
.await;

assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]);
}

#[tokio::test]
async fn dispatch_client_method_terminal_kill_no_reply_does_not_call_client_or_publish() {
let nats = MockNatsClient::new();
Expand Down Expand Up @@ -753,6 +947,16 @@ mod tests {
"mock kill_terminal_command failure",
))
}

async fn terminal_output(
&self,
_: TerminalOutputRequest,
) -> agent_client_protocol::Result<TerminalOutputResponse> {
Err(agent_client_protocol::Error::new(
-32603,
"mock terminal_output failure",
))
}
}

#[tokio::test]
Expand Down Expand Up @@ -1259,6 +1463,16 @@ mod tests {
) -> agent_client_protocol::Result<ReadTextFileResponse> {
Ok(ReadTextFileResponse::new("file contents".to_string()))
}

async fn terminal_output(
&self,
_: TerminalOutputRequest,
) -> agent_client_protocol::Result<TerminalOutputResponse> {
Ok(TerminalOutputResponse::new(
"rpc mock output".to_string(),
false,
))
}
}

#[tokio::test]
Expand Down
Loading