Skip to content

Commit 73bcf4d

Browse files
committed
fix(acp-nats): replace unimplemented!() in session_update test mock
Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent f4759e9 commit 73bcf4d

2 files changed

Lines changed: 197 additions & 75 deletions

File tree

rsworkspace/crates/acp-nats/src/client/mod.rs

Lines changed: 193 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::nats::{
77
headers_with_trace_context, parse_client_subject,
88
};
99
use agent_client_protocol::{Client, Error, JsonRpcMessage, RequestId, Response};
10+
use async_nats::Message;
1011
use bytes::Bytes;
1112
use futures::StreamExt;
1213
use std::cell::Cell;
@@ -66,90 +67,105 @@ pub async fn run<
6667
let max_concurrent = bridge.config.max_concurrent_client_tasks();
6768

6869
while let Some(msg) = subscriber.next().await {
69-
let subject = msg.subject.to_string();
70-
71-
// Validate subject before backpressure so unrecognised methods always
72-
// get InvalidParams, not a misleading "Bridge overloaded" error.
73-
let parsed = match parse_client_subject(&subject) {
74-
Some(parsed) => parsed,
75-
None => {
76-
let request_id = extract_request_id(&msg.payload);
77-
warn!(subject = %subject, "Failed to parse client subject");
78-
if let Some(reply_to) = msg.reply.as_ref().map(|reply| reply.to_string()) {
79-
let bytes = error_response_bytes(
80-
request_id,
81-
Error::invalid_params().data("Invalid client subject"),
82-
);
83-
let headers = headers_with_trace_context();
84-
if let Err(e) = nats.publish_with_headers(reply_to, headers, bytes).await {
85-
error!(error = %e, "Failed to publish invalid subject response");
86-
}
87-
if let Err(error) = nats.flush().await {
88-
warn!(error = %error, "Failed to flush invalid subject response");
89-
}
90-
}
91-
continue;
92-
}
93-
};
70+
process_message(msg, &nats, client.clone(), bridge.clone(), &in_flight, max_concurrent)
71+
.await;
72+
}
9473

95-
if in_flight.get() >= max_concurrent {
96-
let request_id = extract_request_id(&msg.payload);
97-
warn!(
98-
in_flight = in_flight.get(),
99-
method = ?parsed.method,
100-
subject = %subject,
101-
"Client task backpressure — rejecting message"
102-
);
103-
bridge
104-
.metrics
105-
.record_error("client", "client_backpressure_rejected");
74+
info!("Client proxy subscriber ended");
75+
}
10676

77+
async fn process_message<
78+
N: SubscribeClient + RequestClient + PublishClient + FlushClient,
79+
Cl: Client + 'static,
80+
C: GetElapsed + 'static,
81+
>(
82+
msg: Message,
83+
nats: &N,
84+
client: Rc<Cl>,
85+
bridge: Rc<Bridge<N, C>>,
86+
in_flight: &Rc<Cell<usize>>,
87+
max_concurrent: usize,
88+
) {
89+
let subject = msg.subject.to_string();
90+
91+
// Validate subject before backpressure so unrecognised methods always
92+
// get InvalidParams, not a misleading "Bridge overloaded" error.
93+
let parsed = match parse_client_subject(&subject) {
94+
Some(parsed) => parsed,
95+
None => {
96+
let request_id = extract_request_id(&msg.payload);
97+
warn!(subject = %subject, "Failed to parse client subject");
10798
if let Some(reply_to) = msg.reply.as_ref().map(|reply| reply.to_string()) {
10899
let bytes = error_response_bytes(
109100
request_id,
110-
Error::internal_error()
111-
.data("Bridge overloaded - too many concurrent requests"),
101+
Error::invalid_params().data("Invalid client subject"),
112102
);
113103
let headers = headers_with_trace_context();
114104
if let Err(e) = nats.publish_with_headers(reply_to, headers, bytes).await {
115-
error!(error = %e, "Failed to publish backpressure response");
105+
error!(error = %e, "Failed to publish invalid subject response");
116106
}
117107
if let Err(error) = nats.flush().await {
118-
warn!(error = %error, "Failed to flush backpressure response");
108+
warn!(error = %error, "Failed to flush invalid subject response");
119109
}
120-
} else {
121-
warn!(
122-
subject = %subject,
123-
in_flight = in_flight.get(),
124-
"No reply_to on request; dropping due to backpressure"
125-
);
126110
}
127-
continue;
111+
return;
128112
}
113+
};
129114

130-
let reply = msg.reply.as_ref().map(|reply| reply.to_string());
131-
let payload = msg.payload.clone();
132-
let nats = nats.clone();
133-
let client = client.clone();
134-
135-
let bridge_clone = bridge.clone();
136-
let in_flight_guard = InFlightSlotGuard::new(in_flight.clone());
137-
tokio::task::spawn_local(async move {
138-
let _in_flight_guard = in_flight_guard;
139-
handle_client_request(
140-
&subject,
141-
parsed,
142-
payload,
143-
reply,
144-
&nats,
145-
client.as_ref(),
146-
bridge_clone.as_ref(),
147-
)
148-
.await;
149-
});
115+
if in_flight.get() >= max_concurrent {
116+
let request_id = extract_request_id(&msg.payload);
117+
warn!(
118+
in_flight = in_flight.get(),
119+
method = ?parsed.method,
120+
subject = %subject,
121+
"Client task backpressure — rejecting message"
122+
);
123+
bridge
124+
.metrics
125+
.record_error("client", "client_backpressure_rejected");
126+
127+
if let Some(reply_to) = msg.reply.as_ref().map(|reply| reply.to_string()) {
128+
let bytes = error_response_bytes(
129+
request_id,
130+
Error::internal_error()
131+
.data("Bridge overloaded - too many concurrent requests"),
132+
);
133+
let headers = headers_with_trace_context();
134+
if let Err(e) = nats.publish_with_headers(reply_to, headers, bytes).await {
135+
error!(error = %e, "Failed to publish backpressure response");
136+
}
137+
if let Err(error) = nats.flush().await {
138+
warn!(error = %error, "Failed to flush backpressure response");
139+
}
140+
} else {
141+
warn!(
142+
subject = %subject,
143+
in_flight = in_flight.get(),
144+
"No reply_to on request; dropping due to backpressure"
145+
);
146+
}
147+
return;
150148
}
151149

152-
info!("Client proxy subscriber ended");
150+
let reply = msg.reply.as_ref().map(|reply| reply.to_string());
151+
let payload = msg.payload.clone();
152+
let nats = nats.clone();
153+
154+
let bridge_clone = bridge.clone();
155+
let in_flight_guard = InFlightSlotGuard::new(in_flight.clone());
156+
tokio::task::spawn_local(async move {
157+
let _in_flight_guard = in_flight_guard;
158+
handle_client_request(
159+
&subject,
160+
parsed,
161+
payload,
162+
reply,
163+
&nats,
164+
client.as_ref(),
165+
bridge_clone.as_ref(),
166+
)
167+
.await;
168+
});
153169
}
154170

155171
#[instrument(skip(payload, reply, _nats, client, _bridge), fields(subject = %subject, session_id = tracing::field::Empty))]
@@ -221,15 +237,33 @@ mod tests {
221237
}
222238
}
223239

224-
#[tokio::test]
225-
async fn run_returns_early_when_subscribe_fails() {
226-
let nats = MockNatsClient::new();
227-
let bridge = Rc::new(Bridge::new(
228-
nats.clone(),
240+
fn make_msg(subject: &str, payload: &[u8], reply: Option<&str>) -> async_nats::Message {
241+
async_nats::Message {
242+
subject: subject.into(),
243+
reply: reply.map(|r| r.into()),
244+
payload: payload.to_vec().into(),
245+
headers: None,
246+
length: payload.len(),
247+
status: None,
248+
description: None,
249+
}
250+
}
251+
252+
fn make_bridge(
253+
nats: MockNatsClient,
254+
) -> Rc<Bridge<MockNatsClient, SystemClock>> {
255+
Rc::new(Bridge::new(
256+
nats,
229257
SystemClock,
230258
&opentelemetry::global::meter("acp-nats-test"),
231259
crate::config::Config::for_test("acp"),
232-
));
260+
))
261+
}
262+
263+
#[tokio::test]
264+
async fn run_returns_early_when_subscribe_fails() {
265+
let nats = MockNatsClient::new();
266+
let bridge = make_bridge(nats.clone());
233267
let client = Rc::new(MockClient::new());
234268

235269
run(nats, client, bridge).await;
@@ -304,4 +338,89 @@ mod tests {
304338

305339
assert_eq!(client.notifications.borrow().len(), 1);
306340
}
341+
342+
#[tokio::test]
343+
async fn process_message_invalid_subject_no_reply_does_not_publish() {
344+
let nats = MockNatsClient::new();
345+
let bridge = make_bridge(nats.clone());
346+
let client = Rc::new(MockClient::new());
347+
let in_flight = Rc::new(Cell::new(0usize));
348+
349+
let msg = make_msg("acp.sess.unknown.method", b"{}", None);
350+
process_message(msg, &nats, client, bridge, &in_flight, 256).await;
351+
352+
assert!(nats.published_messages().is_empty());
353+
}
354+
355+
#[tokio::test]
356+
async fn process_message_invalid_subject_with_reply_publishes_error() {
357+
let nats = MockNatsClient::new();
358+
let bridge = make_bridge(nats.clone());
359+
let client = Rc::new(MockClient::new());
360+
let in_flight = Rc::new(Cell::new(0usize));
361+
362+
let msg = make_msg("acp.sess.unknown.method", b"{}", Some("_INBOX.reply"));
363+
process_message(msg, &nats, client, bridge, &in_flight, 256).await;
364+
365+
let published = nats.published_messages();
366+
assert_eq!(published, vec!["_INBOX.reply"]);
367+
}
368+
369+
#[tokio::test]
370+
async fn process_message_backpressure_no_reply_does_not_publish() {
371+
let nats = MockNatsClient::new();
372+
let bridge = make_bridge(nats.clone());
373+
let client = Rc::new(MockClient::new());
374+
let in_flight = Rc::new(Cell::new(1usize));
375+
376+
let msg = make_msg("acp.sess1.client.session.update", b"{}", None);
377+
process_message(msg, &nats, client, bridge, &in_flight, 1).await;
378+
379+
assert!(nats.published_messages().is_empty());
380+
}
381+
382+
#[tokio::test]
383+
async fn process_message_backpressure_with_reply_publishes_error() {
384+
let nats = MockNatsClient::new();
385+
let bridge = make_bridge(nats.clone());
386+
let client = Rc::new(MockClient::new());
387+
let in_flight = Rc::new(Cell::new(1usize));
388+
389+
let msg = make_msg(
390+
"acp.sess1.client.session.update",
391+
b"{}",
392+
Some("_INBOX.reply"),
393+
);
394+
process_message(msg, &nats, client, bridge, &in_flight, 1).await;
395+
396+
let published = nats.published_messages();
397+
assert_eq!(published, vec!["_INBOX.reply"]);
398+
}
399+
400+
#[tokio::test]
401+
async fn process_message_valid_dispatch_spawns_task() {
402+
let local = tokio::task::LocalSet::new();
403+
local
404+
.run_until(async {
405+
let nats = MockNatsClient::new();
406+
let bridge = make_bridge(nats.clone());
407+
let client = Rc::new(MockClient::new());
408+
let in_flight = Rc::new(Cell::new(0usize));
409+
410+
let notification = SessionNotification::new(
411+
"sess1",
412+
SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::from("hi"))),
413+
);
414+
let payload = serde_json::to_vec(&notification).unwrap();
415+
416+
let msg = make_msg("acp.sess1.client.session.update", &payload, None);
417+
process_message(msg, &nats, client.clone(), bridge, &in_flight, 256).await;
418+
419+
// Yield to allow the spawned local task to run.
420+
tokio::task::yield_now().await;
421+
422+
assert_eq!(client.notifications.borrow().len(), 1);
423+
})
424+
.await;
425+
}
307426
}

rsworkspace/crates/acp-nats/src/client/session_update.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ mod tests {
9595
&self,
9696
_: RequestPermissionRequest,
9797
) -> Result<RequestPermissionResponse, agent_client_protocol::Error> {
98-
unimplemented!()
98+
Err(agent_client_protocol::Error::new(
99+
-32603,
100+
"not implemented in test mock",
101+
))
99102
}
100103
}
101104

0 commit comments

Comments
 (0)