Skip to content

Commit e91db52

Browse files
committed
feat(acp-nats): add new_session handler with session.ready lifecycle
- Add new_session handler aligned with initialize.rs patterns (handler-specific error mapping, single-line generics, local subject variable) - Add spawn_session_ready to publish ext.session.ready after successful session creation - Add SessionReady extension type for bridge-backend coordination - Add session_new and ext_session_ready NATS subject functions - Add record_session_created() and record_error() to Metrics - Add Bridge task tracking for session_ready publish tasks - Add comprehensive test suite mirroring initialize handler tests Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent 2f05030 commit e91db52

6 files changed

Lines changed: 401 additions & 13 deletions

File tree

rsworkspace/crates/acp-nats/src/agent/mod.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod initialize;
2+
mod new_session;
23

34
use crate::config::Config;
45
use crate::nats::{FlushClient, PublishClient, RequestClient};
@@ -48,11 +49,8 @@ impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Agent for Br
4849
))
4950
}
5051

51-
async fn new_session(&self, _args: NewSessionRequest) -> Result<NewSessionResponse> {
52-
Err(Error::new(
53-
ErrorCode::InternalError.into(),
54-
"not yet implemented",
55-
))
52+
async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
53+
new_session::handle(self, args).await
5654
}
5755

5856
async fn load_session(&self, _args: LoadSessionRequest) -> Result<LoadSessionResponse> {
@@ -109,7 +107,7 @@ mod tests {
109107
use crate::config::Config;
110108
use agent_client_protocol::{
111109
Agent, AuthenticateRequest, CancelNotification, ExtNotification, ExtRequest,
112-
LoadSessionRequest, NewSessionRequest, PromptRequest, SetSessionModeRequest,
110+
LoadSessionRequest, PromptRequest, SetSessionModeRequest,
113111
};
114112
use trogon_nats::AdvancedMockNatsClient;
115113

@@ -137,12 +135,6 @@ mod tests {
137135
.await
138136
.is_err()
139137
);
140-
assert!(
141-
bridge
142-
.new_session(NewSessionRequest::new("."))
143-
.await
144-
.is_err()
145-
);
146138
assert!(
147139
bridge
148140
.load_session(LoadSessionRequest::new("s1", "."))
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
use super::Bridge;
2+
use crate::error::AGENT_UNAVAILABLE;
3+
use crate::nats::{
4+
self, ExtSessionReady, FlushClient, FlushPolicy, PublishClient, PublishOptions, RequestClient,
5+
RetryPolicy, agent,
6+
};
7+
use crate::telemetry::metrics::Metrics;
8+
use agent_client_protocol::{
9+
Error, ErrorCode, NewSessionRequest, NewSessionResponse, Result, SessionId,
10+
};
11+
use tracing::{Span, info, instrument, warn};
12+
use trogon_nats::NatsError;
13+
use trogon_std::time::GetElapsed;
14+
15+
fn map_new_session_error(e: NatsError) -> Error {
16+
match &e {
17+
NatsError::Timeout { subject } => {
18+
warn!(subject = %subject, "new_session request timed out");
19+
Error::new(
20+
ErrorCode::Other(AGENT_UNAVAILABLE).into(),
21+
"New session request timed out; agent may be overloaded or unavailable",
22+
)
23+
}
24+
NatsError::Request { subject, error } => {
25+
warn!(subject = %subject, error = %error, "new_session NATS request failed");
26+
Error::new(
27+
ErrorCode::Other(AGENT_UNAVAILABLE).into(),
28+
format!("Agent unavailable: {}", error),
29+
)
30+
}
31+
NatsError::Serialize(inner) => {
32+
warn!(error = %inner, "failed to serialize new_session request");
33+
Error::new(
34+
ErrorCode::InternalError.into(),
35+
format!("Failed to serialize new_session request: {}", inner),
36+
)
37+
}
38+
NatsError::Deserialize(inner) => {
39+
warn!(error = %inner, "failed to deserialize new_session response");
40+
Error::new(
41+
ErrorCode::InternalError.into(),
42+
"Invalid response from agent",
43+
)
44+
}
45+
_ => {
46+
warn!(error = %e, "new_session NATS request failed");
47+
Error::new(ErrorCode::InternalError.into(), "New session request failed")
48+
}
49+
}
50+
}
51+
52+
#[instrument(
53+
name = "acp.session.new",
54+
skip(bridge, args),
55+
fields(cwd = ?args.cwd, mcp_servers = args.mcp_servers.len())
56+
)]
57+
pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapsed>(
58+
bridge: &Bridge<N, C>,
59+
args: NewSessionRequest,
60+
) -> Result<NewSessionResponse> {
61+
let start = bridge.clock.now();
62+
63+
info!(cwd = ?args.cwd, mcp_servers = args.mcp_servers.len(), "New session request");
64+
65+
let nats = bridge.nats();
66+
let subject = agent::session_new(bridge.config.acp_prefix());
67+
68+
let result = nats::request_with_timeout::<N, NewSessionRequest, NewSessionResponse>(
69+
nats,
70+
&subject,
71+
&args,
72+
bridge.config.operation_timeout,
73+
)
74+
.await
75+
.map_err(map_new_session_error);
76+
77+
if let Ok(ref response) = result {
78+
Span::current().record("session_id", response.session_id.to_string().as_str());
79+
info!(session_id = %response.session_id, "Session created");
80+
spawn_session_ready(
81+
bridge.nats.clone(),
82+
bridge.config.acp_prefix(),
83+
&response.session_id,
84+
bridge.metrics.clone(),
85+
);
86+
}
87+
88+
bridge.metrics.record_request(
89+
"new_session",
90+
bridge.clock.elapsed(start).as_secs_f64(),
91+
result.is_ok(),
92+
);
93+
94+
result
95+
}
96+
97+
// TODO: track the JoinHandle so we can drain in-flight publishes on graceful shutdown.
98+
fn spawn_session_ready<N: PublishClient + FlushClient + 'static>(
99+
nats: N,
100+
prefix: &str,
101+
session_id: &SessionId,
102+
metrics: Metrics,
103+
) {
104+
let prefix = prefix.to_owned();
105+
let session_id = session_id.clone();
106+
107+
tokio::spawn(async move {
108+
let subject = agent::ext_session_ready(&prefix, &session_id.to_string());
109+
info!(session_id = %session_id, subject = %subject, "Publishing session.ready");
110+
111+
let message = ExtSessionReady::new(session_id.clone());
112+
113+
let options = PublishOptions::builder()
114+
.publish_retry_policy(RetryPolicy::standard())
115+
.flush_policy(FlushPolicy::standard())
116+
.build();
117+
118+
if let Err(e) = nats::publish(&nats, &subject, &message, options).await {
119+
warn!(
120+
error = %e,
121+
session_id = %session_id,
122+
"Failed to publish session.ready"
123+
);
124+
metrics.record_error("session_ready", "session_ready_publish_failed");
125+
} else {
126+
info!(session_id = %session_id, "Published session.ready");
127+
}
128+
});
129+
}
130+
131+
#[cfg(test)]
132+
mod tests {
133+
use super::{Bridge, map_new_session_error};
134+
use crate::config::Config;
135+
use crate::error::AGENT_UNAVAILABLE;
136+
use agent_client_protocol::{
137+
Agent, ErrorCode, NewSessionRequest, NewSessionResponse, SessionId,
138+
};
139+
use opentelemetry::Value;
140+
use opentelemetry::metrics::MeterProvider;
141+
use opentelemetry_sdk::metrics::data::{AggregatedMetrics, MetricData};
142+
use opentelemetry_sdk::metrics::{
143+
PeriodicReader, SdkMeterProvider, in_memory_exporter::InMemoryMetricExporter,
144+
};
145+
use std::time::Duration;
146+
use trogon_nats::{AdvancedMockNatsClient, NatsError};
147+
148+
fn assert_new_session_metric_recorded(
149+
finished_metrics: &[opentelemetry_sdk::metrics::data::ResourceMetrics],
150+
expected_success: bool,
151+
) {
152+
let found = finished_metrics
153+
.iter()
154+
.flat_map(|rm| rm.scope_metrics())
155+
.any(|sm| {
156+
sm.metrics().any(|metric| {
157+
if metric.name() != "acp.request.count" {
158+
return false;
159+
}
160+
let data = metric.data();
161+
let sum = match data {
162+
AggregatedMetrics::U64(MetricData::Sum(s)) => s,
163+
_ => return false,
164+
};
165+
sum.data_points().any(|dp| {
166+
let mut method_ok = false;
167+
let mut success_ok = false;
168+
for attr in dp.attributes() {
169+
if attr.key.as_str() == "method" {
170+
method_ok = attr.value.as_str() == "new_session";
171+
} else if attr.key.as_str() == "success" {
172+
success_ok = attr.value == Value::from(expected_success);
173+
}
174+
}
175+
method_ok && success_ok
176+
})
177+
})
178+
});
179+
assert!(
180+
found,
181+
"expected acp.request.count datapoint with method=new_session, success={}",
182+
expected_success
183+
);
184+
}
185+
186+
fn mock_bridge_with_metrics() -> (
187+
AdvancedMockNatsClient,
188+
Bridge<AdvancedMockNatsClient, trogon_std::time::SystemClock>,
189+
InMemoryMetricExporter,
190+
SdkMeterProvider,
191+
) {
192+
let exporter = InMemoryMetricExporter::default();
193+
let reader = PeriodicReader::builder(exporter.clone())
194+
.with_interval(Duration::from_millis(100))
195+
.build();
196+
let provider = SdkMeterProvider::builder().with_reader(reader).build();
197+
let meter = provider.meter("acp-nats-test");
198+
199+
let mock = AdvancedMockNatsClient::new();
200+
let bridge = Bridge::new(
201+
mock.clone(),
202+
trogon_std::time::SystemClock,
203+
&meter,
204+
Config::for_test("acp"),
205+
);
206+
(mock, bridge, exporter, provider)
207+
}
208+
209+
fn mock_bridge() -> (
210+
AdvancedMockNatsClient,
211+
Bridge<AdvancedMockNatsClient, trogon_std::time::SystemClock>,
212+
) {
213+
let mock = AdvancedMockNatsClient::new();
214+
let bridge = Bridge::new(
215+
mock.clone(),
216+
trogon_std::time::SystemClock,
217+
&opentelemetry::global::meter("acp-nats-test"),
218+
Config::for_test("acp"),
219+
);
220+
(mock, bridge)
221+
}
222+
223+
fn set_json_response<T: serde::Serialize>(
224+
mock: &AdvancedMockNatsClient,
225+
subject: &str,
226+
resp: &T,
227+
) {
228+
let bytes = serde_json::to_vec(resp).unwrap();
229+
mock.set_response(subject, bytes.into());
230+
}
231+
232+
#[tokio::test]
233+
async fn new_session_forwards_request_and_returns_response() {
234+
let (mock, bridge) = mock_bridge();
235+
let session_id = SessionId::from("test-session-1");
236+
let expected = NewSessionResponse::new(session_id.clone());
237+
set_json_response(&mock, "acp.agent.session.new", &expected);
238+
239+
let request = NewSessionRequest::new(".");
240+
let result = bridge.new_session(request).await;
241+
242+
assert!(result.is_ok());
243+
let response = result.unwrap();
244+
assert_eq!(response.session_id, session_id);
245+
}
246+
247+
#[tokio::test]
248+
async fn new_session_returns_error_when_nats_request_fails() {
249+
let (mock, bridge) = mock_bridge();
250+
mock.fail_next_request();
251+
252+
let request = NewSessionRequest::new(".");
253+
let err = bridge.new_session(request).await.unwrap_err();
254+
255+
assert!(err.to_string().contains("Agent unavailable"));
256+
assert_eq!(err.code, ErrorCode::Other(AGENT_UNAVAILABLE));
257+
}
258+
259+
#[tokio::test]
260+
async fn new_session_returns_error_when_response_is_invalid_json() {
261+
let (mock, bridge) = mock_bridge();
262+
mock.set_response("acp.agent.session.new", "not json".into());
263+
264+
let request = NewSessionRequest::new(".");
265+
let err = bridge.new_session(request).await.unwrap_err();
266+
267+
assert!(err.to_string().contains("Invalid response from agent"));
268+
assert_eq!(err.code, ErrorCode::InternalError);
269+
}
270+
271+
#[tokio::test]
272+
async fn new_session_records_metrics_on_success() {
273+
let (mock, bridge, exporter, provider) = mock_bridge_with_metrics();
274+
let session_id = SessionId::from("test-session-1");
275+
set_json_response(
276+
&mock,
277+
"acp.agent.session.new",
278+
&NewSessionResponse::new(session_id),
279+
);
280+
281+
let _ = bridge.new_session(NewSessionRequest::new(".")).await;
282+
283+
provider.force_flush().unwrap();
284+
let finished_metrics = exporter.get_finished_metrics().unwrap();
285+
assert_new_session_metric_recorded(&finished_metrics, true);
286+
provider.shutdown().unwrap();
287+
}
288+
289+
#[tokio::test]
290+
async fn new_session_records_metrics_on_failure() {
291+
let (mock, bridge, exporter, provider) = mock_bridge_with_metrics();
292+
mock.fail_next_request();
293+
294+
let _ = bridge.new_session(NewSessionRequest::new(".")).await;
295+
296+
provider.force_flush().unwrap();
297+
let finished_metrics = exporter.get_finished_metrics().unwrap();
298+
assert_new_session_metric_recorded(&finished_metrics, false);
299+
provider.shutdown().unwrap();
300+
}
301+
302+
#[test]
303+
fn map_new_session_error_timeout() {
304+
let err = map_new_session_error(NatsError::Timeout {
305+
subject: "acp.agent.session.new".into(),
306+
});
307+
assert!(err.to_string().contains("timed out"));
308+
assert_eq!(err.code, ErrorCode::Other(AGENT_UNAVAILABLE));
309+
}
310+
311+
#[test]
312+
fn map_new_session_error_request() {
313+
let err = map_new_session_error(NatsError::Request {
314+
subject: "acp.agent.session.new".into(),
315+
error: "connection refused".into(),
316+
});
317+
assert!(err.to_string().contains("Agent unavailable"));
318+
assert_eq!(err.code, ErrorCode::Other(AGENT_UNAVAILABLE));
319+
}
320+
321+
#[test]
322+
fn map_new_session_error_serialize() {
323+
let serde_err = serde_json::to_vec(&FailsSerialize).unwrap_err();
324+
let err = map_new_session_error(NatsError::Serialize(serde_err));
325+
assert!(err.to_string().contains("serialize"));
326+
assert_eq!(err.code, ErrorCode::InternalError);
327+
}
328+
329+
#[test]
330+
fn map_new_session_error_deserialize() {
331+
let serde_err = serde_json::from_str::<NewSessionResponse>("{}").unwrap_err();
332+
let err = map_new_session_error(NatsError::Deserialize(serde_err));
333+
assert!(err.to_string().contains("Invalid response from agent"));
334+
assert_eq!(err.code, ErrorCode::InternalError);
335+
}
336+
337+
#[test]
338+
fn map_new_session_error_other() {
339+
let err = map_new_session_error(NatsError::Other("misc failure".into()));
340+
assert!(err.to_string().contains("New session request failed"));
341+
assert_eq!(err.code, ErrorCode::InternalError);
342+
}
343+
344+
struct FailsSerialize;
345+
impl serde::Serialize for FailsSerialize {
346+
fn serialize<S: serde::Serializer>(&self, _s: S) -> Result<S::Ok, S::Error> {
347+
Err(serde::ser::Error::custom("test serialize failure"))
348+
}
349+
}
350+
}

0 commit comments

Comments
 (0)