Skip to content

Commit 7e90a3c

Browse files
committed
feat(acp-nats): add ext_session_prompt_response client handler
1 parent ad1d584 commit 7e90a3c

File tree

6 files changed

+243
-1
lines changed

6 files changed

+243
-1
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
use super::Bridge;
2+
use crate::session_id::AcpSessionId;
3+
use crate::nats::{FlushClient, PublishClient, RequestClient};
4+
use agent_client_protocol::{PromptResponse, SessionId};
5+
use tracing::{instrument, warn};
6+
use trogon_std::time::GetElapsed;
7+
8+
#[instrument(
9+
name = "acp.client.ext.session.prompt_response",
10+
skip(payload, bridge),
11+
fields(session_id = %session_id)
12+
)]
13+
pub async fn handle<
14+
N: RequestClient + PublishClient + FlushClient,
15+
C: GetElapsed,
16+
>(
17+
session_id: &str,
18+
payload: &[u8],
19+
bridge: &Bridge<N, C>,
20+
) {
21+
let Ok(validated) = AcpSessionId::new(session_id) else {
22+
warn!(
23+
session_id = %session_id,
24+
"Invalid session_id in prompt response notification"
25+
);
26+
bridge
27+
.metrics
28+
.record_error("client.ext.session.prompt_response", "invalid_session_id");
29+
return;
30+
};
31+
32+
let session_id_typed: SessionId = validated.as_str().to_string().into();
33+
34+
match serde_json::from_slice::<PromptResponse>(payload) {
35+
Ok(response) => {
36+
bridge
37+
.pending_session_prompt_responses
38+
.purge_expired_timed_out_waiters(&bridge.clock);
39+
let suppress_missing_waiter_warning = bridge
40+
.pending_session_prompt_responses
41+
.should_suppress_missing_waiter_warning(&session_id_typed, &bridge.clock);
42+
43+
if !bridge
44+
.pending_session_prompt_responses
45+
.resolve_waiter(&session_id_typed, Ok(response))
46+
{
47+
if !suppress_missing_waiter_warning {
48+
warn!(
49+
session_id = %session_id,
50+
"No pending prompt response waiter found for session"
51+
);
52+
}
53+
}
54+
}
55+
Err(e) => {
56+
warn!(error = %e, session_id = %session_id, "Failed to parse prompt response");
57+
bridge
58+
.pending_session_prompt_responses
59+
.purge_expired_timed_out_waiters(&bridge.clock);
60+
let suppress_missing_waiter_warning = bridge
61+
.pending_session_prompt_responses
62+
.should_suppress_missing_waiter_warning(&session_id_typed, &bridge.clock);
63+
64+
if !bridge
65+
.pending_session_prompt_responses
66+
.resolve_waiter(&session_id_typed, Err(e.to_string()))
67+
{
68+
if !suppress_missing_waiter_warning {
69+
warn!(
70+
session_id = %session_id,
71+
"No pending prompt response waiter found for session"
72+
);
73+
}
74+
}
75+
bridge
76+
.metrics
77+
.record_error("client.ext.session.prompt_response", "prompt_response_parse_failed");
78+
}
79+
}
80+
}
81+
82+
#[cfg(test)]
83+
mod tests {
84+
use super::*;
85+
use crate::agent::Bridge;
86+
use crate::config::Config;
87+
use agent_client_protocol::StopReason;
88+
use trogon_nats::MockNatsClient;
89+
use trogon_std::time::MockClock;
90+
91+
fn make_bridge() -> Bridge<MockNatsClient, MockClock> {
92+
Bridge::new(
93+
MockNatsClient::new(),
94+
MockClock::new(),
95+
&opentelemetry::global::meter("acp-nats-test"),
96+
Config::for_test("acp"),
97+
)
98+
}
99+
100+
#[tokio::test]
101+
async fn resolves_waiter() {
102+
let bridge = make_bridge();
103+
let session_id: SessionId = "prompt-resp-001".into();
104+
105+
let (rx, _guard) = bridge
106+
.pending_session_prompt_responses
107+
.register_waiter(session_id.clone())
108+
.unwrap();
109+
110+
let response = PromptResponse::new(StopReason::EndTurn);
111+
let payload = serde_json::to_vec(&response).unwrap();
112+
113+
super::handle("prompt-resp-001", &payload, &bridge).await;
114+
115+
let result = rx
116+
.await
117+
.expect("Should receive response")
118+
.expect("Prompt response should not include error");
119+
assert_eq!(result.stop_reason, StopReason::EndTurn);
120+
}
121+
122+
#[tokio::test]
123+
async fn no_waiter_does_not_panic() {
124+
let bridge = make_bridge();
125+
126+
let response = PromptResponse::new(StopReason::EndTurn);
127+
let payload = serde_json::to_vec(&response).unwrap();
128+
129+
super::handle("no-waiter-session", &payload, &bridge).await;
130+
}
131+
132+
#[tokio::test]
133+
async fn invalid_payload_removes_waiter() {
134+
let bridge = make_bridge();
135+
let session_id: SessionId = "bad-payload-001".into();
136+
137+
let (rx, _guard) = bridge
138+
.pending_session_prompt_responses
139+
.register_waiter(session_id.clone())
140+
.unwrap();
141+
142+
super::handle("bad-payload-001", b"not json", &bridge).await;
143+
144+
let result = rx
145+
.await
146+
.expect("Should receive resolved parse error")
147+
.err()
148+
.expect("Parse failure should be forwarded to waiter");
149+
assert!(!result.is_empty(), "Expected parse error to be forwarded");
150+
}
151+
152+
#[tokio::test]
153+
async fn invalid_session_id_is_rejected() {
154+
let bridge = make_bridge();
155+
let session_id: SessionId = "valid-session".into();
156+
157+
let (rx, _guard) = bridge
158+
.pending_session_prompt_responses
159+
.register_waiter(session_id.clone())
160+
.unwrap();
161+
162+
let response = PromptResponse::new(StopReason::EndTurn);
163+
let payload = serde_json::to_vec(&response).unwrap();
164+
165+
super::handle("session.with.dots", &payload, &bridge).await;
166+
super::handle("session*wild", &payload, &bridge).await;
167+
super::handle("session id", &payload, &bridge).await;
168+
169+
assert!(
170+
bridge
171+
.pending_session_prompt_responses
172+
.has_waiter(&session_id),
173+
"invalid session IDs should not resolve valid waiter",
174+
);
175+
176+
bridge
177+
.pending_session_prompt_responses
178+
.remove_waiter_for_test(&session_id);
179+
assert!(
180+
!bridge
181+
.pending_session_prompt_responses
182+
.has_waiter(&session_id),
183+
"waiter should be removed"
184+
);
185+
drop(rx);
186+
}
187+
}

rsworkspace/crates/acp-nats/src/client/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub(crate) mod ext_session_prompt_response;
12
pub(crate) mod fs_read_text_file;
23
pub(crate) mod fs_write_text_file;
34
pub(crate) mod request_permission;
@@ -222,6 +223,21 @@ async fn dispatch_client_method<
222223
ClientMethod::SessionUpdate => {
223224
session_update::handle(&payload, ctx.client, &parsed.session_id).await;
224225
}
226+
ClientMethod::ExtSessionPromptResponse => {
227+
if reply.is_some() {
228+
warn!(
229+
session_id = %parsed.session_id,
230+
method = ?parsed.method,
231+
"Unexpected reply subject on prompt response notification"
232+
);
233+
}
234+
ext_session_prompt_response::handle(
235+
parsed.session_id.as_str(),
236+
&payload,
237+
ctx.bridge,
238+
)
239+
.await;
240+
}
225241
ClientMethod::TerminalCreate => {
226242
terminal_create::handle(
227243
&payload,

rsworkspace/crates/acp-nats/src/nats/client_method.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub enum ClientMethod {
99
TerminalOutput,
1010
TerminalRelease,
1111
TerminalWaitForExit,
12+
ExtSessionPromptResponse,
1213
}
1314

1415
impl ClientMethod {
@@ -23,6 +24,7 @@ impl ClientMethod {
2324
"client.terminal.output" => Some(Self::TerminalOutput),
2425
"client.terminal.release" => Some(Self::TerminalRelease),
2526
"client.terminal.wait_for_exit" => Some(Self::TerminalWaitForExit),
27+
"client.ext.session.prompt_response" => Some(Self::ExtSessionPromptResponse),
2628
_ => None,
2729
}
2830
}

rsworkspace/crates/acp-nats/src/nats/parsing.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ mod tests {
9292
assert_eq!(parsed.method, ClientMethod::TerminalWaitForExit);
9393
}
9494

95+
#[test]
96+
fn test_parse_ext_session_prompt_response() {
97+
let subject = "acp.sess999.client.ext.session.prompt_response";
98+
let parsed = parse_client_subject(subject).unwrap();
99+
assert_eq!(parsed.session_id.as_str(), "sess999");
100+
assert_eq!(parsed.method, ClientMethod::ExtSessionPromptResponse);
101+
}
102+
95103
#[test]
96104
fn test_parse_with_custom_prefix() {
97105
let subject = "myapp.sess123.client.session.update";

rsworkspace/crates/acp-nats/src/nats/subjects.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ pub mod client {
7878
format!("{}.{}.client.terminal.wait_for_exit", prefix, session_id)
7979
}
8080

81+
pub fn ext(prefix: &str, session_id: &str, method: &str) -> String {
82+
format!("{}.{}.client.ext.{}", prefix, session_id, method)
83+
}
84+
85+
pub fn ext_session_prompt_response(prefix: &str, session_id: &str) -> String {
86+
ext(prefix, session_id, "session.prompt_response")
87+
}
88+
8189
pub mod wildcards {
8290
pub fn all(prefix: &str) -> String {
8391
format!("{}.*.client.>", prefix)
@@ -162,6 +170,14 @@ mod tests {
162170
);
163171
}
164172

173+
#[test]
174+
fn client_ext_session_prompt_response_subject() {
175+
assert_eq!(
176+
client::ext_session_prompt_response("acp", "s1"),
177+
"acp.s1.client.ext.session.prompt_response"
178+
);
179+
}
180+
165181
#[test]
166182
fn client_wildcards_all() {
167183
assert_eq!(client::wildcards::all("foo"), "foo.*.client.>");

rsworkspace/crates/acp-nats/src/pending_prompt_waiters.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,16 @@ impl<I: Copy> PendingSessionPromptResponseWaiters<I> {
126126
});
127127
}
128128

129+
/// Returns true if a late prompt response for this session should not emit a missing-waiter warning.
130+
pub(crate) fn should_suppress_missing_waiter_warning<C: GetElapsed<Instant = I>>(
131+
&self,
132+
session_id: &SessionId,
133+
_clock: &C,
134+
) -> bool {
135+
self.timed_out.lock().unwrap().contains_key(session_id)
136+
}
137+
129138
/// Delivers a backend prompt result to the currently waiting caller for `session_id`.
130-
#[allow(dead_code)]
131139
pub fn resolve_waiter(
132140
&self,
133141
session_id: &SessionId,
@@ -152,6 +160,11 @@ impl<I: Copy> PendingSessionPromptResponseWaiters<I> {
152160
}
153161
}
154162

163+
#[cfg(test)]
164+
pub(crate) fn has_waiter(&self, session_id: &SessionId) -> bool {
165+
self.waiters.lock().unwrap().contains_key(session_id)
166+
}
167+
155168
#[cfg(test)]
156169
pub(crate) fn remove_waiter_for_test(&self, session_id: &SessionId) {
157170
self.waiters.lock().unwrap().remove(session_id);

0 commit comments

Comments
 (0)