Skip to content

Commit a86f530

Browse files
committed
feat(acp-nats, acp-nats-agent): wire JetStream into bridge and agent
Wire Layer 3 of JetStream integration — the actual message flow. JsMessage is now a zero-cost trait abstraction: NatsJsMessage wraps real jetstream::Message for production, MockJsMessage records signals for testing. All dispatch and prompt code is generic over the trait. Bridge prompt handler branches on bridge.js(): when JetStream is available, publishes to COMMANDS stream and consumes from NOTIFICATIONS/RESPONSES streams instead of subscribe-before-publish. Session.ready skips the 100ms sleep when JetStream is available. Agent library gains serve_js() alongside serve() with full ack/nak/term signal handling. Runs in parallel via with_jetstream(). Binary crates (stdio, ws) create NatsJetStreamClient and pass to Bridge::with_jetstream(). () implements JetStream traits for backward compatibility. Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent 43f9048 commit a86f530

15 files changed

Lines changed: 826 additions & 158 deletions

File tree

rsworkspace/crates/acp-nats-agent/src/connection.rs

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ use trogon_nats::{FlushClient, PublishClient, RequestClient, SubscribeClient};
2020

2121
pub enum ConnectionError {
2222
Subscribe(Box<dyn std::error::Error + Send + Sync>),
23+
JetStream(Box<dyn std::error::Error + Send + Sync>),
2324
}
2425

2526
impl std::fmt::Debug for ConnectionError {
2627
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2728
match self {
2829
Self::Subscribe(e) => f.debug_tuple("Subscribe").field(e).finish(),
30+
Self::JetStream(e) => f.debug_tuple("JetStream").field(e).finish(),
2931
}
3032
}
3133
}
@@ -34,6 +36,7 @@ impl std::fmt::Display for ConnectionError {
3436
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3537
match self {
3638
Self::Subscribe(e) => write!(f, "failed to subscribe: {}", e),
39+
Self::JetStream(e) => write!(f, "jetstream error: {}", e),
3740
}
3841
}
3942
}
@@ -95,6 +98,47 @@ where
9598
(conn, io_task)
9699
}
97100

101+
pub fn with_jetstream<J>(
102+
agent: impl Agent + 'static,
103+
nats: N,
104+
js: J,
105+
acp_prefix: AcpPrefix,
106+
spawn: impl Fn(LocalBoxFuture<'static, ()>) + Copy + 'static,
107+
) -> (
108+
Self,
109+
impl std::future::Future<Output = Result<(), ConnectionError>>,
110+
)
111+
where
112+
J: JetStreamConsumerFactory + 'static,
113+
{
114+
let nats_for_serve = nats.clone();
115+
let nats_for_js = nats.clone();
116+
let prefix = acp_prefix.as_str().to_string();
117+
let prefix_js = prefix.clone();
118+
119+
let io_task = async move {
120+
let (agent1, agent2) = {
121+
let agent = Rc::new(agent);
122+
(agent.clone(), agent)
123+
};
124+
125+
let core = serve(agent1, nats_for_serve, &prefix, spawn);
126+
let jetstream = serve_js(agent2, nats_for_js, js, &prefix_js, spawn);
127+
128+
tokio::select! {
129+
result = core => result,
130+
result = jetstream => result,
131+
}
132+
};
133+
134+
let conn = Self {
135+
nats,
136+
acp_prefix,
137+
operation_timeout: DEFAULT_OPERATION_TIMEOUT,
138+
};
139+
(conn, io_task)
140+
}
141+
98142
pub fn client_for_session(&self, session_id: AcpSessionId) -> NatsClientProxy<N> {
99143
NatsClientProxy::new(
100144
self.nats.clone(),
@@ -313,6 +357,253 @@ where
313357
.map_err(DispatchError::NotificationHandler)
314358
}
315359

360+
use trogon_nats::jetstream::{JetStreamConsumer as _, JetStreamConsumerFactory, JsMessage};
361+
362+
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(15);
363+
364+
async fn handle_request_with_keepalive<N, Resp, ReqT, F, M>(
365+
msg: &Message,
366+
nats: &N,
367+
js_msg: &M,
368+
handler: impl FnOnce(ReqT) -> F,
369+
) -> Result<(), DispatchError>
370+
where
371+
N: PublishClient + FlushClient,
372+
ReqT: serde::de::DeserializeOwned,
373+
F: std::future::Future<Output = agent_client_protocol::Result<Resp>>,
374+
Resp: serde::Serialize,
375+
M: JsMessage,
376+
{
377+
let reply_to = msg.reply.as_deref().ok_or(DispatchError::NoReplySubject)?;
378+
379+
let request: ReqT = match serde_json::from_slice(&msg.payload) {
380+
Ok(req) => req,
381+
Err(e) => {
382+
let error = agent_client_protocol::Error::new(
383+
agent_client_protocol::ErrorCode::InvalidParams.into(),
384+
format!("Failed to deserialize request: {}", e),
385+
);
386+
let _ = reply(nats, reply_to, &error).await;
387+
return Err(DispatchError::DeserializeRequest(e));
388+
}
389+
};
390+
391+
let handler_fut = handler(request);
392+
tokio::pin!(handler_fut);
393+
394+
let mut keepalive = tokio::time::interval(KEEPALIVE_INTERVAL);
395+
keepalive.tick().await;
396+
397+
loop {
398+
tokio::select! {
399+
result = &mut handler_fut => {
400+
return match result {
401+
Ok(resp) => reply(nats, reply_to, &resp).await,
402+
Err(err) => reply(nats, reply_to, &err).await,
403+
};
404+
}
405+
_ = keepalive.tick() => {
406+
if let Err(e) = js_msg.in_progress().await {
407+
warn!(error = %e, "Failed to send in_progress keepalive");
408+
}
409+
}
410+
}
411+
}
412+
}
413+
414+
async fn serve_js<N, J, A>(
415+
agent: A,
416+
nats: N,
417+
js: J,
418+
prefix: &str,
419+
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
420+
) -> Result<(), ConnectionError>
421+
where
422+
N: PublishClient + FlushClient + Clone + 'static,
423+
J: JetStreamConsumerFactory + 'static,
424+
A: Agent + 'static,
425+
{
426+
let stream_name = acp_nats::jetstream::streams::commands_stream_name(prefix);
427+
let config = acp_nats::jetstream::consumers::commands_observer();
428+
429+
info!(stream = %stream_name, "Starting JetStream consumer for COMMANDS stream");
430+
431+
let consumer = js
432+
.create_consumer(&stream_name, config)
433+
.await
434+
.map_err(|e| ConnectionError::JetStream(Box::new(e)))?;
435+
436+
let mut messages = consumer
437+
.messages()
438+
.await
439+
.map_err(|e| ConnectionError::JetStream(Box::new(e)))?;
440+
441+
let agent = Rc::new(agent);
442+
let nats = Rc::new(nats);
443+
444+
while let Some(msg_result) = messages.next().await {
445+
match msg_result {
446+
Ok(js_msg) => {
447+
let agent = agent.clone();
448+
let nats = nats.clone();
449+
spawn(Box::pin(async move {
450+
dispatch_js_message(js_msg, agent.as_ref(), nats.as_ref()).await;
451+
}));
452+
}
453+
Err(e) => {
454+
warn!(error = %e, "JetStream consumer error");
455+
}
456+
}
457+
}
458+
459+
info!("JetStream COMMANDS consumer ended");
460+
Ok(())
461+
}
462+
463+
async fn dispatch_js_message<N: PublishClient + FlushClient, A: Agent, M: JsMessage>(
464+
js_msg: M,
465+
agent: &A,
466+
nats: &N,
467+
) {
468+
let subject = js_msg.subject().to_string();
469+
let msg = Message {
470+
subject: subject.as_str().into(),
471+
reply: js_msg.reply().map(|s| s.into()),
472+
payload: js_msg.payload().clone(),
473+
headers: js_msg.headers().cloned(),
474+
status: None,
475+
description: None,
476+
length: js_msg.payload().len(),
477+
};
478+
let subject = msg.subject.as_str();
479+
480+
let parsed = match parse_agent_subject(subject) {
481+
Some(p) => p,
482+
None => {
483+
if let Err(e) = js_msg.term().await {
484+
warn!(error = %e, subject, "Failed to term unknown subject");
485+
}
486+
return;
487+
}
488+
};
489+
490+
let result = match parsed.method {
491+
AgentMethod::Initialize => {
492+
handle_request(&msg, nats, |req: InitializeRequest| agent.initialize(req)).await
493+
}
494+
AgentMethod::Authenticate => {
495+
handle_request(&msg, nats, |req: AuthenticateRequest| {
496+
agent.authenticate(req)
497+
})
498+
.await
499+
}
500+
AgentMethod::SessionNew => {
501+
handle_request(&msg, nats, |req: NewSessionRequest| agent.new_session(req)).await
502+
}
503+
AgentMethod::SessionList => {
504+
handle_request(&msg, nats, |req: ListSessionsRequest| {
505+
agent.list_sessions(req)
506+
})
507+
.await
508+
}
509+
AgentMethod::SessionLoad => {
510+
handle_request(&msg, nats, |req: LoadSessionRequest| {
511+
agent.load_session(req)
512+
})
513+
.await
514+
}
515+
AgentMethod::SessionPrompt => {
516+
handle_request_with_keepalive(&msg, nats, &js_msg, |req: PromptRequest| {
517+
agent.prompt(req)
518+
})
519+
.await
520+
}
521+
AgentMethod::SessionCancel => {
522+
handle_notification(&msg, |req: CancelNotification| agent.cancel(req)).await
523+
}
524+
AgentMethod::SessionSetMode => {
525+
handle_request(&msg, nats, |req: SetSessionModeRequest| {
526+
agent.set_session_mode(req)
527+
})
528+
.await
529+
}
530+
AgentMethod::SessionSetConfigOption => {
531+
handle_request(&msg, nats, |req: SetSessionConfigOptionRequest| {
532+
agent.set_session_config_option(req)
533+
})
534+
.await
535+
}
536+
AgentMethod::SessionSetModel => {
537+
handle_request(&msg, nats, |req: SetSessionModelRequest| {
538+
agent.set_session_model(req)
539+
})
540+
.await
541+
}
542+
AgentMethod::SessionFork => {
543+
handle_request(&msg, nats, |req: ForkSessionRequest| {
544+
agent.fork_session(req)
545+
})
546+
.await
547+
}
548+
AgentMethod::SessionResume => {
549+
handle_request(&msg, nats, |req: ResumeSessionRequest| {
550+
agent.resume_session(req)
551+
})
552+
.await
553+
}
554+
AgentMethod::SessionClose => {
555+
handle_request(&msg, nats, |req: CloseSessionRequest| {
556+
agent.close_session(req)
557+
})
558+
.await
559+
}
560+
AgentMethod::Ext(_) => {
561+
if msg.reply.is_some() {
562+
handle_request(&msg, nats, |req: ExtRequest| agent.ext_method(req)).await
563+
} else {
564+
handle_notification(&msg, |req: ExtNotification| agent.ext_notification(req)).await
565+
}
566+
}
567+
};
568+
569+
match &result {
570+
Ok(()) => {
571+
if let Err(e) = js_msg.ack().await {
572+
warn!(subject, error = %e, "Failed to ack JetStream message");
573+
}
574+
}
575+
Err(DispatchError::DeserializeRequest(_) | DispatchError::DeserializeNotification(_)) => {
576+
if let Err(e) = js_msg.term().await {
577+
warn!(subject, error = %e, "Failed to term bad payload");
578+
}
579+
}
580+
Err(DispatchError::NoReplySubject) => {
581+
if let Err(e) = js_msg.term().await {
582+
warn!(subject, error = %e, "Failed to term missing reply subject");
583+
}
584+
}
585+
Err(DispatchError::Reply(_)) => {
586+
if let Err(e) = js_msg.nak().await {
587+
warn!(subject, error = %e, "Failed to nak after reply failure");
588+
}
589+
}
590+
Err(DispatchError::NotificationHandler(_)) => {
591+
if let Err(e) = js_msg.ack().await {
592+
warn!(subject, error = %e, "Failed to ack after notification handler error");
593+
}
594+
}
595+
}
596+
597+
if let Err(e) = result {
598+
let sid = parsed
599+
.session_id
600+
.as_ref()
601+
.map(|s| s.as_str())
602+
.unwrap_or("-");
603+
warn!(subject, session_id = sid, error = %e, "Error handling agent request");
604+
}
605+
}
606+
316607
#[cfg(test)]
317608
mod tests {
318609
use super::*;

rsworkspace/crates/acp-nats-stdio/src/main.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
3232
let stdin = async_compat::Compat::new(tokio::io::stdin());
3333
let stdout = async_compat::Compat::new(tokio::io::stdout());
3434

35+
let js_context = async_nats::jetstream::new(nats_client.clone());
36+
let js = acp_nats::NatsJetStreamClient::new(js_context);
37+
3538
let local = tokio::task::LocalSet::new();
3639
let result = local
3740
.run_until(run_bridge(
3841
nats_client,
42+
js,
3943
&config,
4044
stdout,
4145
stdin,
@@ -57,8 +61,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
5761
#[cfg(coverage)]
5862
fn main() {}
5963

60-
async fn run_bridge<N, W, R>(
64+
async fn run_bridge<N, J, W, R>(
6165
nats_client: N,
66+
js: J,
6267
config: &acp_nats::Config,
6368
stdout: W,
6469
stdin: R,
@@ -70,13 +75,16 @@ where
7075
+ acp_nats::FlushClient
7176
+ acp_nats::SubscribeClient
7277
+ 'static,
78+
J: acp_nats::JetStreamPublisher + acp_nats::JetStreamConsumerFactory + 'static,
7379
W: futures::AsyncWrite + Unpin + 'static,
7480
R: futures::AsyncRead + Unpin + 'static,
7581
{
7682
let meter = acp_telemetry::meter("acp-io-bridge-nats");
7783
let (notification_tx, notification_rx) = tokio::sync::mpsc::channel::<SessionNotification>(64);
78-
let bridge = Rc::new(Bridge::new(
84+
85+
let bridge = Rc::new(Bridge::with_jetstream(
7986
nats_client.clone(),
87+
js,
8088
SystemClock,
8189
&meter,
8290
config.clone(),
@@ -166,6 +174,7 @@ mod tests {
166174
let result = local
167175
.run_until(run_bridge(
168176
mock,
177+
(),
169178
&config,
170179
stdout,
171180
stdin,
@@ -198,6 +207,7 @@ mod tests {
198207
let result = local
199208
.run_until(run_bridge(
200209
mock,
210+
(),
201211
&config,
202212
stdout,
203213
stdin,

0 commit comments

Comments
 (0)