Skip to content

Commit ce67f76

Browse files
committed
feat(acp-nats): add prompt handler (#20)
- Add prompt handler with session validation, cancel pre-flight, backpressure - Expand Bridge with CancelledSessions, PendingSessionPromptResponseWaiters - Update cancel to mark sessions cancelled and resolve pending prompt waiters - Add prompt_timeout and max_concurrent_client_tasks to config - Add session_prompt NATS subject Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent 263c14b commit ce67f76

9 files changed

Lines changed: 919 additions & 71 deletions

File tree

rsworkspace/crates/AGENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Prefer domain-specific value objects over primitives (e.g. `AcpPrefix` not `String`). Each type's factory must guarantee correctness at construction—invalid instances should be unrepresentable. Validate per-type, not per-aggregate: avoid validating unrelated fields together in a single constructor.
22

3-
Every value object lives in its own file named after the type (e.g. `acp_prefix.rs`, `ext_method_name.rs`, `session_id.rs`). Never inline a value object into a config, aggregate, or service file.
3+
Every value object lives in its own file named after the type (e.g. `acp_prefix.rs`, `ext_method_name.rs`, `session_id.rs`). Never inline a value object into a config, aggregate, or service file. File layout: `src/{type_snake_case}.rs`; export in `lib.rs` as `pub use {module}::{Type, TypeError}`.
44

55
You must use the `test-support` feature to share test helpers between crates.
66
Prefer one trait per operation over a single trait with multiple operations.

rsworkspace/crates/acp-nats/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ async-nats = "0.45.0"
1313
async-trait = "0.1.89"
1414
serde = { version = "1.0.228", features = ["derive"] }
1515
serde_json = "1.0.149"
16-
tokio = { version = "1.49.0", features = ["rt", "macros", "sync", "time"] }
16+
tokio = { version = "1.49.0", features = ["rt", "macros", "sync", "time", "test-util"] }
1717
tracing = "0.1.44"
1818

1919
trogon-nats = { path = "../trogon-nats" }

rsworkspace/crates/acp-nats/src/agent/cancel.rs

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
use super::Bridge;
22
use crate::nats::{self, FlushClient, PublishClient, RequestClient, agent};
33
use crate::session_id::AcpSessionId;
4-
use agent_client_protocol::{CancelNotification, Error, ErrorCode, Result};
4+
use agent_client_protocol::{CancelNotification, PromptResponse, Result, StopReason};
5+
use agent_client_protocol::{Error, ErrorCode};
56
use tracing::{info, instrument, warn};
67
use trogon_std::time::GetElapsed;
78

8-
/// Publishes the cancel notification to the backend via NATS (fire-and-forget).
9-
/// The publish failure is logged and recorded as a metric but does not propagate
10-
/// to the caller, so the client always receives `Ok(())`.
9+
/// Handles cancel notification requests.
10+
///
11+
/// Marks the session as cancelled, resolves any pending prompt waiters, and publishes
12+
/// the cancellation to the backend. The backend publish is fire-and-forget - if it fails,
13+
/// the error is logged and recorded in metrics, but the method still returns `Ok(())`.
1114
#[instrument(
1215
name = "acp.session.cancel",
1316
skip(bridge, args),
@@ -25,15 +28,22 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
2528
bridge
2629
.metrics
2730
.record_request("cancel", bridge.clock.elapsed(start).as_secs_f64(), false);
28-
bridge
29-
.metrics
30-
.record_error("session_validate", "invalid_session_id");
31+
bridge.metrics.record_error("cancel", "invalid_session_id");
3132
Error::new(
3233
ErrorCode::InvalidParams.into(),
3334
format!("Invalid session ID: {}", e),
3435
)
3536
})?;
3637

38+
bridge
39+
.cancelled_sessions
40+
.mark_cancelled(args.session_id.clone(), &bridge.clock);
41+
42+
bridge.pending_session_prompt_responses.resolve_waiter(
43+
&args.session_id,
44+
Ok(PromptResponse::new(StopReason::Cancelled)),
45+
);
46+
3747
let subject = agent::session_cancel(bridge.config.acp_prefix(), &args.session_id.to_string());
3848

3949
let publish_result = nats::publish(
@@ -46,20 +56,26 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
4656
)
4757
.await;
4858

49-
if let Err(error) = publish_result {
50-
warn!(
51-
session_id = %args.session_id,
52-
error = %error,
53-
"Failed to publish cancel notification to backend"
54-
);
55-
bridge
56-
.metrics
57-
.record_error("cancel", "cancel_publish_failed");
58-
}
59+
let publish_ok = match publish_result {
60+
Ok(()) => true,
61+
Err(error) => {
62+
warn!(
63+
session_id = %args.session_id,
64+
error = %error,
65+
"Failed to publish cancel notification to backend"
66+
);
67+
bridge
68+
.metrics
69+
.record_error("cancel", "cancel_publish_failed");
70+
false
71+
}
72+
};
5973

60-
bridge
61-
.metrics
62-
.record_request("cancel", bridge.clock.elapsed(start).as_secs_f64(), true);
74+
bridge.metrics.record_request(
75+
"cancel",
76+
bridge.clock.elapsed(start).as_secs_f64(),
77+
publish_ok,
78+
);
6379

6480
Ok(())
6581
}
@@ -183,6 +199,23 @@ mod tests {
183199
.is_some()
184200
}
185201

202+
#[tokio::test]
203+
async fn cancel_resolves_pending_prompt_waiter_with_cancelled() {
204+
let (_mock, bridge) = mock_bridge();
205+
let rx = bridge
206+
.pending_session_prompt_responses
207+
.register_waiter(agent_client_protocol::SessionId::from("s1"))
208+
.unwrap();
209+
210+
let _ = bridge.cancel(CancelNotification::new("s1")).await;
211+
212+
let result = rx.await.unwrap().unwrap();
213+
assert_eq!(
214+
result.stop_reason,
215+
agent_client_protocol::StopReason::Cancelled
216+
);
217+
}
218+
186219
#[tokio::test]
187220
async fn cancel_publishes_to_correct_subject() {
188221
let (mock, bridge) = mock_bridge();
@@ -223,8 +256,8 @@ mod tests {
223256
"expected acp.request.count with method=cancel, success=false on validation failure"
224257
);
225258
assert!(
226-
has_error_metric(&finished_metrics, "session_validate", "invalid_session_id"),
227-
"expected acp.errors.total with operation=session_validate, reason=invalid_session_id"
259+
has_error_metric(&finished_metrics, "cancel", "invalid_session_id"),
260+
"expected acp.errors.total with operation=cancel, reason=invalid_session_id"
228261
);
229262
provider.shutdown().unwrap();
230263
}
@@ -258,8 +291,8 @@ mod tests {
258291
"expected acp.errors.total with operation=cancel, reason=cancel_publish_failed"
259292
);
260293
assert!(
261-
has_request_metric(&finished_metrics, "cancel", true),
262-
"publish failure is fire-and-forget; caller still gets Ok, so success=true"
294+
has_request_metric(&finished_metrics, "cancel", false),
295+
"request metric records publish outcome; success=false when publish fails"
263296
);
264297
provider.shutdown().unwrap();
265298
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//! Pre-flight checks to avoid sending prompt requests for cancelled sessions.
2+
//!
3+
//! Entries are evicted after `CANCELLED_SESSION_TTL` to prevent unbounded growth
4+
//! when sessions are cancelled but never receive a subsequent prompt.
5+
6+
use std::cell::{Cell, RefCell};
7+
use std::collections::HashMap;
8+
use std::time::Duration;
9+
10+
use agent_client_protocol::SessionId;
11+
use trogon_std::time::GetElapsed;
12+
13+
const CANCELLED_SESSION_TTL: Duration = Duration::from_secs(300); // 5 minutes
14+
const CLEANUP_EVERY: usize = 16;
15+
16+
/// Pre-flight checks to avoid sending prompt requests for cancelled sessions.
17+
///
18+
/// Entries are evicted after `CANCELLED_SESSION_TTL` to prevent unbounded growth
19+
/// when sessions are cancelled but never receive a subsequent prompt.
20+
pub(crate) struct CancelledSessions<I: Copy> {
21+
map: RefCell<HashMap<SessionId, I>>,
22+
cleanup_counter: Cell<usize>,
23+
}
24+
25+
impl<I: Copy> CancelledSessions<I> {
26+
pub fn new() -> Self {
27+
Self {
28+
map: RefCell::new(HashMap::new()),
29+
cleanup_counter: Cell::new(0),
30+
}
31+
}
32+
33+
pub fn mark_cancelled<C: GetElapsed<Instant = I>>(&self, session_id: SessionId, clock: &C) {
34+
let mut map = self.map.borrow_mut();
35+
map.insert(session_id, clock.now());
36+
let cleanup_due_count = self.cleanup_counter.get().wrapping_add(1);
37+
self.cleanup_counter.set(cleanup_due_count);
38+
if cleanup_due_count.is_multiple_of(CLEANUP_EVERY) {
39+
map.retain(|_, ts| clock.elapsed(*ts) < CANCELLED_SESSION_TTL);
40+
}
41+
}
42+
43+
/// Atomically checks if a session is cancelled and clears it if so.
44+
///
45+
/// Returns `Some(())` if the session was cancelled (and has now been cleared),
46+
/// or `None` if the session was not found or has expired.
47+
pub fn take_if_cancelled<C: GetElapsed<Instant = I>>(
48+
&self,
49+
session_id: &SessionId,
50+
clock: &C,
51+
) -> Option<()> {
52+
let mut map = self.map.borrow_mut();
53+
54+
let is_valid = map
55+
.get(session_id)
56+
.is_some_and(|ts| clock.elapsed(*ts) < CANCELLED_SESSION_TTL);
57+
58+
if is_valid {
59+
map.remove(session_id);
60+
Some(())
61+
} else {
62+
map.remove(session_id);
63+
None
64+
}
65+
}
66+
}
Lines changed: 43 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,79 @@
1+
// Safety: `CancelledSessions` and `PendingSessionPromptResponseWaiters` use `RefCell`
2+
// for interior mutability. This is sound because the `Bridge` must be driven from a
3+
// single task (or a single-threaded `LocalSet`). The fire-and-forget publish in
4+
// `new_session` uses `tokio::spawn` and captures only cloned, `Send` values —
5+
// it never touches the `RefCell` fields.
6+
17
mod authenticate;
28
mod cancel;
9+
mod cancelled_sessions;
310
mod ext_method;
411
mod ext_notification;
512
mod initialize;
613
mod load_session;
714
mod new_session;
15+
mod pending_prompt_waiters;
16+
mod prompt;
817
mod set_session_mode;
918

19+
use cancelled_sessions::CancelledSessions;
20+
use pending_prompt_waiters::PendingSessionPromptResponseWaiters;
21+
1022
use crate::config::Config;
1123
use crate::nats::{FlushClient, PublishClient, RequestClient};
1224
use crate::telemetry::metrics::Metrics;
13-
use agent_client_protocol::ErrorCode;
1425
use agent_client_protocol::{
15-
Agent, AuthenticateRequest, AuthenticateResponse, CancelNotification, Error, ExtNotification,
26+
Agent, AuthenticateRequest, AuthenticateResponse, CancelNotification, ExtNotification,
1627
ExtRequest, ExtResponse, InitializeRequest, InitializeResponse, LoadSessionRequest,
1728
LoadSessionResponse, NewSessionRequest, NewSessionResponse, PromptRequest, PromptResponse,
1829
Result, SetSessionModeRequest, SetSessionModeResponse,
1930
};
2031
use opentelemetry::metrics::Meter;
32+
use std::cell::Cell;
33+
use std::marker::PhantomData;
2134
use trogon_std::time::GetElapsed;
2235

2336
pub struct Bridge<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> {
2437
pub(crate) nats: N,
2538
pub(crate) clock: C,
26-
pub(crate) config: Config,
2739
pub(crate) metrics: Metrics,
40+
pub(crate) cancelled_sessions: CancelledSessions<C::Instant>,
41+
pub(crate) pending_session_prompt_responses: PendingSessionPromptResponseWaiters<C::Instant>,
42+
pub(crate) config: Config,
43+
_not_send_sync: PhantomData<std::rc::Rc<()>>,
44+
agent_prompt_requests_in_flight: Cell<usize>,
2845
}
2946

3047
impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Bridge<N, C> {
3148
pub fn new(nats: N, clock: C, meter: &Meter, config: Config) -> Self {
3249
Self {
3350
nats,
3451
clock,
35-
config,
3652
metrics: Metrics::new(meter),
53+
cancelled_sessions: CancelledSessions::new(),
54+
pending_session_prompt_responses: PendingSessionPromptResponseWaiters::new(),
55+
config,
56+
_not_send_sync: PhantomData,
57+
agent_prompt_requests_in_flight: Cell::new(0),
58+
}
59+
}
60+
61+
pub(crate) fn try_acquire_prompt_slot(&self) -> bool {
62+
let max = self.config.max_concurrent_client_tasks();
63+
if self.agent_prompt_requests_in_flight.get() >= max {
64+
false
65+
} else {
66+
self.agent_prompt_requests_in_flight
67+
.set(self.agent_prompt_requests_in_flight.get() + 1);
68+
true
3769
}
3870
}
3971

72+
pub(crate) fn release_prompt_slot(&self) {
73+
self.agent_prompt_requests_in_flight
74+
.set(self.agent_prompt_requests_in_flight.get().saturating_sub(1));
75+
}
76+
4077
pub(crate) fn nats(&self) -> &N {
4178
&self.nats
4279
}
@@ -67,11 +104,8 @@ impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Agent for Br
67104
set_session_mode::handle(self, args).await
68105
}
69106

70-
async fn prompt(&self, _args: PromptRequest) -> Result<PromptResponse> {
71-
Err(Error::new(
72-
ErrorCode::InternalError.into(),
73-
"not yet implemented",
74-
))
107+
async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
108+
prompt::handle(self, args).await
75109
}
76110

77111
async fn cancel(&self, args: CancelNotification) -> Result<()> {
@@ -86,39 +120,3 @@ impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Agent for Br
86120
ext_notification::handle(self, args).await
87121
}
88122
}
89-
90-
#[cfg(test)]
91-
mod tests {
92-
use super::Bridge;
93-
use crate::config::Config;
94-
use agent_client_protocol::{Agent, PromptRequest};
95-
use trogon_nats::AdvancedMockNatsClient;
96-
97-
fn mock_bridge() -> Bridge<AdvancedMockNatsClient, trogon_std::time::SystemClock> {
98-
Bridge::new(
99-
AdvancedMockNatsClient::new(),
100-
trogon_std::time::SystemClock,
101-
&opentelemetry::global::meter("acp-nats-test"),
102-
Config::for_test("acp"),
103-
)
104-
}
105-
106-
#[tokio::test]
107-
async fn stub_methods_return_not_implemented() {
108-
let bridge = mock_bridge();
109-
let msg = "not yet implemented";
110-
111-
assert!(
112-
bridge
113-
.prompt(PromptRequest::new("s1", vec![]))
114-
.await
115-
.is_err()
116-
);
117-
118-
let err = bridge
119-
.prompt(PromptRequest::new("s1", vec![]))
120-
.await
121-
.unwrap_err();
122-
assert!(err.to_string().contains(msg));
123-
}
124-
}

0 commit comments

Comments
 (0)