Skip to content

Commit 82d3e12

Browse files
authored
refactor(acp-nats): extract Bridge into its own file (#49)
Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent 9100b5c commit 82d3e12

2 files changed

Lines changed: 202 additions & 197 deletions

File tree

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
use std::cell::RefCell;
2+
use std::time::Duration;
3+
4+
use crate::config::Config;
5+
use crate::nats::{
6+
self, ExtSessionReady, FlushClient, FlushPolicy, PublishClient, PublishOptions, RequestClient,
7+
RetryPolicy, SubscribeClient, agent,
8+
};
9+
use crate::pending_prompt_waiters::PendingSessionPromptResponseWaiters;
10+
use crate::telemetry::metrics::Metrics;
11+
use agent_client_protocol::{
12+
Agent, AuthenticateRequest, AuthenticateResponse, CancelNotification, CloseSessionRequest,
13+
CloseSessionResponse, ExtNotification, ExtRequest, ExtResponse, ForkSessionRequest,
14+
ForkSessionResponse, InitializeRequest, InitializeResponse, ListSessionsRequest,
15+
ListSessionsResponse, LoadSessionRequest, LoadSessionResponse, NewSessionRequest,
16+
NewSessionResponse, PromptRequest, PromptResponse, Result, ResumeSessionRequest,
17+
ResumeSessionResponse, SessionId, SessionNotification, SetSessionConfigOptionRequest,
18+
SetSessionConfigOptionResponse, SetSessionModeRequest, SetSessionModeResponse,
19+
SetSessionModelRequest, SetSessionModelResponse,
20+
};
21+
use opentelemetry::metrics::Meter;
22+
use tokio::sync::mpsc;
23+
use tokio::task::JoinHandle;
24+
use tracing::{info, warn};
25+
use trogon_std::time::GetElapsed;
26+
27+
use super::{
28+
authenticate, cancel, close_session, ext_method, ext_notification, fork_session, initialize,
29+
list_sessions, load_session, new_session, prompt, resume_session, set_session_config_option,
30+
set_session_mode, set_session_model,
31+
};
32+
33+
/// Delay before publishing `session.ready` to NATS.
34+
///
35+
/// The `Agent` trait returns the response value *before* the transport layer
36+
/// serializes and writes it to the client. Without a delay the spawned task
37+
/// could publish `session.ready` before the client has received the
38+
/// `session/new` response, violating the ordering guarantee.
39+
const SESSION_READY_DELAY: Duration = Duration::from_millis(100);
40+
41+
pub struct Bridge<N, C: GetElapsed> {
42+
pub(crate) nats: N,
43+
pub(crate) clock: C,
44+
pub(crate) config: Config,
45+
pub(crate) metrics: Metrics,
46+
pub(crate) notification_sender: mpsc::Sender<SessionNotification>,
47+
pub(crate) pending_session_prompt_responses: PendingSessionPromptResponseWaiters<C::Instant>,
48+
pub(crate) background_tasks: RefCell<Vec<JoinHandle<()>>>,
49+
}
50+
51+
impl<N, C: GetElapsed> Bridge<N, C> {
52+
pub fn new(
53+
nats: N,
54+
clock: C,
55+
meter: &Meter,
56+
config: Config,
57+
notification_sender: mpsc::Sender<SessionNotification>,
58+
) -> Self {
59+
Self {
60+
nats,
61+
clock,
62+
config,
63+
metrics: Metrics::new(meter),
64+
notification_sender,
65+
pending_session_prompt_responses: PendingSessionPromptResponseWaiters::new(),
66+
background_tasks: RefCell::new(Vec::new()),
67+
}
68+
}
69+
70+
pub(crate) fn nats(&self) -> &N {
71+
&self.nats
72+
}
73+
74+
pub(crate) fn spawn_background(&self, task: JoinHandle<()>) {
75+
self.background_tasks.borrow_mut().push(task);
76+
}
77+
78+
pub async fn drain_background_tasks(&self) {
79+
let tasks: Vec<_> = self.background_tasks.borrow_mut().drain(..).collect();
80+
for task in tasks {
81+
let _ = task.await;
82+
}
83+
}
84+
}
85+
86+
impl<N: PublishClient + FlushClient + Clone + Send + 'static, C: GetElapsed> Bridge<N, C> {
87+
pub(crate) fn schedule_session_ready(&self, session_id: SessionId) {
88+
let nats = self.nats.clone();
89+
let prefix = self.config.acp_prefix().to_string();
90+
let metrics = self.metrics.clone();
91+
let handle = tokio::spawn(async move {
92+
publish_session_ready(&nats, &prefix, &session_id, &metrics).await;
93+
});
94+
self.spawn_background(handle);
95+
}
96+
}
97+
98+
async fn publish_session_ready<N: PublishClient + FlushClient>(
99+
nats: &N,
100+
prefix: &str,
101+
session_id: &SessionId,
102+
metrics: &Metrics,
103+
) {
104+
tokio::time::sleep(SESSION_READY_DELAY).await;
105+
106+
let subject = agent::ext_session_ready(prefix, &session_id.to_string());
107+
info!(session_id = %session_id, subject = %subject, "Publishing session.ready");
108+
109+
let message = ExtSessionReady::new(session_id.clone());
110+
111+
let options = PublishOptions::builder()
112+
.publish_retry_policy(RetryPolicy::standard())
113+
.flush_policy(FlushPolicy::standard())
114+
.build();
115+
116+
if let Err(e) = nats::publish(nats, &subject, &message, options).await {
117+
warn!(
118+
error = %e,
119+
session_id = %session_id,
120+
"Failed to publish session.ready"
121+
);
122+
metrics.record_error("session_ready", "session_ready_publish_failed");
123+
} else {
124+
info!(session_id = %session_id, "Published session.ready");
125+
}
126+
}
127+
128+
#[async_trait::async_trait(?Send)]
129+
impl<N: RequestClient + PublishClient + SubscribeClient + FlushClient, C: GetElapsed> Agent
130+
for Bridge<N, C>
131+
{
132+
async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse> {
133+
initialize::handle(self, args).await
134+
}
135+
136+
async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
137+
authenticate::handle(self, args).await
138+
}
139+
140+
async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
141+
new_session::handle(self, args).await
142+
}
143+
144+
async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse> {
145+
load_session::handle(self, args).await
146+
}
147+
148+
async fn set_session_mode(
149+
&self,
150+
args: SetSessionModeRequest,
151+
) -> Result<SetSessionModeResponse> {
152+
set_session_mode::handle(self, args).await
153+
}
154+
155+
async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse> {
156+
prompt::handle(self, args, &trogon_std::StdJsonSerialize).await
157+
}
158+
159+
async fn cancel(&self, args: CancelNotification) -> Result<()> {
160+
cancel::handle(self, args).await
161+
}
162+
163+
async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
164+
list_sessions::handle(self, args).await
165+
}
166+
167+
async fn set_session_config_option(
168+
&self,
169+
args: SetSessionConfigOptionRequest,
170+
) -> Result<SetSessionConfigOptionResponse> {
171+
set_session_config_option::handle(self, args).await
172+
}
173+
174+
async fn set_session_model(
175+
&self,
176+
args: SetSessionModelRequest,
177+
) -> Result<SetSessionModelResponse> {
178+
set_session_model::handle(self, args).await
179+
}
180+
181+
async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
182+
fork_session::handle(self, args).await
183+
}
184+
185+
async fn resume_session(&self, args: ResumeSessionRequest) -> Result<ResumeSessionResponse> {
186+
resume_session::handle(self, args).await
187+
}
188+
189+
async fn close_session(&self, args: CloseSessionRequest) -> Result<CloseSessionResponse> {
190+
close_session::handle(self, args).await
191+
}
192+
193+
async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
194+
ext_method::handle(self, args).await
195+
}
196+
197+
async fn ext_notification(&self, args: ExtNotification) -> Result<()> {
198+
ext_notification::handle(self, args).await
199+
}
200+
}

0 commit comments

Comments
 (0)