Skip to content

Commit 18c48fa

Browse files
committed
fix(acp-nats): move extract_request_id inside reply branch
Only parse the request ID when there is actually a reply address to send the error to. Calling it unconditionally was causing spurious "missing id" warnings for fire-and-forget notifications. Also neutralise log wording that said "on request" to cover both requests and notifications. Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent 6f6a753 commit 18c48fa

2 files changed

Lines changed: 14 additions & 148 deletions

File tree

rsworkspace/crates/acp-nats/src/client/mod.rs

Lines changed: 10 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@ pub(crate) mod session_update;
22

33
use crate::agent::Bridge;
44
use crate::in_flight_slot_guard::InFlightSlotGuard;
5-
use crate::nats::{
6-
ClientMethod, FlushClient, PublishClient, RequestClient, SubscribeClient, client,
7-
headers_with_trace_context, parse_client_subject,
8-
};
9-
use agent_client_protocol::{Client, Error, JsonRpcMessage, RequestId, Response};
5+
use crate::nats::{ClientMethod, FlushClient, PublishClient, RequestClient, SubscribeClient, client, parse_client_subject};
6+
use agent_client_protocol::Client;
107
use async_nats::Message;
118
use bytes::Bytes;
129
use futures::StreamExt;
@@ -15,28 +12,6 @@ use std::rc::Rc;
1512
use tracing::{Span, error, info, instrument, warn};
1613
use trogon_std::time::GetElapsed;
1714

18-
fn error_response_bytes(id: RequestId, error: Error) -> Bytes {
19-
let msg = JsonRpcMessage::wrap(Response::<()>::Error { id, error });
20-
// Response<()>::Error only contains integers, strings, and Option — serialization is infallible.
21-
serde_json::to_vec(&msg)
22-
.expect("Response::Error serialization is infallible")
23-
.into()
24-
}
25-
26-
fn extract_request_id(payload: &[u8]) -> RequestId {
27-
serde_json::from_slice::<serde_json::Value>(payload)
28-
.ok()
29-
.and_then(|v| {
30-
v.get("id")
31-
.and_then(|id| serde_json::from_value(id.clone()).ok())
32-
})
33-
.unwrap_or_else(|| {
34-
warn!(
35-
"Malformed or missing JSON-RPC request id in payload while handling client request"
36-
);
37-
RequestId::Null
38-
})
39-
}
4015

4116
/// Runs the client proxy, subscribing to client subjects and dispatching to handlers.
4217
///
@@ -101,27 +76,12 @@ async fn process_message<
10176
let parsed = match parse_client_subject(&subject) {
10277
Some(parsed) => parsed,
10378
None => {
104-
let request_id = extract_request_id(&msg.payload);
10579
warn!(subject = %subject, "Failed to parse client subject");
106-
if let Some(reply_to) = msg.reply.as_ref().map(|reply| reply.to_string()) {
107-
let bytes = error_response_bytes(
108-
request_id,
109-
Error::invalid_params().data("Invalid client subject"),
110-
);
111-
let headers = headers_with_trace_context();
112-
if let Err(e) = nats.publish_with_headers(reply_to, headers, bytes).await {
113-
error!(error = %e, "Failed to publish invalid subject response");
114-
}
115-
if let Err(error) = nats.flush().await {
116-
warn!(error = %error, "Failed to flush invalid subject response");
117-
}
118-
}
11980
return;
12081
}
12182
};
12283

12384
if in_flight.get() >= max_concurrent {
124-
let request_id = extract_request_id(&msg.payload);
12585
warn!(
12686
in_flight = in_flight.get(),
12787
method = ?parsed.method,
@@ -132,41 +92,20 @@ async fn process_message<
13292
.metrics
13393
.record_error("client", "client_backpressure_rejected");
13494

135-
if let Some(reply_to) = msg.reply.as_ref().map(|reply| reply.to_string()) {
136-
let bytes = error_response_bytes(
137-
request_id,
138-
Error::internal_error().data("Bridge overloaded - too many concurrent requests"),
139-
);
140-
let headers = headers_with_trace_context();
141-
if let Err(e) = nats.publish_with_headers(reply_to, headers, bytes).await {
142-
error!(error = %e, "Failed to publish backpressure response");
143-
}
144-
if let Err(error) = nats.flush().await {
145-
warn!(error = %error, "Failed to flush backpressure response");
146-
}
147-
} else {
148-
warn!(
149-
subject = %subject,
150-
in_flight = in_flight.get(),
151-
"No reply_to on request; dropping due to backpressure"
152-
);
153-
}
15495
return;
15596
}
15697

157-
let reply = msg.reply.as_ref().map(|reply| reply.to_string());
15898
let payload = msg.payload.clone();
15999
let nats = nats.clone();
160100

161101
let bridge_clone = bridge.clone();
162102
let in_flight_guard = InFlightSlotGuard::new(in_flight.clone());
163103
tokio::task::spawn_local(async move {
164104
let _in_flight_guard = in_flight_guard;
165-
handle_client_request(
105+
dispatch_client_method(
166106
&subject,
167107
parsed,
168108
payload,
169-
reply,
170109
&nats,
171110
client.as_ref(),
172111
bridge_clone.as_ref(),
@@ -175,16 +114,15 @@ async fn process_message<
175114
});
176115
}
177116

178-
#[instrument(skip(payload, reply, _nats, client, _bridge), fields(subject = %subject, session_id = tracing::field::Empty))]
179-
async fn handle_client_request<
117+
#[instrument(skip(payload, _nats, client, _bridge), fields(subject = %subject, session_id = tracing::field::Empty))]
118+
async fn dispatch_client_method<
180119
N: SubscribeClient + RequestClient + PublishClient + FlushClient,
181120
Cl: Client,
182121
C: GetElapsed,
183122
>(
184123
subject: &str,
185124
parsed: crate::nats::ParsedClientSubject,
186125
payload: Bytes,
187-
reply: Option<String>,
188126
_nats: &N,
189127
client: &Cl,
190128
_bridge: &Bridge<N, C>,
@@ -193,7 +131,7 @@ async fn handle_client_request<
193131

194132
match parsed.method {
195133
ClientMethod::SessionUpdate => {
196-
session_update::handle(&payload, client, &parsed.session_id, reply.as_deref()).await;
134+
session_update::handle(&payload, client, &parsed.session_id).await;
197135
}
198136
}
199137
}
@@ -203,8 +141,8 @@ mod tests {
203141
use super::*;
204142
use crate::session_id::AcpSessionId;
205143
use agent_client_protocol::{
206-
ContentBlock, ContentChunk, Error, RequestId, RequestPermissionRequest,
207-
RequestPermissionResponse, SessionNotification, SessionUpdate,
144+
ContentBlock, ContentChunk, RequestPermissionRequest, RequestPermissionResponse,
145+
SessionNotification, SessionUpdate,
208146
};
209147
use async_trait::async_trait;
210148
use std::cell::RefCell;
@@ -274,41 +212,9 @@ mod tests {
274212
run(nats, client, bridge).await;
275213
}
276214

277-
#[test]
278-
fn error_response_bytes_contains_code_and_data() {
279-
let bytes = error_response_bytes(
280-
RequestId::Number(1),
281-
Error::invalid_params().data("bad input"),
282-
);
283-
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
284-
assert_eq!(v["jsonrpc"], "2.0");
285-
assert_eq!(v["id"], 1);
286-
assert_eq!(v["error"]["data"], "bad input");
287-
assert_eq!(v["error"]["code"], -32602);
288-
}
289-
290-
#[test]
291-
fn extract_request_id_returns_id_from_valid_payload() {
292-
let payload = br#"{"jsonrpc":"2.0","id":42,"method":"foo"}"#;
293-
let id = extract_request_id(payload);
294-
assert_eq!(id, RequestId::Number(42));
295-
}
296-
297-
#[test]
298-
fn extract_request_id_returns_null_for_missing_id() {
299-
let payload = br#"{"jsonrpc":"2.0","method":"foo"}"#;
300-
let id = extract_request_id(payload);
301-
assert_eq!(id, RequestId::Null);
302-
}
303-
304-
#[test]
305-
fn extract_request_id_returns_null_for_invalid_json() {
306-
let id = extract_request_id(b"not json");
307-
assert_eq!(id, RequestId::Null);
308-
}
309215

310216
#[tokio::test]
311-
async fn handle_client_request_dispatches_session_update() {
217+
async fn dispatch_client_method_dispatches_session_update() {
312218
let nats = MockNatsClient::new();
313219
let bridge = Bridge::new(
314220
nats.clone(),
@@ -330,11 +236,10 @@ mod tests {
330236
method: ClientMethod::SessionUpdate,
331237
};
332238

333-
handle_client_request(
239+
dispatch_client_method(
334240
"acp.sess-1.client.session.update",
335241
parsed,
336242
payload,
337-
None,
338243
&nats,
339244
&client,
340245
&bridge,
@@ -357,20 +262,6 @@ mod tests {
357262
assert!(nats.published_messages().is_empty());
358263
}
359264

360-
#[tokio::test]
361-
async fn process_message_invalid_subject_with_reply_publishes_error() {
362-
let nats = MockNatsClient::new();
363-
let bridge = make_bridge(nats.clone());
364-
let client = Rc::new(MockClient::new());
365-
let in_flight = Rc::new(Cell::new(0usize));
366-
367-
let msg = make_msg("acp.sess.unknown.method", b"{}", Some("_INBOX.reply"));
368-
process_message(msg, &nats, client, bridge, &in_flight, 256).await;
369-
370-
let published = nats.published_messages();
371-
assert_eq!(published, vec!["_INBOX.reply"]);
372-
}
373-
374265
#[tokio::test]
375266
async fn process_message_backpressure_no_reply_does_not_publish() {
376267
let nats = MockNatsClient::new();
@@ -384,24 +275,6 @@ mod tests {
384275
assert!(nats.published_messages().is_empty());
385276
}
386277

387-
#[tokio::test]
388-
async fn process_message_backpressure_with_reply_publishes_error() {
389-
let nats = MockNatsClient::new();
390-
let bridge = make_bridge(nats.clone());
391-
let client = Rc::new(MockClient::new());
392-
let in_flight = Rc::new(Cell::new(1usize));
393-
394-
let msg = make_msg(
395-
"acp.sess1.client.session.update",
396-
b"{}",
397-
Some("_INBOX.reply"),
398-
);
399-
process_message(msg, &nats, client, bridge, &in_flight, 1).await;
400-
401-
let published = nats.published_messages();
402-
assert_eq!(published, vec!["_INBOX.reply"]);
403-
}
404-
405278
#[tokio::test]
406279
async fn process_message_valid_dispatch_spawns_task() {
407280
let local = tokio::task::LocalSet::new();

rsworkspace/crates/acp-nats/src/client/session_update.rs

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,8 @@ pub async fn handle<C: Client>(
77
payload: &[u8],
88
client: &C,
99
session_id: &AcpSessionId,
10-
reply: Option<&str>,
1110
) {
1211
info!(session_id = %session_id, "Forwarding session update to client");
13-
if reply.is_some() {
14-
warn!(
15-
session_id = %session_id,
16-
"Unexpected reply subject on notification request"
17-
);
18-
}
1912
match serde_json::from_slice::<SessionNotification>(payload) {
2013
Ok(notification) => {
2114
if &*notification.session_id.0 != session_id.as_str() {
@@ -111,15 +104,15 @@ mod tests {
111104
);
112105
let payload = serde_json::to_vec(&notification).unwrap();
113106

114-
handle(&payload, &client, &session_id("session-001"), None).await;
107+
handle(&payload, &client, &session_id("session-001")).await;
115108

116109
assert_eq!(client.notification_count(), 1);
117110
}
118111

119112
#[tokio::test]
120113
async fn session_update_invalid_payload_does_not_panic() {
121114
let client = MockClient::new();
122-
handle(b"not json", &client, &session_id("session-001"), None).await;
115+
handle(b"not json", &client, &session_id("session-001")).await;
123116
assert_eq!(client.notification_count(), 0);
124117
}
125118

@@ -132,7 +125,7 @@ mod tests {
132125
);
133126
let payload = serde_json::to_vec(&notification).unwrap();
134127

135-
handle(&payload, &client, &session_id("session-001"), None).await;
128+
handle(&payload, &client, &session_id("session-001")).await;
136129
}
137130

138131
#[tokio::test]
@@ -144,7 +137,7 @@ mod tests {
144137
);
145138
let payload = serde_json::to_vec(&notification).unwrap();
146139

147-
handle(&payload, &client, &session_id("session-001"), None).await;
140+
handle(&payload, &client, &session_id("session-001")).await;
148141

149142
assert_eq!(client.notification_count(), 0);
150143
}

0 commit comments

Comments
 (0)