Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions rsworkspace/crates/acp-nats/src/agent/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::Bridge;
use crate::config::PROMPT_TIMEOUT_MESSAGE_SECS_THRESHOLD;
use crate::error::AGENT_UNAVAILABLE;
use crate::nats::{self, FlushClient, PublishClient, RequestClient, agent};
use crate::pending_prompt_waiters::PromptToken;
use crate::session_id::AcpSessionId;
use agent_client_protocol::ErrorCode;
use agent_client_protocol::{Error, PromptRequest, PromptResponse, Result};
Expand Down Expand Up @@ -49,6 +50,16 @@ fn duplicate_waiter_error<N: RequestClient + PublishClient + FlushClient, C: Get
)
}

fn add_prompt_id_to_request(args: &PromptRequest, prompt_token: PromptToken) -> PromptRequest {
let mut meta = args
.meta
.as_ref()
.cloned()
.unwrap_or_else(serde_json::Map::new);
meta.insert("prompt_id".to_string(), serde_json::json!(prompt_token.0));
args.clone().meta(meta)
}

#[instrument(
name = "acp.session.prompt",
skip(bridge, args),
Expand Down Expand Up @@ -76,19 +87,21 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
let nats = bridge.nats();
let subject = agent::session_prompt(bridge.config.acp_prefix(), session_id.as_str());

let (rx, _waiter_guard) = match bridge
let (rx, _waiter_guard, prompt_token) = match bridge
.pending_session_prompt_responses
.register_waiter(args.session_id.clone())
{
Ok(waiter) => waiter,
Err(()) => return Err(duplicate_waiter_error(bridge, &args.session_id)),
};

let request_with_token = add_prompt_id_to_request(&args, prompt_token);

let publish_options = nats::PublishOptions::builder()
.flush_policy(nats::FlushPolicy::no_retries())
.build();

if let Err(e) = nats::publish(nats, &subject, &args, publish_options).await {
if let Err(e) = nats::publish(nats, &subject, &request_with_token, publish_options).await {
bridge
.metrics
.record_error("prompt", "prompt_publish_failed");
Expand Down Expand Up @@ -139,7 +152,11 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
bridge.metrics.record_error("prompt", "prompt_timeout");
bridge
.pending_session_prompt_responses
.mark_prompt_waiter_timed_out(args.session_id.clone(), &bridge.clock);
.mark_prompt_waiter_timed_out(
args.session_id.clone(),
prompt_token,
&bridge.clock,
);

let timeout = bridge.config.prompt_timeout();
let timeout_msg = if timeout >= PROMPT_TIMEOUT_MESSAGE_SECS_THRESHOLD {
Expand Down Expand Up @@ -346,6 +363,7 @@ mod tests {
.pending_session_prompt_responses
.resolve_waiter(
&SessionId::from("s1"),
PromptToken(0),
Ok(PromptResponse::new(StopReason::EndTurn)),
);
let result = handle1.await.unwrap();
Expand All @@ -357,7 +375,7 @@ mod tests {
#[tokio::test]
async fn prompt_rejects_duplicate_waiter_for_same_session() {
let (_mock, bridge) = mock_bridge();
let (_rx, _guard) = bridge
let (_rx, _guard, _) = bridge
.pending_session_prompt_responses
.register_waiter(agent_client_protocol::SessionId::from("s1"))
.unwrap();
Expand Down Expand Up @@ -385,14 +403,15 @@ mod tests {
tokio::time::sleep(Duration::from_millis(5)).await;
handle.abort();
let _ = handle.await;
let (rx, _guard) = bridge_after
let (rx, _guard, token) = bridge_after
.pending_session_prompt_responses
.register_waiter(SessionId::from("s1"))
.expect("waiter should be free after cancelled prompt dropped guard");
bridge_after
.pending_session_prompt_responses
.resolve_waiter(
&SessionId::from("s1"),
token,
Ok(PromptResponse::new(StopReason::EndTurn)),
);
let result = rx.await.unwrap().unwrap();
Expand Down Expand Up @@ -430,13 +449,14 @@ mod tests {
#[tokio::test]
async fn prompt_resolves_waiter_with_response() {
let (_mock, bridge) = mock_bridge();
let (rx, _guard) = bridge
let (rx, _guard, token) = bridge
.pending_session_prompt_responses
.register_waiter(agent_client_protocol::SessionId::from("s1"))
.unwrap();

bridge.pending_session_prompt_responses.resolve_waiter(
&agent_client_protocol::SessionId::from("s1"),
token,
Ok(PromptResponse::new(StopReason::EndTurn)),
);

Expand All @@ -461,6 +481,7 @@ mod tests {
.pending_session_prompt_responses
.resolve_waiter(
&SessionId::from("s1"),
PromptToken(0),
Ok(PromptResponse::new(StopReason::EndTurn)),
);
let result = prompt_handle.await.unwrap();
Expand Down Expand Up @@ -513,6 +534,7 @@ mod tests {
.pending_session_prompt_responses
.resolve_waiter(
&SessionId::from("s1"),
PromptToken(0),
Ok(PromptResponse::new(StopReason::EndTurn)),
);
let result = handle1.await.unwrap();
Expand Down Expand Up @@ -631,7 +653,11 @@ mod tests {
tokio::time::sleep(Duration::from_millis(5)).await;
bridge_resolve
.pending_session_prompt_responses
.resolve_waiter(&SessionId::from("s1"), Err("parse error".to_string()));
.resolve_waiter(
&SessionId::from("s1"),
PromptToken(0),
Err("parse error".to_string()),
);
let result = prompt_handle.await.unwrap();
let err = result.unwrap_err();
assert!(err.to_string().contains("parse failed"));
Expand Down Expand Up @@ -681,6 +707,7 @@ mod tests {
.pending_session_prompt_responses
.resolve_waiter(
&SessionId::from("s1"),
PromptToken(0),
Ok(PromptResponse::new(StopReason::EndTurn)),
);
let result = prompt_handle.await.unwrap();
Expand Down
Loading