@@ -20,12 +20,14 @@ use trogon_nats::{FlushClient, PublishClient, RequestClient, SubscribeClient};
2020
2121pub enum ConnectionError {
2222 Subscribe ( Box < dyn std:: error:: Error + Send + Sync > ) ,
23+ JetStream ( Box < dyn std:: error:: Error + Send + Sync > ) ,
2324}
2425
2526impl std:: fmt:: Debug for ConnectionError {
2627 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
2728 match self {
2829 Self :: Subscribe ( e) => f. debug_tuple ( "Subscribe" ) . field ( e) . finish ( ) ,
30+ Self :: JetStream ( e) => f. debug_tuple ( "JetStream" ) . field ( e) . finish ( ) ,
2931 }
3032 }
3133}
@@ -34,6 +36,7 @@ impl std::fmt::Display for ConnectionError {
3436 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
3537 match self {
3638 Self :: Subscribe ( e) => write ! ( f, "failed to subscribe: {}" , e) ,
39+ Self :: JetStream ( e) => write ! ( f, "jetstream error: {}" , e) ,
3740 }
3841 }
3942}
9598 ( conn, io_task)
9699 }
97100
101+ pub fn with_jetstream < J > (
102+ agent : impl Agent + ' static ,
103+ nats : N ,
104+ js : J ,
105+ acp_prefix : AcpPrefix ,
106+ spawn : impl Fn ( LocalBoxFuture < ' static , ( ) > ) + Copy + ' static ,
107+ ) -> (
108+ Self ,
109+ impl std:: future:: Future < Output = Result < ( ) , ConnectionError > > ,
110+ )
111+ where
112+ J : JetStreamConsumerFactory + ' static ,
113+ {
114+ let nats_for_serve = nats. clone ( ) ;
115+ let nats_for_js = nats. clone ( ) ;
116+ let prefix = acp_prefix. as_str ( ) . to_string ( ) ;
117+ let prefix_js = prefix. clone ( ) ;
118+
119+ let io_task = async move {
120+ let ( agent1, agent2) = {
121+ let agent = Rc :: new ( agent) ;
122+ ( agent. clone ( ) , agent)
123+ } ;
124+
125+ let core = serve ( agent1, nats_for_serve, & prefix, spawn) ;
126+ let jetstream = serve_js ( agent2, nats_for_js, js, & prefix_js, spawn) ;
127+
128+ tokio:: select! {
129+ result = core => result,
130+ result = jetstream => result,
131+ }
132+ } ;
133+
134+ let conn = Self {
135+ nats,
136+ acp_prefix,
137+ operation_timeout : DEFAULT_OPERATION_TIMEOUT ,
138+ } ;
139+ ( conn, io_task)
140+ }
141+
98142 pub fn client_for_session ( & self , session_id : AcpSessionId ) -> NatsClientProxy < N > {
99143 NatsClientProxy :: new (
100144 self . nats . clone ( ) ,
@@ -313,6 +357,253 @@ where
313357 . map_err ( DispatchError :: NotificationHandler )
314358}
315359
360+ use trogon_nats:: jetstream:: { JetStreamConsumer as _, JetStreamConsumerFactory , JsMessage } ;
361+
362+ const KEEPALIVE_INTERVAL : Duration = Duration :: from_secs ( 15 ) ;
363+
364+ async fn handle_request_with_keepalive < N , Resp , ReqT , F , M > (
365+ msg : & Message ,
366+ nats : & N ,
367+ js_msg : & M ,
368+ handler : impl FnOnce ( ReqT ) -> F ,
369+ ) -> Result < ( ) , DispatchError >
370+ where
371+ N : PublishClient + FlushClient ,
372+ ReqT : serde:: de:: DeserializeOwned ,
373+ F : std:: future:: Future < Output = agent_client_protocol:: Result < Resp > > ,
374+ Resp : serde:: Serialize ,
375+ M : JsMessage ,
376+ {
377+ let reply_to = msg. reply . as_deref ( ) . ok_or ( DispatchError :: NoReplySubject ) ?;
378+
379+ let request: ReqT = match serde_json:: from_slice ( & msg. payload ) {
380+ Ok ( req) => req,
381+ Err ( e) => {
382+ let error = agent_client_protocol:: Error :: new (
383+ agent_client_protocol:: ErrorCode :: InvalidParams . into ( ) ,
384+ format ! ( "Failed to deserialize request: {}" , e) ,
385+ ) ;
386+ let _ = reply ( nats, reply_to, & error) . await ;
387+ return Err ( DispatchError :: DeserializeRequest ( e) ) ;
388+ }
389+ } ;
390+
391+ let handler_fut = handler ( request) ;
392+ tokio:: pin!( handler_fut) ;
393+
394+ let mut keepalive = tokio:: time:: interval ( KEEPALIVE_INTERVAL ) ;
395+ keepalive. tick ( ) . await ;
396+
397+ loop {
398+ tokio:: select! {
399+ result = & mut handler_fut => {
400+ return match result {
401+ Ok ( resp) => reply( nats, reply_to, & resp) . await ,
402+ Err ( err) => reply( nats, reply_to, & err) . await ,
403+ } ;
404+ }
405+ _ = keepalive. tick( ) => {
406+ if let Err ( e) = js_msg. in_progress( ) . await {
407+ warn!( error = %e, "Failed to send in_progress keepalive" ) ;
408+ }
409+ }
410+ }
411+ }
412+ }
413+
414+ async fn serve_js < N , J , A > (
415+ agent : A ,
416+ nats : N ,
417+ js : J ,
418+ prefix : & str ,
419+ spawn : impl Fn ( LocalBoxFuture < ' static , ( ) > ) + ' static ,
420+ ) -> Result < ( ) , ConnectionError >
421+ where
422+ N : PublishClient + FlushClient + Clone + ' static ,
423+ J : JetStreamConsumerFactory + ' static ,
424+ A : Agent + ' static ,
425+ {
426+ let stream_name = acp_nats:: jetstream:: streams:: commands_stream_name ( prefix) ;
427+ let config = acp_nats:: jetstream:: consumers:: commands_observer ( ) ;
428+
429+ info ! ( stream = %stream_name, "Starting JetStream consumer for COMMANDS stream" ) ;
430+
431+ let consumer = js
432+ . create_consumer ( & stream_name, config)
433+ . await
434+ . map_err ( |e| ConnectionError :: JetStream ( Box :: new ( e) ) ) ?;
435+
436+ let mut messages = consumer
437+ . messages ( )
438+ . await
439+ . map_err ( |e| ConnectionError :: JetStream ( Box :: new ( e) ) ) ?;
440+
441+ let agent = Rc :: new ( agent) ;
442+ let nats = Rc :: new ( nats) ;
443+
444+ while let Some ( msg_result) = messages. next ( ) . await {
445+ match msg_result {
446+ Ok ( js_msg) => {
447+ let agent = agent. clone ( ) ;
448+ let nats = nats. clone ( ) ;
449+ spawn ( Box :: pin ( async move {
450+ dispatch_js_message ( js_msg, agent. as_ref ( ) , nats. as_ref ( ) ) . await ;
451+ } ) ) ;
452+ }
453+ Err ( e) => {
454+ warn ! ( error = %e, "JetStream consumer error" ) ;
455+ }
456+ }
457+ }
458+
459+ info ! ( "JetStream COMMANDS consumer ended" ) ;
460+ Ok ( ( ) )
461+ }
462+
463+ async fn dispatch_js_message < N : PublishClient + FlushClient , A : Agent , M : JsMessage > (
464+ js_msg : M ,
465+ agent : & A ,
466+ nats : & N ,
467+ ) {
468+ let subject = js_msg. subject ( ) . to_string ( ) ;
469+ let msg = Message {
470+ subject : subject. as_str ( ) . into ( ) ,
471+ reply : js_msg. reply ( ) . map ( |s| s. into ( ) ) ,
472+ payload : js_msg. payload ( ) . clone ( ) ,
473+ headers : js_msg. headers ( ) . cloned ( ) ,
474+ status : None ,
475+ description : None ,
476+ length : js_msg. payload ( ) . len ( ) ,
477+ } ;
478+ let subject = msg. subject . as_str ( ) ;
479+
480+ let parsed = match parse_agent_subject ( subject) {
481+ Some ( p) => p,
482+ None => {
483+ if let Err ( e) = js_msg. term ( ) . await {
484+ warn ! ( error = %e, subject, "Failed to term unknown subject" ) ;
485+ }
486+ return ;
487+ }
488+ } ;
489+
490+ let result = match parsed. method {
491+ AgentMethod :: Initialize => {
492+ handle_request ( & msg, nats, |req : InitializeRequest | agent. initialize ( req) ) . await
493+ }
494+ AgentMethod :: Authenticate => {
495+ handle_request ( & msg, nats, |req : AuthenticateRequest | {
496+ agent. authenticate ( req)
497+ } )
498+ . await
499+ }
500+ AgentMethod :: SessionNew => {
501+ handle_request ( & msg, nats, |req : NewSessionRequest | agent. new_session ( req) ) . await
502+ }
503+ AgentMethod :: SessionList => {
504+ handle_request ( & msg, nats, |req : ListSessionsRequest | {
505+ agent. list_sessions ( req)
506+ } )
507+ . await
508+ }
509+ AgentMethod :: SessionLoad => {
510+ handle_request ( & msg, nats, |req : LoadSessionRequest | {
511+ agent. load_session ( req)
512+ } )
513+ . await
514+ }
515+ AgentMethod :: SessionPrompt => {
516+ handle_request_with_keepalive ( & msg, nats, & js_msg, |req : PromptRequest | {
517+ agent. prompt ( req)
518+ } )
519+ . await
520+ }
521+ AgentMethod :: SessionCancel => {
522+ handle_notification ( & msg, |req : CancelNotification | agent. cancel ( req) ) . await
523+ }
524+ AgentMethod :: SessionSetMode => {
525+ handle_request ( & msg, nats, |req : SetSessionModeRequest | {
526+ agent. set_session_mode ( req)
527+ } )
528+ . await
529+ }
530+ AgentMethod :: SessionSetConfigOption => {
531+ handle_request ( & msg, nats, |req : SetSessionConfigOptionRequest | {
532+ agent. set_session_config_option ( req)
533+ } )
534+ . await
535+ }
536+ AgentMethod :: SessionSetModel => {
537+ handle_request ( & msg, nats, |req : SetSessionModelRequest | {
538+ agent. set_session_model ( req)
539+ } )
540+ . await
541+ }
542+ AgentMethod :: SessionFork => {
543+ handle_request ( & msg, nats, |req : ForkSessionRequest | {
544+ agent. fork_session ( req)
545+ } )
546+ . await
547+ }
548+ AgentMethod :: SessionResume => {
549+ handle_request ( & msg, nats, |req : ResumeSessionRequest | {
550+ agent. resume_session ( req)
551+ } )
552+ . await
553+ }
554+ AgentMethod :: SessionClose => {
555+ handle_request ( & msg, nats, |req : CloseSessionRequest | {
556+ agent. close_session ( req)
557+ } )
558+ . await
559+ }
560+ AgentMethod :: Ext ( _) => {
561+ if msg. reply . is_some ( ) {
562+ handle_request ( & msg, nats, |req : ExtRequest | agent. ext_method ( req) ) . await
563+ } else {
564+ handle_notification ( & msg, |req : ExtNotification | agent. ext_notification ( req) ) . await
565+ }
566+ }
567+ } ;
568+
569+ match & result {
570+ Ok ( ( ) ) => {
571+ if let Err ( e) = js_msg. ack ( ) . await {
572+ warn ! ( subject, error = %e, "Failed to ack JetStream message" ) ;
573+ }
574+ }
575+ Err ( DispatchError :: DeserializeRequest ( _) | DispatchError :: DeserializeNotification ( _) ) => {
576+ if let Err ( e) = js_msg. term ( ) . await {
577+ warn ! ( subject, error = %e, "Failed to term bad payload" ) ;
578+ }
579+ }
580+ Err ( DispatchError :: NoReplySubject ) => {
581+ if let Err ( e) = js_msg. term ( ) . await {
582+ warn ! ( subject, error = %e, "Failed to term missing reply subject" ) ;
583+ }
584+ }
585+ Err ( DispatchError :: Reply ( _) ) => {
586+ if let Err ( e) = js_msg. nak ( ) . await {
587+ warn ! ( subject, error = %e, "Failed to nak after reply failure" ) ;
588+ }
589+ }
590+ Err ( DispatchError :: NotificationHandler ( _) ) => {
591+ if let Err ( e) = js_msg. ack ( ) . await {
592+ warn ! ( subject, error = %e, "Failed to ack after notification handler error" ) ;
593+ }
594+ }
595+ }
596+
597+ if let Err ( e) = result {
598+ let sid = parsed
599+ . session_id
600+ . as_ref ( )
601+ . map ( |s| s. as_str ( ) )
602+ . unwrap_or ( "-" ) ;
603+ warn ! ( subject, session_id = sid, error = %e, "Error handling agent request" ) ;
604+ }
605+ }
606+
316607#[ cfg( test) ]
317608mod tests {
318609 use super :: * ;
0 commit comments