Skip to content

Commit e170bf3

Browse files
committed
feat(acp-nats): add fs_read_text_file client handler
Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent c4beead commit e170bf3

8 files changed

Lines changed: 406 additions & 10 deletions

File tree

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
use crate::jsonrpc::JsonRpcRequest;
2+
use agent_client_protocol::{Client, ErrorCode, ReadTextFileRequest};
3+
use tracing::instrument;
4+
5+
#[derive(Debug)]
6+
pub enum FsReadTextFileError {
7+
InvalidRequest(serde_json::Error),
8+
ClientError(agent_client_protocol::Error),
9+
SerializationError(serde_json::Error),
10+
}
11+
12+
impl std::fmt::Display for FsReadTextFileError {
13+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14+
match self {
15+
Self::InvalidRequest(e) => write!(f, "invalid request: {}", e),
16+
Self::ClientError(e) => write!(f, "client error: {}", e),
17+
Self::SerializationError(e) => write!(f, "serialization error: {}", e),
18+
}
19+
}
20+
}
21+
22+
impl std::error::Error for FsReadTextFileError {
23+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
24+
match self {
25+
Self::InvalidRequest(e) => Some(e),
26+
Self::ClientError(e) => Some(e),
27+
Self::SerializationError(e) => Some(e),
28+
}
29+
}
30+
}
31+
32+
pub fn error_code_and_message(e: &FsReadTextFileError) -> (ErrorCode, String) {
33+
match e {
34+
FsReadTextFileError::InvalidRequest(inner) => (
35+
ErrorCode::InvalidParams,
36+
format!("Invalid read_text_file request: {}", inner),
37+
),
38+
FsReadTextFileError::ClientError(inner) => (inner.code, inner.message.clone()),
39+
FsReadTextFileError::SerializationError(inner) => (
40+
ErrorCode::InternalError,
41+
format!("Failed to serialize read_text_file response: {}", inner),
42+
),
43+
}
44+
}
45+
46+
/// Forwards read_text_file to the client. NATS enforces payload limits when publishing.
47+
/// Expects a JSON-RPC envelope with a "params" object containing the ReadTextFileRequest fields.
48+
#[instrument(name = "acp.client.fs.read_text_file", skip(payload, client))]
49+
pub async fn handle<C: Client>(payload: &[u8], client: &C) -> Result<Vec<u8>, FsReadTextFileError> {
50+
let envelope: JsonRpcRequest<ReadTextFileRequest> =
51+
serde_json::from_slice(payload).map_err(FsReadTextFileError::InvalidRequest)?;
52+
let request = envelope.params.ok_or_else(|| {
53+
FsReadTextFileError::InvalidRequest(
54+
serde_json::from_value::<ReadTextFileRequest>(serde_json::Value::Null).unwrap_err(),
55+
)
56+
})?;
57+
let response = client
58+
.read_text_file(request)
59+
.await
60+
.map_err(FsReadTextFileError::ClientError)?;
61+
serde_json::to_vec(&response).map_err(FsReadTextFileError::SerializationError)
62+
}
63+
64+
#[cfg(test)]
65+
mod tests {
66+
use super::*;
67+
use crate::jsonrpc::JsonRpcRequest;
68+
use agent_client_protocol::RequestPermissionRequest;
69+
use agent_client_protocol::RequestPermissionResponse;
70+
use agent_client_protocol::SessionNotification;
71+
use agent_client_protocol::{ReadTextFileRequest, ReadTextFileResponse};
72+
use async_trait::async_trait;
73+
74+
struct MockClient {
75+
content: String,
76+
}
77+
78+
impl MockClient {
79+
fn new(content: &str) -> Self {
80+
Self {
81+
content: content.to_string(),
82+
}
83+
}
84+
}
85+
86+
#[async_trait(?Send)]
87+
impl Client for MockClient {
88+
async fn session_notification(
89+
&self,
90+
_: SessionNotification,
91+
) -> agent_client_protocol::Result<()> {
92+
Ok(())
93+
}
94+
95+
async fn request_permission(
96+
&self,
97+
_: RequestPermissionRequest,
98+
) -> agent_client_protocol::Result<RequestPermissionResponse> {
99+
Err(agent_client_protocol::Error::new(
100+
-32603,
101+
"not implemented in test mock",
102+
))
103+
}
104+
105+
async fn read_text_file(
106+
&self,
107+
_: ReadTextFileRequest,
108+
) -> agent_client_protocol::Result<ReadTextFileResponse> {
109+
Ok(ReadTextFileResponse::new(self.content.clone()))
110+
}
111+
}
112+
113+
#[tokio::test]
114+
async fn fs_read_text_file_forwards_request_and_returns_response() {
115+
let client = MockClient::new("hello world");
116+
let request = ReadTextFileRequest::new(
117+
agent_client_protocol::SessionId::from("sess-1"),
118+
"/tmp/foo.txt".to_string(),
119+
);
120+
let envelope = JsonRpcRequest {
121+
jsonrpc: "2.0".to_string(),
122+
id: serde_json::json!(1),
123+
method: Some("fs/read_text_file".to_string()),
124+
params: Some(request),
125+
};
126+
let payload = serde_json::to_vec(&envelope).unwrap();
127+
128+
let result = handle(&payload, &client).await;
129+
assert!(result.is_ok());
130+
let response = serde_json::from_slice::<ReadTextFileResponse>(&result.unwrap()).unwrap();
131+
assert_eq!(response.content, "hello world");
132+
}
133+
134+
#[tokio::test]
135+
async fn fs_read_text_file_returns_error_when_payload_is_invalid_json() {
136+
let client = MockClient::new("hello");
137+
let result = handle(b"not json", &client).await;
138+
assert!(result.is_err());
139+
}
140+
141+
#[test]
142+
fn error_code_and_message_invalid_request_returns_invalid_params() {
143+
let err = serde_json::from_slice::<ReadTextFileRequest>(b"not json").unwrap_err();
144+
let fs_err = FsReadTextFileError::InvalidRequest(err);
145+
let (code, message) = error_code_and_message(&fs_err);
146+
assert_eq!(code, ErrorCode::InvalidParams);
147+
assert!(message.contains("Invalid read_text_file request"));
148+
}
149+
150+
#[test]
151+
fn error_code_and_message_client_error_preserves_client_code() {
152+
let client_err =
153+
agent_client_protocol::Error::new(ErrorCode::InvalidParams.into(), "file not found");
154+
let fs_err = FsReadTextFileError::ClientError(client_err);
155+
let (code, message) = error_code_and_message(&fs_err);
156+
assert_eq!(code, ErrorCode::InvalidParams);
157+
assert_eq!(message, "file not found");
158+
}
159+
160+
#[test]
161+
fn error_code_and_message_serialization_error_returns_internal_error() {
162+
struct Unserializable;
163+
impl serde::Serialize for Unserializable {
164+
fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
165+
Err(serde::ser::Error::custom("test serialization failure"))
166+
}
167+
}
168+
let err = serde_json::to_vec(&Unserializable).unwrap_err();
169+
let fs_err = FsReadTextFileError::SerializationError(err);
170+
let (code, message) = error_code_and_message(&fs_err);
171+
assert_eq!(code, ErrorCode::InternalError);
172+
assert!(message.contains("Failed to serialize read_text_file response"));
173+
}
174+
}

rsworkspace/crates/acp-nats/src/client/mod.rs

Lines changed: 128 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
pub(crate) mod fs_read_text_file;
12
pub(crate) mod session_update;
23

34
use crate::agent::Bridge;
5+
use crate::error::AGENT_UNAVAILABLE;
46
use crate::in_flight_slot_guard::InFlightSlotGuard;
7+
use crate::jsonrpc::{ErrorResponse, ResultResponse, extract_request_id};
58
use crate::nats::{
69
ClientMethod, FlushClient, PublishClient, RequestClient, SubscribeClient, client,
7-
parse_client_subject,
10+
headers_with_trace_context, parse_client_subject,
811
};
912
use agent_client_protocol::Client;
13+
use agent_client_protocol::ErrorCode;
1014
use async_nats::Message;
1115
use bytes::Bytes;
1216
use futures::StreamExt;
@@ -84,6 +88,9 @@ async fn process_message<
8488
}
8589
};
8690

91+
let payload = msg.payload.clone();
92+
let reply = msg.reply.as_ref().map(|r| r.to_string());
93+
8794
let current_in_flight = in_flight.get();
8895
if current_in_flight >= max_concurrent {
8996
warn!(
@@ -96,10 +103,39 @@ async fn process_message<
96103
.metrics
97104
.record_error("client", "client_backpressure_rejected");
98105

106+
if let Some(reply_to) = &reply {
107+
let request_id = extract_request_id(&payload);
108+
let bytes = ErrorResponse::new(
109+
request_id,
110+
ErrorCode::Other(AGENT_UNAVAILABLE),
111+
"Client proxy overloaded; retry with backoff",
112+
)
113+
.to_bytes()
114+
.unwrap_or_else(|e| {
115+
ErrorResponse::new(
116+
serde_json::Value::Null,
117+
ErrorCode::Other(AGENT_UNAVAILABLE),
118+
format!(
119+
"Client proxy overloaded; retry with backoff (serialization failed: {})",
120+
e
121+
),
122+
)
123+
.to_bytes()
124+
.unwrap()
125+
});
126+
let headers = headers_with_trace_context();
127+
if let Err(e) = nats
128+
.publish_with_headers(reply_to.clone(), headers, bytes)
129+
.await
130+
{
131+
error!(error = %e, "Failed to publish backpressure error reply");
132+
}
133+
if let Err(e) = nats.flush().await {
134+
warn!(error = %e, "Failed to flush backpressure error reply");
135+
}
136+
}
99137
return;
100138
}
101-
102-
let payload = msg.payload.clone();
103139
let nats = nats.clone();
104140

105141
let bridge_clone = bridge.clone();
@@ -110,6 +146,7 @@ async fn process_message<
110146
&subject,
111147
parsed,
112148
payload,
149+
reply,
113150
&nats,
114151
client.as_ref(),
115152
bridge_clone.as_ref(),
@@ -118,7 +155,7 @@ async fn process_message<
118155
});
119156
}
120157

121-
#[instrument(skip(payload, _nats, client, _bridge), fields(subject = %subject, session_id = tracing::field::Empty))]
158+
#[instrument(skip(payload, nats, client, _bridge), fields(subject = %subject, session_id = tracing::field::Empty))]
122159
async fn dispatch_client_method<
123160
N: SubscribeClient + RequestClient + PublishClient + FlushClient,
124161
Cl: Client,
@@ -127,13 +164,68 @@ async fn dispatch_client_method<
127164
subject: &str,
128165
parsed: crate::nats::ParsedClientSubject,
129166
payload: Bytes,
130-
_nats: &N,
167+
reply: Option<String>,
168+
nats: &N,
131169
client: &Cl,
132170
_bridge: &Bridge<N, C>,
133171
) {
134172
Span::current().record("session_id", parsed.session_id.as_str());
135173

136174
match parsed.method {
175+
ClientMethod::FsReadTextFile => {
176+
let request_id = extract_request_id(&payload);
177+
match fs_read_text_file::handle(&payload, client).await {
178+
Ok(bytes) => {
179+
if let Some(reply_to) = &reply {
180+
let result =
181+
serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null);
182+
let response_bytes =
183+
ResultResponse::new(request_id, result).to_bytes().unwrap();
184+
let headers = headers_with_trace_context();
185+
if let Err(e) = nats
186+
.publish_with_headers(reply_to.clone(), headers, response_bytes)
187+
.await
188+
{
189+
error!(error = %e, "Failed to publish fs_read_text_file reply");
190+
}
191+
if let Err(e) = nats.flush().await {
192+
warn!(error = %e, "Failed to flush fs_read_text_file reply");
193+
}
194+
}
195+
}
196+
Err(e) => {
197+
let (code, message) = fs_read_text_file::error_code_and_message(&e);
198+
warn!(
199+
error = %e,
200+
session_id = %parsed.session_id,
201+
"Failed to handle fs_read_text_file"
202+
);
203+
if let Some(reply_to) = &reply {
204+
let bytes = ErrorResponse::new(request_id, code, message.as_str())
205+
.to_bytes()
206+
.unwrap_or_else(|e| {
207+
ErrorResponse::new(
208+
serde_json::Value::Null,
209+
code,
210+
format!("{} (serialization failed: {})", message, e),
211+
)
212+
.to_bytes()
213+
.unwrap()
214+
});
215+
let headers = headers_with_trace_context();
216+
if let Err(e) = nats
217+
.publish_with_headers(reply_to.clone(), headers, bytes)
218+
.await
219+
{
220+
error!(error = %e, "Failed to publish fs_read_text_file error reply");
221+
}
222+
if let Err(e) = nats.flush().await {
223+
warn!(error = %e, "Failed to flush fs_read_text_file error reply");
224+
}
225+
}
226+
}
227+
}
228+
}
137229
ClientMethod::SessionUpdate => {
138230
session_update::handle(&payload, client, &parsed.session_id).await;
139231
}
@@ -143,10 +235,11 @@ async fn dispatch_client_method<
143235
#[cfg(test)]
144236
mod tests {
145237
use super::*;
238+
use crate::jsonrpc::JsonRpcRequest;
146239
use crate::session_id::AcpSessionId;
147240
use agent_client_protocol::{
148-
ContentBlock, ContentChunk, RequestPermissionRequest, RequestPermissionResponse,
149-
SessionNotification, SessionUpdate,
241+
ContentBlock, ContentChunk, ReadTextFileRequest, RequestPermissionRequest,
242+
RequestPermissionResponse, SessionNotification, SessionUpdate,
150243
};
151244
use async_trait::async_trait;
152245
use std::cell::RefCell;
@@ -286,6 +379,7 @@ mod tests {
286379
"acp.sess-1.client.session.update",
287380
parsed,
288381
payload,
382+
None,
289383
&nats,
290384
&client,
291385
&bridge,
@@ -334,6 +428,33 @@ mod tests {
334428
assert!(nats.published_messages().is_empty());
335429
}
336430

431+
#[tokio::test]
432+
async fn process_message_backpressure_with_reply_publishes_error() {
433+
let nats = MockNatsClient::new();
434+
let bridge = make_bridge(nats.clone());
435+
let client = Rc::new(MockClient::new());
436+
let in_flight = Rc::new(Cell::new(1usize));
437+
438+
let envelope = JsonRpcRequest {
439+
jsonrpc: "2.0".to_string(),
440+
id: serde_json::json!(1),
441+
method: Some("fs/read_text_file".to_string()),
442+
params: Some(ReadTextFileRequest::new(
443+
agent_client_protocol::SessionId::from("sess1"),
444+
"/tmp/foo.txt".to_string(),
445+
)),
446+
};
447+
let payload = serde_json::to_vec(&envelope).unwrap();
448+
let msg = make_msg(
449+
"acp.sess1.client.fs.read_text_file",
450+
&payload,
451+
Some("_INBOX.reply"),
452+
);
453+
process_message(msg, &nats, client, bridge, &in_flight, 1).await;
454+
455+
assert_eq!(nats.published_messages(), vec!["_INBOX.reply"]);
456+
}
457+
337458
#[tokio::test]
338459
async fn process_message_valid_dispatch_spawns_task() {
339460
let local = tokio::task::LocalSet::new();

0 commit comments

Comments
 (0)