Skip to content

Commit 2d61ce7

Browse files
committed
feat(acp-nats, acp-nats-agent): wire JetStream into bridge and agent
Wire Layer 3 of JetStream integration — the actual message flow. Bridge prompt handler branches on bridge.js(): when JetStream is available, publishes to COMMANDS stream and consumes from NOTIFICATIONS/RESPONSES streams instead of subscribe-before-publish. Session.ready skips the 100ms sleep when JetStream is available — the RESPONSES stream captures the message regardless of timing. Agent library gains serve_js() alongside serve() — creates a consumer on the COMMANDS stream and dispatches with ack/nak/term signals. Runs in parallel with core NATS serve via with_jetstream() constructor. Binary crates (stdio, ws) create NatsJetStreamClient from the NATS connection and pass to Bridge::with_jetstream(). () implements JetStream traits as no-ops for backward compatibility when JetStream is not configured. Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent 43f9048 commit 2d61ce7

13 files changed

Lines changed: 540 additions & 42 deletions

File tree

rsworkspace/crates/acp-nats-agent/src/connection.rs

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ use trogon_nats::{FlushClient, PublishClient, RequestClient, SubscribeClient};
2020

2121
pub enum ConnectionError {
2222
Subscribe(Box<dyn std::error::Error + Send + Sync>),
23+
JetStream(Box<dyn std::error::Error + Send + Sync>),
2324
}
2425

2526
impl std::fmt::Debug for ConnectionError {
2627
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2728
match self {
2829
Self::Subscribe(e) => f.debug_tuple("Subscribe").field(e).finish(),
30+
Self::JetStream(e) => f.debug_tuple("JetStream").field(e).finish(),
2931
}
3032
}
3133
}
@@ -34,6 +36,7 @@ impl std::fmt::Display for ConnectionError {
3436
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3537
match self {
3638
Self::Subscribe(e) => write!(f, "failed to subscribe: {}", e),
39+
Self::JetStream(e) => write!(f, "jetstream error: {}", e),
3740
}
3841
}
3942
}
@@ -95,6 +98,47 @@ where
9598
(conn, io_task)
9699
}
97100

101+
pub fn with_jetstream<J>(
102+
agent: impl Agent + 'static,
103+
nats: N,
104+
js: J,
105+
acp_prefix: AcpPrefix,
106+
spawn: impl Fn(LocalBoxFuture<'static, ()>) + Copy + 'static,
107+
) -> (
108+
Self,
109+
impl std::future::Future<Output = Result<(), ConnectionError>>,
110+
)
111+
where
112+
J: JetStreamConsumerFactory + 'static,
113+
{
114+
let nats_for_serve = nats.clone();
115+
let nats_for_js = nats.clone();
116+
let prefix = acp_prefix.as_str().to_string();
117+
let prefix_js = prefix.clone();
118+
119+
let io_task = async move {
120+
let (agent1, agent2) = {
121+
let agent = Rc::new(agent);
122+
(agent.clone(), agent)
123+
};
124+
125+
let core = serve(agent1, nats_for_serve, &prefix, spawn);
126+
let jetstream = serve_js(agent2, nats_for_js, js, &prefix_js, spawn);
127+
128+
tokio::select! {
129+
result = core => result,
130+
result = jetstream => result,
131+
}
132+
};
133+
134+
let conn = Self {
135+
nats,
136+
acp_prefix,
137+
operation_timeout: DEFAULT_OPERATION_TIMEOUT,
138+
};
139+
(conn, io_task)
140+
}
141+
98142
pub fn client_for_session(&self, session_id: AcpSessionId) -> NatsClientProxy<N> {
99143
NatsClientProxy::new(
100144
self.nats.clone(),
@@ -313,6 +357,183 @@ where
313357
.map_err(DispatchError::NotificationHandler)
314358
}
315359

360+
use trogon_nats::jetstream::{JetStreamConsumer as _, JetStreamConsumerFactory, JsMessage};
361+
362+
async fn serve_js<N, J, A>(
363+
agent: A,
364+
nats: N,
365+
js: J,
366+
prefix: &str,
367+
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
368+
) -> Result<(), ConnectionError>
369+
where
370+
N: PublishClient + FlushClient + Clone + 'static,
371+
J: JetStreamConsumerFactory + 'static,
372+
A: Agent + 'static,
373+
{
374+
let stream_name = acp_nats::jetstream::streams::commands_stream_name(prefix);
375+
let config = acp_nats::jetstream::consumers::commands_observer();
376+
377+
info!(stream = %stream_name, "Starting JetStream consumer for COMMANDS stream");
378+
379+
let consumer = js
380+
.create_consumer(&stream_name, config)
381+
.await
382+
.map_err(|e| ConnectionError::JetStream(Box::new(e)))?;
383+
384+
let mut messages = consumer
385+
.messages()
386+
.await
387+
.map_err(|e| ConnectionError::JetStream(Box::new(e)))?;
388+
389+
let agent = Rc::new(agent);
390+
let nats = Rc::new(nats);
391+
392+
while let Some(msg_result) = messages.next().await {
393+
match msg_result {
394+
Ok(js_msg) => {
395+
let agent = agent.clone();
396+
let nats = nats.clone();
397+
spawn(Box::pin(async move {
398+
dispatch_js_message(js_msg, agent.as_ref(), nats.as_ref()).await;
399+
}));
400+
}
401+
Err(e) => {
402+
warn!(error = %e, "JetStream consumer error");
403+
}
404+
}
405+
}
406+
407+
info!("JetStream COMMANDS consumer ended");
408+
Ok(())
409+
}
410+
411+
async fn dispatch_js_message<N: PublishClient + FlushClient, A: Agent>(
412+
js_msg: JsMessage,
413+
agent: &A,
414+
nats: &N,
415+
) {
416+
let msg = js_msg.message();
417+
let subject = msg.subject.as_str();
418+
419+
let parsed = match parse_agent_subject(subject) {
420+
Some(p) => p,
421+
None => {
422+
if let Err(e) = js_msg.term().await {
423+
warn!(error = %e, subject, "Failed to term unknown subject");
424+
}
425+
return;
426+
}
427+
};
428+
429+
let result = match parsed.method {
430+
AgentMethod::Initialize => {
431+
handle_request(msg, nats, |req: InitializeRequest| agent.initialize(req)).await
432+
}
433+
AgentMethod::Authenticate => {
434+
handle_request(msg, nats, |req: AuthenticateRequest| {
435+
agent.authenticate(req)
436+
})
437+
.await
438+
}
439+
AgentMethod::SessionNew => {
440+
handle_request(msg, nats, |req: NewSessionRequest| agent.new_session(req)).await
441+
}
442+
AgentMethod::SessionList => {
443+
handle_request(msg, nats, |req: ListSessionsRequest| {
444+
agent.list_sessions(req)
445+
})
446+
.await
447+
}
448+
AgentMethod::SessionLoad => {
449+
handle_request(msg, nats, |req: LoadSessionRequest| agent.load_session(req)).await
450+
}
451+
AgentMethod::SessionPrompt => {
452+
handle_request(msg, nats, |req: PromptRequest| agent.prompt(req)).await
453+
}
454+
AgentMethod::SessionCancel => {
455+
handle_notification(msg, |req: CancelNotification| agent.cancel(req)).await
456+
}
457+
AgentMethod::SessionSetMode => {
458+
handle_request(msg, nats, |req: SetSessionModeRequest| {
459+
agent.set_session_mode(req)
460+
})
461+
.await
462+
}
463+
AgentMethod::SessionSetConfigOption => {
464+
handle_request(msg, nats, |req: SetSessionConfigOptionRequest| {
465+
agent.set_session_config_option(req)
466+
})
467+
.await
468+
}
469+
AgentMethod::SessionSetModel => {
470+
handle_request(msg, nats, |req: SetSessionModelRequest| {
471+
agent.set_session_model(req)
472+
})
473+
.await
474+
}
475+
AgentMethod::SessionFork => {
476+
handle_request(msg, nats, |req: ForkSessionRequest| agent.fork_session(req)).await
477+
}
478+
AgentMethod::SessionResume => {
479+
handle_request(msg, nats, |req: ResumeSessionRequest| {
480+
agent.resume_session(req)
481+
})
482+
.await
483+
}
484+
AgentMethod::SessionClose => {
485+
handle_request(msg, nats, |req: CloseSessionRequest| {
486+
agent.close_session(req)
487+
})
488+
.await
489+
}
490+
AgentMethod::Ext(_) => {
491+
if msg.reply.is_some() {
492+
handle_request(msg, nats, |req: ExtRequest| agent.ext_method(req)).await
493+
} else {
494+
handle_notification(msg, |req: ExtNotification| agent.ext_notification(req)).await
495+
}
496+
}
497+
};
498+
499+
match &result {
500+
Ok(()) => {
501+
if let Err(e) = js_msg.ack().await {
502+
warn!(subject, error = %e, "Failed to ack JetStream message");
503+
}
504+
}
505+
Err(DispatchError::DeserializeRequest(_) | DispatchError::DeserializeNotification(_)) => {
506+
if let Err(e) = js_msg.term().await {
507+
warn!(subject, error = %e, "Failed to term bad payload");
508+
}
509+
}
510+
Err(DispatchError::NoReplySubject) => {
511+
if let Err(e) = js_msg.term().await {
512+
warn!(subject, error = %e, "Failed to term missing reply subject");
513+
}
514+
}
515+
Err(DispatchError::Reply(_)) => {
516+
if let Err(e) = js_msg.nak().await {
517+
warn!(subject, error = %e, "Failed to nak after reply failure");
518+
}
519+
}
520+
Err(DispatchError::NotificationHandler(_)) => {
521+
if let Err(e) = js_msg.ack().await {
522+
warn!(subject, error = %e, "Failed to ack after notification handler error");
523+
}
524+
}
525+
}
526+
527+
if let Err(e) = result {
528+
let sid = parsed
529+
.session_id
530+
.as_ref()
531+
.map(|s| s.as_str())
532+
.unwrap_or("-");
533+
warn!(subject, session_id = sid, error = %e, "Error handling agent request");
534+
}
535+
}
536+
316537
#[cfg(test)]
317538
mod tests {
318539
use super::*;

rsworkspace/crates/acp-nats-stdio/src/main.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
3232
let stdin = async_compat::Compat::new(tokio::io::stdin());
3333
let stdout = async_compat::Compat::new(tokio::io::stdout());
3434

35+
let js_context = async_nats::jetstream::new(nats_client.clone());
36+
let js = acp_nats::NatsJetStreamClient::new(js_context);
37+
3538
let local = tokio::task::LocalSet::new();
3639
let result = local
3740
.run_until(run_bridge(
3841
nats_client,
42+
js,
3943
&config,
4044
stdout,
4145
stdin,
@@ -57,8 +61,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
5761
#[cfg(coverage)]
5862
fn main() {}
5963

60-
async fn run_bridge<N, W, R>(
64+
async fn run_bridge<N, J, W, R>(
6165
nats_client: N,
66+
js: J,
6267
config: &acp_nats::Config,
6368
stdout: W,
6469
stdin: R,
@@ -70,13 +75,16 @@ where
7075
+ acp_nats::FlushClient
7176
+ acp_nats::SubscribeClient
7277
+ 'static,
78+
J: acp_nats::JetStreamPublisher + acp_nats::JetStreamConsumerFactory + 'static,
7379
W: futures::AsyncWrite + Unpin + 'static,
7480
R: futures::AsyncRead + Unpin + 'static,
7581
{
7682
let meter = acp_telemetry::meter("acp-io-bridge-nats");
7783
let (notification_tx, notification_rx) = tokio::sync::mpsc::channel::<SessionNotification>(64);
78-
let bridge = Rc::new(Bridge::new(
84+
85+
let bridge = Rc::new(Bridge::with_jetstream(
7986
nats_client.clone(),
87+
js,
8088
SystemClock,
8189
&meter,
8290
config.clone(),
@@ -166,6 +174,7 @@ mod tests {
166174
let result = local
167175
.run_until(run_bridge(
168176
mock,
177+
(),
169178
&config,
170179
stdout,
171180
stdin,
@@ -198,6 +207,7 @@ mod tests {
198207
let result = local
199208
.run_until(run_bridge(
200209
mock,
210+
(),
201211
&config,
202212
stdout,
203213
stdin,

rsworkspace/crates/acp-nats-ws/src/connection.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use acp_nats::{JetStreamConsumerFactory, JetStreamPublisher};
12
use acp_nats::{StdJsonSerialize, agent::Bridge, client, spawn_notification_forwarder};
23
use agent_client_protocol::{AgentSideConnection, SessionNotification};
34
use axum::extract::ws::{Message, WebSocket};
@@ -12,9 +13,10 @@ use trogon_std::time::SystemClock;
1213
use crate::constants::DUPLEX_BUFFER_SIZE;
1314

1415
/// Handles a single WebSocket connection by bridging it to NATS via ACP.
15-
pub async fn handle<N>(
16+
pub async fn handle<N, J>(
1617
socket: WebSocket,
1718
nats_client: N,
19+
js: J,
1820
config: acp_nats::Config,
1921
mut shutdown_rx: watch::Receiver<bool>,
2022
) where
@@ -24,6 +26,7 @@ pub async fn handle<N>(
2426
+ acp_nats::SubscribeClient
2527
+ Clone
2628
+ 'static,
29+
J: JetStreamPublisher + JetStreamConsumerFactory + 'static,
2730
{
2831
let (ws_sender, ws_receiver) = socket.split();
2932

@@ -35,8 +38,9 @@ pub async fn handle<N>(
3538

3639
let meter = acp_telemetry::meter("acp-nats-ws");
3740
let (notification_tx, notification_rx) = tokio::sync::mpsc::channel::<SessionNotification>(64);
38-
let bridge = Rc::new(Bridge::new(
41+
let bridge = Rc::new(Bridge::with_jetstream(
3942
nats_client.clone(),
43+
js,
4044
SystemClock,
4145
&meter,
4246
config,

0 commit comments

Comments
 (0)