Skip to content

Commit e387fa8

Browse files
committed
feat(acp-nats): add prompt handler (#20)
- Add prompt handler with session validation, cancel pre-flight, backpressure - Expand Bridge with CancelledSessions, PendingSessionPromptResponseWaiters - Update cancel to mark sessions cancelled and resolve pending prompt waiters - Add prompt_timeout and max_concurrent_client_tasks to config - Add session_prompt NATS subject Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent 263c14b commit e387fa8

File tree

12 files changed

+1025
-70
lines changed

12 files changed

+1025
-70
lines changed

rsworkspace/crates/AGENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Prefer domain-specific value objects over primitives (e.g. `AcpPrefix` not `String`). Each type's factory must guarantee correctness at construction—invalid instances should be unrepresentable. Validate per-type, not per-aggregate: avoid validating unrelated fields together in a single constructor.
22

3-
Every value object lives in its own file named after the type (e.g. `acp_prefix.rs`, `ext_method_name.rs`, `session_id.rs`). Never inline a value object into a config, aggregate, or service file.
3+
Every value object lives in its own file named after the type (e.g. `acp_prefix.rs`, `ext_method_name.rs`, `session_id.rs`). Never inline a value object into a config, aggregate, or service file. File layout: `src/{type_snake_case}.rs`; export in `lib.rs` as `pub use {module}::{Type, TypeError}`.
44

55
You must use the `test-support` feature to share test helpers between crates.
66
Prefer one trait per operation over a single trait with multiple operations.

rsworkspace/crates/acp-nats/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ trogon-std = { path = "../trogon-std" }
2121

2222
[dev-dependencies]
2323
opentelemetry_sdk = { version = "0.31.0", features = ["rt-tokio", "metrics", "testing"] }
24+
tokio = { version = "1.49.0", features = ["test-util"] }
2425
trogon-nats = { path = "../trogon-nats", features = ["test-support"] }
2526
trogon-std = { path = "../trogon-std", features = ["test-support"] }

rsworkspace/crates/acp-nats/src/agent/cancel.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ use agent_client_protocol::{CancelNotification, Error, ErrorCode, Result};
55
use tracing::{info, instrument, warn};
66
use trogon_std::time::GetElapsed;
77

8-
/// Publishes the cancel notification to the backend via NATS (fire-and-forget).
9-
/// The publish failure is logged and recorded as a metric but does not propagate
10-
/// to the caller, so the client always receives `Ok(())`.
8+
/// Handles cancel notification requests.
9+
///
10+
/// Validates the session ID and publishes the cancellation to the backend (fire-and-forget).
11+
/// The backend owns session state and will respond to the in-flight prompt with `stopReason: cancelled`.
12+
/// Publish failure is logged and recorded in metrics but does not propagate to the caller.
1113
#[instrument(
1214
name = "acp.session.cancel",
1315
skip(bridge, args),
@@ -25,9 +27,7 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
2527
bridge
2628
.metrics
2729
.record_request("cancel", bridge.clock.elapsed(start).as_secs_f64(), false);
28-
bridge
29-
.metrics
30-
.record_error("session_validate", "invalid_session_id");
30+
bridge.metrics.record_error("cancel", "invalid_session_id");
3131
Error::new(
3232
ErrorCode::InvalidParams.into(),
3333
format!("Invalid session ID: {}", e),
@@ -46,7 +46,7 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
4646
)
4747
.await;
4848

49-
if let Err(error) = publish_result {
49+
if let Err(error) = &publish_result {
5050
warn!(
5151
session_id = %args.session_id,
5252
error = %error,
@@ -57,9 +57,11 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
5757
.record_error("cancel", "cancel_publish_failed");
5858
}
5959

60-
bridge
61-
.metrics
62-
.record_request("cancel", bridge.clock.elapsed(start).as_secs_f64(), true);
60+
bridge.metrics.record_request(
61+
"cancel",
62+
bridge.clock.elapsed(start).as_secs_f64(),
63+
publish_result.is_ok(),
64+
);
6365

6466
Ok(())
6567
}
@@ -223,8 +225,8 @@ mod tests {
223225
"expected acp.request.count with method=cancel, success=false on validation failure"
224226
);
225227
assert!(
226-
has_error_metric(&finished_metrics, "session_validate", "invalid_session_id"),
227-
"expected acp.errors.total with operation=session_validate, reason=invalid_session_id"
228+
has_error_metric(&finished_metrics, "cancel", "invalid_session_id"),
229+
"expected acp.errors.total with operation=cancel, reason=invalid_session_id"
228230
);
229231
provider.shutdown().unwrap();
230232
}
@@ -258,8 +260,8 @@ mod tests {
258260
"expected acp.errors.total with operation=cancel, reason=cancel_publish_failed"
259261
);
260262
assert!(
261-
has_request_metric(&finished_metrics, "cancel", true),
262-
"publish failure is fire-and-forget; caller still gets Ok, so success=true"
263+
has_request_metric(&finished_metrics, "cancel", false),
264+
"request metric records publish outcome; success=false when publish fails"
263265
);
264266
provider.shutdown().unwrap();
265267
}

rsworkspace/crates/acp-nats/src/agent/load_session.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use super::Bridge;
22
use crate::acp_prefix::AcpPrefix;
3+
use crate::config::SESSION_READY_DELAY;
34
use crate::error::AGENT_UNAVAILABLE;
45
use crate::nats::{
56
self, ExtSessionReady, FlushClient, FlushPolicy, PublishClient, PublishOptions, RequestClient,
@@ -8,13 +9,10 @@ use crate::nats::{
89
use crate::session_id::AcpSessionId;
910
use crate::telemetry::metrics::Metrics;
1011
use agent_client_protocol::{Error, ErrorCode, LoadSessionRequest, LoadSessionResponse, Result};
11-
use std::time::Duration;
1212
use tracing::{info, instrument, warn};
1313
use trogon_nats::NatsError;
1414
use trogon_std::time::GetElapsed;
1515

16-
const SESSION_READY_DELAY: Duration = Duration::from_millis(100);
17-
1816
fn map_load_session_error(e: NatsError) -> Error {
1917
match &e {
2018
NatsError::Timeout { subject } => {

rsworkspace/crates/acp-nats/src/agent/mod.rs

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@ mod ext_notification;
55
mod initialize;
66
mod load_session;
77
mod new_session;
8+
mod pending_prompt_waiters;
9+
mod prompt;
810
mod set_session_mode;
911

12+
use pending_prompt_waiters::PendingSessionPromptResponseWaiters;
13+
1014
use crate::config::Config;
1115
use crate::nats::{FlushClient, PublishClient, RequestClient};
16+
use crate::prompt_slot_counter::PromptSlotCounter;
1217
use crate::telemetry::metrics::Metrics;
13-
use agent_client_protocol::ErrorCode;
1418
use agent_client_protocol::{
15-
Agent, AuthenticateRequest, AuthenticateResponse, CancelNotification, Error, ExtNotification,
19+
Agent, AuthenticateRequest, AuthenticateResponse, CancelNotification, ExtNotification,
1620
ExtRequest, ExtResponse, InitializeRequest, InitializeResponse, LoadSessionRequest,
1721
LoadSessionResponse, NewSessionRequest, NewSessionResponse, PromptRequest, PromptResponse,
1822
Result, SetSessionModeRequest, SetSessionModeResponse,
@@ -23,17 +27,22 @@ use trogon_std::time::GetElapsed;
2327
pub struct Bridge<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> {
2428
pub(crate) nats: N,
2529
pub(crate) clock: C,
26-
pub(crate) config: Config,
2730
pub(crate) metrics: Metrics,
31+
pub(crate) pending_session_prompt_responses: PendingSessionPromptResponseWaiters<C::Instant>,
32+
pub(crate) prompt_slot_counter: PromptSlotCounter,
33+
pub(crate) config: Config,
2834
}
2935

3036
impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Bridge<N, C> {
3137
pub fn new(nats: N, clock: C, meter: &Meter, config: Config) -> Self {
38+
let max_concurrent = config.max_concurrent_client_tasks();
3239
Self {
3340
nats,
3441
clock,
3542
config,
3643
metrics: Metrics::new(meter),
44+
pending_session_prompt_responses: PendingSessionPromptResponseWaiters::new(),
45+
prompt_slot_counter: PromptSlotCounter::new(max_concurrent),
3746
}
3847
}
3948

@@ -67,11 +76,8 @@ impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Agent for Br
6776
set_session_mode::handle(self, args).await
6877
}
6978

70-
async fn prompt(&self, _args: PromptRequest) -> Result<PromptResponse> {
71-
Err(Error::new(
72-
ErrorCode::InternalError.into(),
73-
"not yet implemented",
74-
))
79+
async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
80+
prompt::handle(self, args).await
7581
}
7682

7783
async fn cancel(&self, args: CancelNotification) -> Result<()> {
@@ -88,37 +94,14 @@ impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Agent for Br
8894
}
8995

9096
#[cfg(test)]
91-
mod tests {
97+
mod send_sync_tests {
9298
use super::Bridge;
93-
use crate::config::Config;
94-
use agent_client_protocol::{Agent, PromptRequest};
9599
use trogon_nats::AdvancedMockNatsClient;
100+
use trogon_std::time::SystemClock;
96101

97-
fn mock_bridge() -> Bridge<AdvancedMockNatsClient, trogon_std::time::SystemClock> {
98-
Bridge::new(
99-
AdvancedMockNatsClient::new(),
100-
trogon_std::time::SystemClock,
101-
&opentelemetry::global::meter("acp-nats-test"),
102-
Config::for_test("acp"),
103-
)
104-
}
105-
106-
#[tokio::test]
107-
async fn stub_methods_return_not_implemented() {
108-
let bridge = mock_bridge();
109-
let msg = "not yet implemented";
110-
111-
assert!(
112-
bridge
113-
.prompt(PromptRequest::new("s1", vec![]))
114-
.await
115-
.is_err()
116-
);
117-
118-
let err = bridge
119-
.prompt(PromptRequest::new("s1", vec![]))
120-
.await
121-
.unwrap_err();
122-
assert!(err.to_string().contains(msg));
102+
#[test]
103+
fn bridge_is_send_and_sync() {
104+
fn assert_send_sync<T: Send + Sync>() {}
105+
assert_send_sync::<Bridge<AdvancedMockNatsClient, SystemClock>>();
123106
}
124107
}

rsworkspace/crates/acp-nats/src/agent/new_session.rs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::Bridge;
2+
use crate::config::SESSION_READY_DELAY;
23
use crate::error::AGENT_UNAVAILABLE;
34
use crate::nats::{
45
self, ExtSessionReady, FlushClient, FlushPolicy, PublishClient, PublishOptions, RequestClient,
@@ -8,25 +9,10 @@ use crate::telemetry::metrics::Metrics;
89
use agent_client_protocol::{
910
Error, ErrorCode, NewSessionRequest, NewSessionResponse, Result, SessionId,
1011
};
11-
use std::time::Duration;
1212
use tracing::{Span, info, instrument, warn};
1313
use trogon_nats::NatsError;
1414
use trogon_std::time::GetElapsed;
1515

16-
/// Delay before publishing `session.ready` to NATS.
17-
///
18-
/// The `Agent` trait returns the response value *before* the transport layer
19-
/// serializes and writes it to the client. Without a delay the spawned task
20-
/// could publish `session.ready` to NATS before the client has received the
21-
/// `session/new` response, violating the ordering guarantee documented on
22-
/// [`ExtSessionReady`].
23-
///
24-
/// A post-send callback from the transport would be the ideal fix, but the
25-
/// external `agent_client_protocol` crate does not expose one. This constant
26-
/// delay provides a practical safety margin (serialization + write is typically
27-
/// sub-millisecond).
28-
const SESSION_READY_DELAY: Duration = Duration::from_millis(100);
29-
3016
fn map_new_session_error(e: NatsError) -> Error {
3117
match &e {
3218
NatsError::Timeout { subject } => {
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//! Waiter registry for bridging prompt request/response over NATS notifications.
2+
//!
3+
//! **When to use**
4+
//! - In the ACP prompt path where request and response are decoupled (`publish` now, response
5+
//! arrives later via `client.ext.session.prompt_response`).
6+
//! - Before publishing prompt work, so an immediate backend response cannot race ahead of waiter
7+
//! registration.
8+
//!
9+
//! **Why this exists**
10+
//! - Prompt responses are correlated by `SessionId`, not by direct request/reply transport.
11+
//! - Enforcing one active waiter per session avoids ambiguous delivery when clients duplicate
12+
//! prompt calls.
13+
//! - Timed-out sessions are tracked briefly to suppress noisy duplicate timeout-related warnings
14+
//! during late-response windows.
15+
16+
use std::collections::HashMap;
17+
use std::sync::Mutex;
18+
19+
use agent_client_protocol::{PromptResponse, SessionId};
20+
use tokio::sync::oneshot;
21+
use trogon_std::time::GetElapsed;
22+
use crate::config::PROMPT_TIMEOUT_WARNING_SUPPRESSION_WINDOW;
23+
24+
type PromptResponseReceiver = oneshot::Receiver<std::result::Result<PromptResponse, String>>;
25+
26+
/// Lifetime token for a registered session waiter.
27+
///
28+
/// Dropping the guard removes the waiter so cancellations and task aborts do not leak entries.
29+
pub(crate) struct PromptWaiterGuard<'a, I: Copy> {
30+
waiters: &'a PendingSessionPromptResponseWaiters<I>,
31+
session_id: SessionId,
32+
}
33+
34+
impl<'a, I: Copy> PromptWaiterGuard<'a, I> {
35+
fn new(waiters: &'a PendingSessionPromptResponseWaiters<I>, session_id: SessionId) -> Self {
36+
Self {
37+
waiters,
38+
session_id,
39+
}
40+
}
41+
}
42+
43+
impl<'a, I: Copy> Drop for PromptWaiterGuard<'a, I> {
44+
fn drop(&mut self) {
45+
self.waiters.remove_waiter(&self.session_id);
46+
}
47+
}
48+
49+
/// Process-local map of in-flight prompt waiters keyed by session.
50+
///
51+
/// Scope is intentionally local to this agent process; cross-process correlation belongs to NATS
52+
/// subjects and backend state.
53+
pub(crate) struct PendingSessionPromptResponseWaiters<I: Copy> {
54+
waiters:
55+
Mutex<HashMap<SessionId, oneshot::Sender<std::result::Result<PromptResponse, String>>>>,
56+
timed_out: Mutex<HashMap<SessionId, I>>,
57+
}
58+
59+
impl<I: Copy> PendingSessionPromptResponseWaiters<I> {
60+
/// Creates an empty waiter registry.
61+
pub fn new() -> Self {
62+
Self {
63+
waiters: Mutex::new(HashMap::new()),
64+
timed_out: Mutex::new(HashMap::new()),
65+
}
66+
}
67+
68+
/// Registers the receiver for the next prompt response of `session_id`.
69+
///
70+
/// Returns `Err(())` when another waiter is already active for the same session.
71+
pub fn register_waiter(
72+
&self,
73+
session_id: SessionId,
74+
) -> std::result::Result<(PromptResponseReceiver, PromptWaiterGuard<'_, I>), ()> {
75+
let (tx, rx) = oneshot::channel();
76+
let mut waiters = self.waiters.lock().unwrap();
77+
if waiters.contains_key(&session_id) {
78+
return Err(());
79+
}
80+
self.timed_out.lock().unwrap().remove(&session_id);
81+
waiters.insert(session_id.clone(), tx);
82+
Ok((rx, PromptWaiterGuard::new(self, session_id)))
83+
}
84+
85+
/// Marks a session as timed out to suppress transient duplicate warnings for late responses.
86+
pub(crate) fn mark_prompt_waiter_timed_out<C: GetElapsed<Instant = I>>(
87+
&self,
88+
session_id: SessionId,
89+
clock: &C,
90+
) {
91+
self.purge_expired_timed_out_waiters(clock);
92+
self.timed_out
93+
.lock()
94+
.unwrap()
95+
.insert(session_id, clock.now());
96+
}
97+
98+
/// Drops timeout-suppression markers after a short window.
99+
///
100+
/// This keeps suppression bounded so future requests for the same session can emit warnings
101+
/// again if they truly timeout.
102+
pub(crate) fn purge_expired_timed_out_waiters<C: GetElapsed<Instant = I>>(&self, clock: &C) {
103+
self.timed_out.lock().unwrap().retain(|_, seen_at| {
104+
clock.elapsed(*seen_at) < PROMPT_TIMEOUT_WARNING_SUPPRESSION_WINDOW
105+
});
106+
}
107+
108+
/// Delivers a backend prompt result to the currently waiting caller for `session_id`.
109+
#[allow(dead_code)]
110+
pub fn resolve_waiter(
111+
&self,
112+
session_id: &SessionId,
113+
response: std::result::Result<PromptResponse, String>,
114+
) -> bool {
115+
let sender = self.waiters.lock().unwrap().remove(session_id);
116+
self.timed_out.lock().unwrap().remove(session_id);
117+
if let Some(sender) = sender {
118+
sender.send(response).is_ok()
119+
} else {
120+
false
121+
}
122+
}
123+
124+
/// Removes a waiter for `session_id` without delivering a response.
125+
///
126+
/// Used by cancellation/drop paths where the caller is no longer waiting.
127+
pub fn remove_waiter(&self, session_id: &SessionId) {
128+
self.waiters.lock().unwrap().remove(session_id);
129+
}
130+
}

0 commit comments

Comments
 (0)