Skip to content

Commit 1e2a8f6

Browse files
authored
feat(acp-nats): add ext_session_prompt_response client handler (#31)
Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent ad1d584 commit 1e2a8f6

File tree

7 files changed

+444
-32
lines changed

7 files changed

+444
-32
lines changed

rsworkspace/crates/acp-nats/src/agent/prompt.rs

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::Bridge;
22
use crate::config::PROMPT_TIMEOUT_MESSAGE_SECS_THRESHOLD;
33
use crate::error::AGENT_UNAVAILABLE;
44
use crate::nats::{self, FlushClient, PublishClient, RequestClient, agent};
5+
use crate::pending_prompt_waiters::PromptToken;
56
use crate::session_id::AcpSessionId;
67
use agent_client_protocol::ErrorCode;
78
use agent_client_protocol::{Error, PromptRequest, PromptResponse, Result};
@@ -49,6 +50,16 @@ fn duplicate_waiter_error<N: RequestClient + PublishClient + FlushClient, C: Get
4950
)
5051
}
5152

53+
fn add_prompt_id_to_request(args: &PromptRequest, prompt_token: PromptToken) -> PromptRequest {
54+
let mut meta = args
55+
.meta
56+
.as_ref()
57+
.cloned()
58+
.unwrap_or_else(serde_json::Map::new);
59+
meta.insert("prompt_id".to_string(), serde_json::json!(prompt_token.0));
60+
args.clone().meta(meta)
61+
}
62+
5263
#[instrument(
5364
name = "acp.session.prompt",
5465
skip(bridge, args),
@@ -76,19 +87,21 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
7687
let nats = bridge.nats();
7788
let subject = agent::session_prompt(bridge.config.acp_prefix(), session_id.as_str());
7889

79-
let (rx, _waiter_guard) = match bridge
90+
let (rx, _waiter_guard, prompt_token) = match bridge
8091
.pending_session_prompt_responses
8192
.register_waiter(args.session_id.clone())
8293
{
8394
Ok(waiter) => waiter,
8495
Err(()) => return Err(duplicate_waiter_error(bridge, &args.session_id)),
8596
};
8697

98+
let request_with_token = add_prompt_id_to_request(&args, prompt_token);
99+
87100
let publish_options = nats::PublishOptions::builder()
88101
.flush_policy(nats::FlushPolicy::no_retries())
89102
.build();
90103

91-
if let Err(e) = nats::publish(nats, &subject, &args, publish_options).await {
104+
if let Err(e) = nats::publish(nats, &subject, &request_with_token, publish_options).await {
92105
bridge
93106
.metrics
94107
.record_error("prompt", "prompt_publish_failed");
@@ -139,7 +152,11 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
139152
bridge.metrics.record_error("prompt", "prompt_timeout");
140153
bridge
141154
.pending_session_prompt_responses
142-
.mark_prompt_waiter_timed_out(args.session_id.clone(), &bridge.clock);
155+
.mark_prompt_waiter_timed_out(
156+
args.session_id.clone(),
157+
prompt_token,
158+
&bridge.clock,
159+
);
143160

144161
let timeout = bridge.config.prompt_timeout();
145162
let timeout_msg = if timeout >= PROMPT_TIMEOUT_MESSAGE_SECS_THRESHOLD {
@@ -346,6 +363,7 @@ mod tests {
346363
.pending_session_prompt_responses
347364
.resolve_waiter(
348365
&SessionId::from("s1"),
366+
PromptToken(0),
349367
Ok(PromptResponse::new(StopReason::EndTurn)),
350368
);
351369
let result = handle1.await.unwrap();
@@ -357,7 +375,7 @@ mod tests {
357375
#[tokio::test]
358376
async fn prompt_rejects_duplicate_waiter_for_same_session() {
359377
let (_mock, bridge) = mock_bridge();
360-
let (_rx, _guard) = bridge
378+
let (_rx, _guard, _) = bridge
361379
.pending_session_prompt_responses
362380
.register_waiter(agent_client_protocol::SessionId::from("s1"))
363381
.unwrap();
@@ -385,14 +403,15 @@ mod tests {
385403
tokio::time::sleep(Duration::from_millis(5)).await;
386404
handle.abort();
387405
let _ = handle.await;
388-
let (rx, _guard) = bridge_after
406+
let (rx, _guard, token) = bridge_after
389407
.pending_session_prompt_responses
390408
.register_waiter(SessionId::from("s1"))
391409
.expect("waiter should be free after cancelled prompt dropped guard");
392410
bridge_after
393411
.pending_session_prompt_responses
394412
.resolve_waiter(
395413
&SessionId::from("s1"),
414+
token,
396415
Ok(PromptResponse::new(StopReason::EndTurn)),
397416
);
398417
let result = rx.await.unwrap().unwrap();
@@ -430,13 +449,14 @@ mod tests {
430449
#[tokio::test]
431450
async fn prompt_resolves_waiter_with_response() {
432451
let (_mock, bridge) = mock_bridge();
433-
let (rx, _guard) = bridge
452+
let (rx, _guard, token) = bridge
434453
.pending_session_prompt_responses
435454
.register_waiter(agent_client_protocol::SessionId::from("s1"))
436455
.unwrap();
437456

438457
bridge.pending_session_prompt_responses.resolve_waiter(
439458
&agent_client_protocol::SessionId::from("s1"),
459+
token,
440460
Ok(PromptResponse::new(StopReason::EndTurn)),
441461
);
442462

@@ -461,6 +481,7 @@ mod tests {
461481
.pending_session_prompt_responses
462482
.resolve_waiter(
463483
&SessionId::from("s1"),
484+
PromptToken(0),
464485
Ok(PromptResponse::new(StopReason::EndTurn)),
465486
);
466487
let result = prompt_handle.await.unwrap();
@@ -513,6 +534,7 @@ mod tests {
513534
.pending_session_prompt_responses
514535
.resolve_waiter(
515536
&SessionId::from("s1"),
537+
PromptToken(0),
516538
Ok(PromptResponse::new(StopReason::EndTurn)),
517539
);
518540
let result = handle1.await.unwrap();
@@ -631,7 +653,11 @@ mod tests {
631653
tokio::time::sleep(Duration::from_millis(5)).await;
632654
bridge_resolve
633655
.pending_session_prompt_responses
634-
.resolve_waiter(&SessionId::from("s1"), Err("parse error".to_string()));
656+
.resolve_waiter(
657+
&SessionId::from("s1"),
658+
PromptToken(0),
659+
Err("parse error".to_string()),
660+
);
635661
let result = prompt_handle.await.unwrap();
636662
let err = result.unwrap_err();
637663
assert!(err.to_string().contains("parse failed"));
@@ -681,6 +707,7 @@ mod tests {
681707
.pending_session_prompt_responses
682708
.resolve_waiter(
683709
&SessionId::from("s1"),
710+
PromptToken(0),
684711
Ok(PromptResponse::new(StopReason::EndTurn)),
685712
);
686713
let result = prompt_handle.await.unwrap();

0 commit comments

Comments
 (0)