Skip to content

Commit 2f5d636

Browse files
committed
feat(acp-nats): add fs_read_text_file client handler
Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent c4beead commit 2f5d636

8 files changed

Lines changed: 344 additions & 11 deletions

File tree

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
use agent_client_protocol::{Client, ErrorCode, ReadTextFileRequest, Request};
2+
use tracing::instrument;
3+
4+
#[derive(Debug)]
5+
pub enum FsReadTextFileError {
6+
InvalidRequest(serde_json::Error),
7+
ClientError(agent_client_protocol::Error),
8+
SerializationError(serde_json::Error),
9+
}
10+
11+
impl std::fmt::Display for FsReadTextFileError {
12+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13+
match self {
14+
Self::InvalidRequest(e) => write!(f, "invalid request: {}", e),
15+
Self::ClientError(e) => write!(f, "client error: {}", e),
16+
Self::SerializationError(e) => write!(f, "serialization error: {}", e),
17+
}
18+
}
19+
}
20+
21+
impl std::error::Error for FsReadTextFileError {
22+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
23+
match self {
24+
Self::InvalidRequest(e) => Some(e),
25+
Self::ClientError(e) => Some(e),
26+
Self::SerializationError(e) => Some(e),
27+
}
28+
}
29+
}
30+
31+
pub fn error_code_and_message(e: &FsReadTextFileError) -> (ErrorCode, String) {
32+
match e {
33+
FsReadTextFileError::InvalidRequest(inner) => (
34+
ErrorCode::InvalidParams,
35+
format!("Invalid read_text_file request: {}", inner),
36+
),
37+
FsReadTextFileError::ClientError(inner) => (inner.code, inner.message.clone()),
38+
FsReadTextFileError::SerializationError(inner) => (
39+
ErrorCode::InternalError,
40+
format!("Failed to serialize read_text_file response: {}", inner),
41+
),
42+
}
43+
}
44+
45+
/// Forwards read_text_file to the client. NATS enforces payload limits when publishing.
46+
/// Expects a JSON-RPC envelope (id, method, params) compatible with agent_client_protocol::Request.
47+
#[instrument(name = "acp.client.fs.read_text_file", skip(payload, client))]
48+
pub async fn handle<C: Client>(payload: &[u8], client: &C) -> Result<Vec<u8>, FsReadTextFileError> {
49+
let envelope: Request<ReadTextFileRequest> =
50+
serde_json::from_slice(payload).map_err(FsReadTextFileError::InvalidRequest)?;
51+
let request = envelope.params.ok_or_else(|| {
52+
FsReadTextFileError::InvalidRequest(
53+
serde_json::from_value::<ReadTextFileRequest>(serde_json::Value::Null).unwrap_err(),
54+
)
55+
})?;
56+
let response = client
57+
.read_text_file(request)
58+
.await
59+
.map_err(FsReadTextFileError::ClientError)?;
60+
serde_json::to_vec(&response).map_err(FsReadTextFileError::SerializationError)
61+
}
62+
63+
#[cfg(test)]
64+
mod tests {
65+
use super::*;
66+
use agent_client_protocol::{
67+
ReadTextFileRequest, ReadTextFileResponse, Request, RequestId, RequestPermissionRequest,
68+
RequestPermissionResponse, SessionNotification,
69+
};
70+
use async_trait::async_trait;
71+
72+
struct MockClient {
73+
content: String,
74+
}
75+
76+
impl MockClient {
77+
fn new(content: &str) -> Self {
78+
Self {
79+
content: content.to_string(),
80+
}
81+
}
82+
}
83+
84+
#[async_trait(?Send)]
85+
impl Client for MockClient {
86+
async fn session_notification(
87+
&self,
88+
_: SessionNotification,
89+
) -> agent_client_protocol::Result<()> {
90+
Ok(())
91+
}
92+
93+
async fn request_permission(
94+
&self,
95+
_: RequestPermissionRequest,
96+
) -> agent_client_protocol::Result<RequestPermissionResponse> {
97+
Err(agent_client_protocol::Error::new(
98+
-32603,
99+
"not implemented in test mock",
100+
))
101+
}
102+
103+
async fn read_text_file(
104+
&self,
105+
_: ReadTextFileRequest,
106+
) -> agent_client_protocol::Result<ReadTextFileResponse> {
107+
Ok(ReadTextFileResponse::new(self.content.clone()))
108+
}
109+
}
110+
111+
#[tokio::test]
112+
async fn fs_read_text_file_forwards_request_and_returns_response() {
113+
let client = MockClient::new("hello world");
114+
let request = ReadTextFileRequest::new(
115+
agent_client_protocol::SessionId::from("sess-1"),
116+
"/tmp/foo.txt".to_string(),
117+
);
118+
let envelope = Request {
119+
id: RequestId::Number(1),
120+
method: std::sync::Arc::from("fs/read_text_file"),
121+
params: Some(request),
122+
};
123+
let payload = serde_json::to_vec(&envelope).unwrap();
124+
125+
let result = handle(&payload, &client).await;
126+
assert!(result.is_ok());
127+
let response = serde_json::from_slice::<ReadTextFileResponse>(&result.unwrap()).unwrap();
128+
assert_eq!(response.content, "hello world");
129+
}
130+
131+
#[tokio::test]
132+
async fn fs_read_text_file_returns_error_when_payload_is_invalid_json() {
133+
let client = MockClient::new("hello");
134+
let result = handle(b"not json", &client).await;
135+
assert!(result.is_err());
136+
}
137+
138+
#[test]
139+
fn error_code_and_message_invalid_request_returns_invalid_params() {
140+
let err = serde_json::from_slice::<ReadTextFileRequest>(b"not json").unwrap_err();
141+
let fs_err = FsReadTextFileError::InvalidRequest(err);
142+
let (code, message) = error_code_and_message(&fs_err);
143+
assert_eq!(code, ErrorCode::InvalidParams);
144+
assert!(message.contains("Invalid read_text_file request"));
145+
}
146+
147+
#[test]
148+
fn error_code_and_message_client_error_preserves_client_code() {
149+
let client_err =
150+
agent_client_protocol::Error::new(ErrorCode::InvalidParams.into(), "file not found");
151+
let fs_err = FsReadTextFileError::ClientError(client_err);
152+
let (code, message) = error_code_and_message(&fs_err);
153+
assert_eq!(code, ErrorCode::InvalidParams);
154+
assert_eq!(message, "file not found");
155+
}
156+
157+
#[test]
158+
fn error_code_and_message_serialization_error_returns_internal_error() {
159+
struct Unserializable;
160+
impl serde::Serialize for Unserializable {
161+
fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
162+
Err(serde::ser::Error::custom("test serialization failure"))
163+
}
164+
}
165+
let err = serde_json::to_vec(&Unserializable).unwrap_err();
166+
let fs_err = FsReadTextFileError::SerializationError(err);
167+
let (code, message) = error_code_and_message(&fs_err);
168+
assert_eq!(code, ErrorCode::InternalError);
169+
assert!(message.contains("Failed to serialize read_text_file response"));
170+
}
171+
}

rsworkspace/crates/acp-nats/src/client/mod.rs

Lines changed: 137 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
pub(crate) mod fs_read_text_file;
12
pub(crate) mod session_update;
23

34
use crate::agent::Bridge;
5+
use crate::error::AGENT_UNAVAILABLE;
46
use crate::in_flight_slot_guard::InFlightSlotGuard;
7+
use crate::jsonrpc::extract_request_id;
58
use crate::nats::{
69
ClientMethod, FlushClient, PublishClient, RequestClient, SubscribeClient, client,
7-
parse_client_subject,
10+
headers_with_trace_context, parse_client_subject,
811
};
9-
use agent_client_protocol::Client;
12+
use agent_client_protocol::{Client, Error, ErrorCode, RequestId, Response};
1013
use async_nats::Message;
1114
use bytes::Bytes;
1215
use futures::StreamExt;
@@ -84,6 +87,9 @@ async fn process_message<
8487
}
8588
};
8689

90+
let payload = msg.payload.clone();
91+
let reply = msg.reply.as_ref().map(|r| r.to_string());
92+
8793
let current_in_flight = in_flight.get();
8894
if current_in_flight >= max_concurrent {
8995
warn!(
@@ -96,10 +102,42 @@ async fn process_message<
96102
.metrics
97103
.record_error("client", "client_backpressure_rejected");
98104

105+
if let Some(reply_to) = &reply {
106+
let request_id = extract_request_id(&payload);
107+
let bytes = serde_json::to_vec(&Response::<()>::Error {
108+
id: request_id,
109+
error: Error::new(
110+
i32::from(ErrorCode::Other(AGENT_UNAVAILABLE)),
111+
"Client proxy overloaded; retry with backoff",
112+
),
113+
})
114+
.unwrap_or_else(|e| {
115+
serde_json::to_vec(&Response::<()>::Error {
116+
id: RequestId::Null,
117+
error: Error::new(
118+
i32::from(ErrorCode::Other(AGENT_UNAVAILABLE)),
119+
format!(
120+
"Client proxy overloaded; retry with backoff (serialization failed: {})",
121+
e
122+
),
123+
),
124+
})
125+
.unwrap()
126+
})
127+
.into();
128+
let headers = headers_with_trace_context();
129+
if let Err(e) = nats
130+
.publish_with_headers(reply_to.clone(), headers, bytes)
131+
.await
132+
{
133+
error!(error = %e, "Failed to publish backpressure error reply");
134+
}
135+
if let Err(e) = nats.flush().await {
136+
warn!(error = %e, "Failed to flush backpressure error reply");
137+
}
138+
}
99139
return;
100140
}
101-
102-
let payload = msg.payload.clone();
103141
let nats = nats.clone();
104142

105143
let bridge_clone = bridge.clone();
@@ -110,6 +148,7 @@ async fn process_message<
110148
&subject,
111149
parsed,
112150
payload,
151+
reply,
113152
&nats,
114153
client.as_ref(),
115154
bridge_clone.as_ref(),
@@ -118,7 +157,7 @@ async fn process_message<
118157
});
119158
}
120159

121-
#[instrument(skip(payload, _nats, client, _bridge), fields(subject = %subject, session_id = tracing::field::Empty))]
160+
#[instrument(skip(payload, nats, client, _bridge), fields(subject = %subject, session_id = tracing::field::Empty))]
122161
async fn dispatch_client_method<
123162
N: SubscribeClient + RequestClient + PublishClient + FlushClient,
124163
Cl: Client,
@@ -127,13 +166,76 @@ async fn dispatch_client_method<
127166
subject: &str,
128167
parsed: crate::nats::ParsedClientSubject,
129168
payload: Bytes,
130-
_nats: &N,
169+
reply: Option<String>,
170+
nats: &N,
131171
client: &Cl,
132172
_bridge: &Bridge<N, C>,
133173
) {
134174
Span::current().record("session_id", parsed.session_id.as_str());
135175

136176
match parsed.method {
177+
ClientMethod::FsReadTextFile => {
178+
let request_id = extract_request_id(&payload);
179+
match fs_read_text_file::handle(&payload, client).await {
180+
Ok(bytes) => {
181+
if let Some(reply_to) = &reply {
182+
let result =
183+
serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null);
184+
let response_bytes = serde_json::to_vec(&Response::Result {
185+
id: request_id,
186+
result,
187+
})
188+
.unwrap()
189+
.into();
190+
let headers = headers_with_trace_context();
191+
if let Err(e) = nats
192+
.publish_with_headers(reply_to.clone(), headers, response_bytes)
193+
.await
194+
{
195+
error!(error = %e, "Failed to publish fs_read_text_file reply");
196+
}
197+
if let Err(e) = nats.flush().await {
198+
warn!(error = %e, "Failed to flush fs_read_text_file reply");
199+
}
200+
}
201+
}
202+
Err(e) => {
203+
let (code, message) = fs_read_text_file::error_code_and_message(&e);
204+
warn!(
205+
error = %e,
206+
session_id = %parsed.session_id,
207+
"Failed to handle fs_read_text_file"
208+
);
209+
if let Some(reply_to) = &reply {
210+
let bytes = serde_json::to_vec(&Response::<()>::Error {
211+
id: request_id,
212+
error: Error::new(i32::from(code), message.as_str()),
213+
})
214+
.unwrap_or_else(|e| {
215+
serde_json::to_vec(&Response::<()>::Error {
216+
id: RequestId::Null,
217+
error: Error::new(
218+
i32::from(code),
219+
format!("{} (serialization failed: {})", message, e),
220+
),
221+
})
222+
.unwrap()
223+
})
224+
.into();
225+
let headers = headers_with_trace_context();
226+
if let Err(e) = nats
227+
.publish_with_headers(reply_to.clone(), headers, bytes)
228+
.await
229+
{
230+
error!(error = %e, "Failed to publish fs_read_text_file error reply");
231+
}
232+
if let Err(e) = nats.flush().await {
233+
warn!(error = %e, "Failed to flush fs_read_text_file error reply");
234+
}
235+
}
236+
}
237+
}
238+
}
137239
ClientMethod::SessionUpdate => {
138240
session_update::handle(&payload, client, &parsed.session_id).await;
139241
}
@@ -145,8 +247,8 @@ mod tests {
145247
use super::*;
146248
use crate::session_id::AcpSessionId;
147249
use agent_client_protocol::{
148-
ContentBlock, ContentChunk, RequestPermissionRequest, RequestPermissionResponse,
149-
SessionNotification, SessionUpdate,
250+
ContentBlock, ContentChunk, ReadTextFileRequest, Request, RequestId,
251+
RequestPermissionRequest, RequestPermissionResponse, SessionNotification, SessionUpdate,
150252
};
151253
use async_trait::async_trait;
152254
use std::cell::RefCell;
@@ -286,6 +388,7 @@ mod tests {
286388
"acp.sess-1.client.session.update",
287389
parsed,
288390
payload,
391+
None,
289392
&nats,
290393
&client,
291394
&bridge,
@@ -334,6 +437,32 @@ mod tests {
334437
assert!(nats.published_messages().is_empty());
335438
}
336439

440+
#[tokio::test]
441+
async fn process_message_backpressure_with_reply_publishes_error() {
442+
let nats = MockNatsClient::new();
443+
let bridge = make_bridge(nats.clone());
444+
let client = Rc::new(MockClient::new());
445+
let in_flight = Rc::new(Cell::new(1usize));
446+
447+
let envelope = Request {
448+
id: RequestId::Number(1),
449+
method: std::sync::Arc::from("fs/read_text_file"),
450+
params: Some(ReadTextFileRequest::new(
451+
agent_client_protocol::SessionId::from("sess1"),
452+
"/tmp/foo.txt".to_string(),
453+
)),
454+
};
455+
let payload = serde_json::to_vec(&envelope).unwrap();
456+
let msg = make_msg(
457+
"acp.sess1.client.fs.read_text_file",
458+
&payload,
459+
Some("_INBOX.reply"),
460+
);
461+
process_message(msg, &nats, client, bridge, &in_flight, 1).await;
462+
463+
assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]);
464+
}
465+
337466
#[tokio::test]
338467
async fn process_message_valid_dispatch_spawns_task() {
339468
let local = tokio::task::LocalSet::new();

0 commit comments

Comments
 (0)