@@ -2,6 +2,7 @@ use super::Bridge;
22use crate :: config:: PROMPT_TIMEOUT_MESSAGE_SECS_THRESHOLD ;
33use crate :: error:: AGENT_UNAVAILABLE ;
44use crate :: nats:: { self , FlushClient , PublishClient , RequestClient , agent} ;
5+ use crate :: pending_prompt_waiters:: PromptToken ;
56use crate :: session_id:: AcpSessionId ;
67use agent_client_protocol:: ErrorCode ;
78use agent_client_protocol:: { Error , PromptRequest , PromptResponse , Result } ;
@@ -49,6 +50,16 @@ fn duplicate_waiter_error<N: RequestClient + PublishClient + FlushClient, C: Get
4950 )
5051}
5152
53+ fn add_prompt_id_to_request ( args : & PromptRequest , prompt_token : PromptToken ) -> PromptRequest {
54+ let mut meta = args
55+ . meta
56+ . as_ref ( )
57+ . cloned ( )
58+ . unwrap_or_else ( serde_json:: Map :: new) ;
59+ meta. insert ( "prompt_id" . to_string ( ) , serde_json:: json!( prompt_token. 0 ) ) ;
60+ args. clone ( ) . meta ( meta)
61+ }
62+
5263#[ instrument(
5364 name = "acp.session.prompt" ,
5465 skip( bridge, args) ,
@@ -76,19 +87,21 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
7687 let nats = bridge. nats ( ) ;
7788 let subject = agent:: session_prompt ( bridge. config . acp_prefix ( ) , session_id. as_str ( ) ) ;
7889
79- let ( rx, _waiter_guard) = match bridge
90+ let ( rx, _waiter_guard, prompt_token ) = match bridge
8091 . pending_session_prompt_responses
8192 . register_waiter ( args. session_id . clone ( ) )
8293 {
8394 Ok ( waiter) => waiter,
8495 Err ( ( ) ) => return Err ( duplicate_waiter_error ( bridge, & args. session_id ) ) ,
8596 } ;
8697
98+ let request_with_token = add_prompt_id_to_request ( & args, prompt_token) ;
99+
87100 let publish_options = nats:: PublishOptions :: builder ( )
88101 . flush_policy ( nats:: FlushPolicy :: no_retries ( ) )
89102 . build ( ) ;
90103
91- if let Err ( e) = nats:: publish ( nats, & subject, & args , publish_options) . await {
104+ if let Err ( e) = nats:: publish ( nats, & subject, & request_with_token , publish_options) . await {
92105 bridge
93106 . metrics
94107 . record_error ( "prompt" , "prompt_publish_failed" ) ;
@@ -139,7 +152,11 @@ pub async fn handle<N: RequestClient + PublishClient + FlushClient, C: GetElapse
139152 bridge. metrics . record_error ( "prompt" , "prompt_timeout" ) ;
140153 bridge
141154 . pending_session_prompt_responses
142- . mark_prompt_waiter_timed_out ( args. session_id . clone ( ) , & bridge. clock ) ;
155+ . mark_prompt_waiter_timed_out (
156+ args. session_id . clone ( ) ,
157+ prompt_token,
158+ & bridge. clock ,
159+ ) ;
143160
144161 let timeout = bridge. config . prompt_timeout ( ) ;
145162 let timeout_msg = if timeout >= PROMPT_TIMEOUT_MESSAGE_SECS_THRESHOLD {
@@ -346,6 +363,7 @@ mod tests {
346363 . pending_session_prompt_responses
347364 . resolve_waiter (
348365 & SessionId :: from ( "s1" ) ,
366+ PromptToken ( 0 ) ,
349367 Ok ( PromptResponse :: new ( StopReason :: EndTurn ) ) ,
350368 ) ;
351369 let result = handle1. await . unwrap ( ) ;
@@ -357,7 +375,7 @@ mod tests {
357375 #[ tokio:: test]
358376 async fn prompt_rejects_duplicate_waiter_for_same_session ( ) {
359377 let ( _mock, bridge) = mock_bridge ( ) ;
360- let ( _rx, _guard) = bridge
378+ let ( _rx, _guard, _ ) = bridge
361379 . pending_session_prompt_responses
362380 . register_waiter ( agent_client_protocol:: SessionId :: from ( "s1" ) )
363381 . unwrap ( ) ;
@@ -385,14 +403,15 @@ mod tests {
385403 tokio:: time:: sleep ( Duration :: from_millis ( 5 ) ) . await ;
386404 handle. abort ( ) ;
387405 let _ = handle. await ;
388- let ( rx, _guard) = bridge_after
406+ let ( rx, _guard, token ) = bridge_after
389407 . pending_session_prompt_responses
390408 . register_waiter ( SessionId :: from ( "s1" ) )
391409 . expect ( "waiter should be free after cancelled prompt dropped guard" ) ;
392410 bridge_after
393411 . pending_session_prompt_responses
394412 . resolve_waiter (
395413 & SessionId :: from ( "s1" ) ,
414+ token,
396415 Ok ( PromptResponse :: new ( StopReason :: EndTurn ) ) ,
397416 ) ;
398417 let result = rx. await . unwrap ( ) . unwrap ( ) ;
@@ -430,13 +449,14 @@ mod tests {
430449 #[ tokio:: test]
431450 async fn prompt_resolves_waiter_with_response ( ) {
432451 let ( _mock, bridge) = mock_bridge ( ) ;
433- let ( rx, _guard) = bridge
452+ let ( rx, _guard, token ) = bridge
434453 . pending_session_prompt_responses
435454 . register_waiter ( agent_client_protocol:: SessionId :: from ( "s1" ) )
436455 . unwrap ( ) ;
437456
438457 bridge. pending_session_prompt_responses . resolve_waiter (
439458 & agent_client_protocol:: SessionId :: from ( "s1" ) ,
459+ token,
440460 Ok ( PromptResponse :: new ( StopReason :: EndTurn ) ) ,
441461 ) ;
442462
@@ -461,6 +481,7 @@ mod tests {
461481 . pending_session_prompt_responses
462482 . resolve_waiter (
463483 & SessionId :: from ( "s1" ) ,
484+ PromptToken ( 0 ) ,
464485 Ok ( PromptResponse :: new ( StopReason :: EndTurn ) ) ,
465486 ) ;
466487 let result = prompt_handle. await . unwrap ( ) ;
@@ -513,6 +534,7 @@ mod tests {
513534 . pending_session_prompt_responses
514535 . resolve_waiter (
515536 & SessionId :: from ( "s1" ) ,
537+ PromptToken ( 0 ) ,
516538 Ok ( PromptResponse :: new ( StopReason :: EndTurn ) ) ,
517539 ) ;
518540 let result = handle1. await . unwrap ( ) ;
@@ -631,7 +653,11 @@ mod tests {
631653 tokio:: time:: sleep ( Duration :: from_millis ( 5 ) ) . await ;
632654 bridge_resolve
633655 . pending_session_prompt_responses
634- . resolve_waiter ( & SessionId :: from ( "s1" ) , Err ( "parse error" . to_string ( ) ) ) ;
656+ . resolve_waiter (
657+ & SessionId :: from ( "s1" ) ,
658+ PromptToken ( 0 ) ,
659+ Err ( "parse error" . to_string ( ) ) ,
660+ ) ;
635661 let result = prompt_handle. await . unwrap ( ) ;
636662 let err = result. unwrap_err ( ) ;
637663 assert ! ( err. to_string( ) . contains( "parse failed" ) ) ;
@@ -681,6 +707,7 @@ mod tests {
681707 . pending_session_prompt_responses
682708 . resolve_waiter (
683709 & SessionId :: from ( "s1" ) ,
710+ PromptToken ( 0 ) ,
684711 Ok ( PromptResponse :: new ( StopReason :: EndTurn ) ) ,
685712 ) ;
686713 let result = prompt_handle. await . unwrap ( ) ;
0 commit comments