diff --git a/rsworkspace/crates/acp-nats/src/agent/prompt.rs b/rsworkspace/crates/acp-nats/src/agent/prompt.rs index 8ae3ec41c..0d7aa7b73 100644 --- a/rsworkspace/crates/acp-nats/src/agent/prompt.rs +++ b/rsworkspace/crates/acp-nats/src/agent/prompt.rs @@ -2,6 +2,7 @@ use super::Bridge; use crate::config::PROMPT_TIMEOUT_MESSAGE_SECS_THRESHOLD; use crate::error::AGENT_UNAVAILABLE; use crate::nats::{self, FlushClient, PublishClient, RequestClient, agent}; +use crate::pending_prompt_waiters::PromptToken; use crate::session_id::AcpSessionId; use agent_client_protocol::ErrorCode; use agent_client_protocol::{Error, PromptRequest, PromptResponse, Result}; @@ -49,6 +50,16 @@ fn duplicate_waiter_error PromptRequest { + let mut meta = args + .meta + .as_ref() + .cloned() + .unwrap_or_else(serde_json::Map::new); + meta.insert("prompt_id".to_string(), serde_json::json!(prompt_token.0)); + args.clone().meta(meta) +} + #[instrument( name = "acp.session.prompt", skip(bridge, args), @@ -76,7 +87,7 @@ pub async fn handle return Err(duplicate_waiter_error(bridge, &args.session_id)), }; + let request_with_token = add_prompt_id_to_request(&args, prompt_token); + let publish_options = nats::PublishOptions::builder() .flush_policy(nats::FlushPolicy::no_retries()) .build(); - if let Err(e) = nats::publish(nats, &subject, &args, publish_options).await { + if let Err(e) = nats::publish(nats, &subject, &request_with_token, publish_options).await { bridge .metrics .record_error("prompt", "prompt_publish_failed"); @@ -139,7 +152,11 @@ pub async fn handle= PROMPT_TIMEOUT_MESSAGE_SECS_THRESHOLD { @@ -346,6 +363,7 @@ mod tests { .pending_session_prompt_responses .resolve_waiter( &SessionId::from("s1"), + PromptToken(0), Ok(PromptResponse::new(StopReason::EndTurn)), ); let result = handle1.await.unwrap(); @@ -357,7 +375,7 @@ mod tests { #[tokio::test] async fn prompt_rejects_duplicate_waiter_for_same_session() { let (_mock, bridge) = mock_bridge(); - let (_rx, _guard) = bridge + let (_rx, _guard, _) = bridge .pending_session_prompt_responses .register_waiter(agent_client_protocol::SessionId::from("s1")) .unwrap(); @@ -385,7 +403,7 @@ mod tests { tokio::time::sleep(Duration::from_millis(5)).await; handle.abort(); let _ = handle.await; - let (rx, _guard) = bridge_after + let (rx, _guard, token) = bridge_after .pending_session_prompt_responses .register_waiter(SessionId::from("s1")) .expect("waiter should be free after cancelled prompt dropped guard"); @@ -393,6 +411,7 @@ mod tests { .pending_session_prompt_responses .resolve_waiter( &SessionId::from("s1"), + token, Ok(PromptResponse::new(StopReason::EndTurn)), ); let result = rx.await.unwrap().unwrap(); @@ -430,13 +449,14 @@ mod tests { #[tokio::test] async fn prompt_resolves_waiter_with_response() { let (_mock, bridge) = mock_bridge(); - let (rx, _guard) = bridge + let (rx, _guard, token) = bridge .pending_session_prompt_responses .register_waiter(agent_client_protocol::SessionId::from("s1")) .unwrap(); bridge.pending_session_prompt_responses.resolve_waiter( &agent_client_protocol::SessionId::from("s1"), + token, Ok(PromptResponse::new(StopReason::EndTurn)), ); @@ -461,6 +481,7 @@ mod tests { .pending_session_prompt_responses .resolve_waiter( &SessionId::from("s1"), + PromptToken(0), Ok(PromptResponse::new(StopReason::EndTurn)), ); let result = prompt_handle.await.unwrap(); @@ -513,6 +534,7 @@ mod tests { .pending_session_prompt_responses .resolve_waiter( &SessionId::from("s1"), + PromptToken(0), Ok(PromptResponse::new(StopReason::EndTurn)), ); let result = handle1.await.unwrap(); @@ -631,7 +653,11 @@ mod tests { tokio::time::sleep(Duration::from_millis(5)).await; bridge_resolve .pending_session_prompt_responses - .resolve_waiter(&SessionId::from("s1"), Err("parse error".to_string())); + .resolve_waiter( + &SessionId::from("s1"), + PromptToken(0), + Err("parse error".to_string()), + ); let result = prompt_handle.await.unwrap(); let err = result.unwrap_err(); assert!(err.to_string().contains("parse failed")); @@ -681,6 +707,7 @@ mod tests { .pending_session_prompt_responses .resolve_waiter( &SessionId::from("s1"), + PromptToken(0), Ok(PromptResponse::new(StopReason::EndTurn)), ); let result = prompt_handle.await.unwrap(); diff --git a/rsworkspace/crates/acp-nats/src/client/ext_session_prompt_response.rs b/rsworkspace/crates/acp-nats/src/client/ext_session_prompt_response.rs new file mode 100644 index 000000000..b996d0ae3 --- /dev/null +++ b/rsworkspace/crates/acp-nats/src/client/ext_session_prompt_response.rs @@ -0,0 +1,286 @@ +use super::Bridge; +use crate::nats::{FlushClient, PublishClient, RequestClient}; +use crate::pending_prompt_waiters::PromptToken; +use crate::session_id::AcpSessionId; +use agent_client_protocol::{PromptResponse, SessionId}; +use tracing::{instrument, warn}; +use trogon_std::time::GetElapsed; + +#[instrument( + name = "acp.client.ext.session.prompt_response", + skip(payload, bridge), + fields(session_id = %session_id) +)] +pub async fn handle( + session_id: &str, + payload: &[u8], + reply: Option<&str>, + bridge: &Bridge, +) { + if reply.is_some() { + warn!( + session_id = %session_id, + "Unexpected reply subject on prompt response notification" + ); + } + + let Ok(validated) = AcpSessionId::new(session_id) else { + warn!( + session_id = %session_id, + "Invalid session_id in prompt response notification" + ); + bridge + .metrics + .record_error("client.ext.session.prompt_response", "invalid_session_id"); + return; + }; + + let session_id_typed: SessionId = validated.as_str().to_string().into(); + + let (prompt_token_opt, response_result) = + match serde_json::from_slice::(payload) { + Ok(response) => (extract_prompt_token(&response), Ok(response)), + Err(e) => { + let token = extract_prompt_token_from_raw(payload); + (token, Err(e.to_string())) + } + }; + + let Some(prompt_token) = prompt_token_opt else { + warn!( + session_id = %session_id, + "Prompt response missing prompt_id in meta; cannot correlate" + ); + bridge + .metrics + .record_error("client.ext.session.prompt_response", "missing_prompt_id"); + return; + }; + + bridge + .pending_session_prompt_responses + .purge_expired_timed_out_waiters(&bridge.clock); + let suppress_missing_waiter_warning = bridge + .pending_session_prompt_responses + .should_suppress_missing_waiter_warning(&session_id_typed, prompt_token, &bridge.clock); + + let parse_failed = response_result.is_err(); + if !bridge.pending_session_prompt_responses.resolve_waiter( + &session_id_typed, + prompt_token, + response_result, + ) && !suppress_missing_waiter_warning + { + warn!( + session_id = %session_id, + "No pending prompt response waiter found for session" + ); + } + + if parse_failed { + bridge.metrics.record_error( + "client.ext.session.prompt_response", + "prompt_response_parse_failed", + ); + } +} + +fn extract_prompt_token(response: &PromptResponse) -> Option { + response + .meta + .as_ref() + .and_then(|m| m.get("prompt_id")) + .and_then(|v| v.as_u64()) + .map(PromptToken) +} + +fn extract_prompt_token_from_raw(payload: &[u8]) -> Option { + serde_json::from_slice::(payload) + .ok() + .and_then(|v| { + v.get("meta") + .and_then(|m| m.get("prompt_id")) + .and_then(|p| p.as_u64()) + }) + .map(PromptToken) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent::Bridge; + use crate::config::Config; + use agent_client_protocol::StopReason; + use trogon_nats::MockNatsClient; + use trogon_std::time::MockClock; + + fn make_bridge() -> Bridge { + Bridge::new( + MockNatsClient::new(), + MockClock::new(), + &opentelemetry::global::meter("acp-nats-test"), + Config::for_test("acp"), + ) + } + + fn response_with_prompt_id(stop_reason: StopReason, prompt_token: PromptToken) -> Vec { + let mut meta = serde_json::Map::new(); + meta.insert("prompt_id".to_string(), serde_json::json!(prompt_token.0)); + let response = PromptResponse::new(stop_reason).meta(meta); + serde_json::to_vec(&response).unwrap() + } + + #[tokio::test] + async fn resolves_waiter() { + let bridge = make_bridge(); + let session_id: SessionId = "prompt-resp-001".into(); + + let (rx, _guard, token) = bridge + .pending_session_prompt_responses + .register_waiter(session_id.clone()) + .unwrap(); + + let payload = response_with_prompt_id(StopReason::EndTurn, token); + + super::handle("prompt-resp-001", &payload, None, &bridge).await; + + let result = rx + .await + .expect("Should receive response") + .expect("Prompt response should not include error"); + assert_eq!(result.stop_reason, StopReason::EndTurn); + } + + #[tokio::test] + async fn no_waiter_does_not_panic() { + let bridge = make_bridge(); + let payload = response_with_prompt_id(StopReason::EndTurn, PromptToken(0)); + + super::handle("no-waiter-session", &payload, None, &bridge).await; + } + + #[tokio::test] + async fn invalid_payload_with_prompt_id_forwards_parse_error() { + let bridge = make_bridge(); + let session_id: SessionId = "bad-payload-001".into(); + + let (rx, _guard, token) = bridge + .pending_session_prompt_responses + .register_waiter(session_id.clone()) + .unwrap(); + + let payload = format!( + r#"{{"meta":{{"prompt_id":{}}},"stop_reason":"invalid"}}"#, + token.0 + ); + + super::handle("bad-payload-001", payload.as_bytes(), None, &bridge).await; + + let result = rx + .await + .expect("Should receive resolved parse error") + .expect_err("Parse failure should be forwarded to waiter"); + assert!(!result.is_empty(), "Expected parse error to be forwarded"); + } + + #[tokio::test] + async fn missing_prompt_id_is_rejected() { + let bridge = make_bridge(); + let session_id: SessionId = "no-token-session".into(); + + let (rx, _guard, _) = bridge + .pending_session_prompt_responses + .register_waiter(session_id.clone()) + .unwrap(); + + let response = PromptResponse::new(StopReason::EndTurn); + let payload = serde_json::to_vec(&response).unwrap(); + + super::handle("no-token-session", &payload, None, &bridge).await; + + assert!( + bridge + .pending_session_prompt_responses + .has_waiter(&session_id), + "waiter should remain when response lacks prompt_id" + ); + bridge + .pending_session_prompt_responses + .remove_waiter_for_test(&session_id); + drop(rx); + } + + #[tokio::test] + async fn invalid_session_id_is_rejected() { + let bridge = make_bridge(); + let session_id: SessionId = "valid-session".into(); + + let (rx, _guard, token) = bridge + .pending_session_prompt_responses + .register_waiter(session_id.clone()) + .unwrap(); + + let payload = response_with_prompt_id(StopReason::EndTurn, token); + + super::handle("session.with.dots", &payload, None, &bridge).await; + super::handle("session*wild", &payload, None, &bridge).await; + super::handle("session id", &payload, None, &bridge).await; + + assert!( + bridge + .pending_session_prompt_responses + .has_waiter(&session_id), + "invalid session IDs should not resolve valid waiter", + ); + + bridge + .pending_session_prompt_responses + .remove_waiter_for_test(&session_id); + assert!( + !bridge + .pending_session_prompt_responses + .has_waiter(&session_id), + "waiter should be removed" + ); + drop(rx); + } + + #[tokio::test] + async fn late_response_with_wrong_token_does_not_resolve_new_prompt() { + let bridge = make_bridge(); + let session_id: SessionId = "same-session".into(); + + let (_rx1, _guard1, token1) = bridge + .pending_session_prompt_responses + .register_waiter(session_id.clone()) + .unwrap(); + bridge.pending_session_prompt_responses.resolve_waiter( + &session_id, + token1, + Ok(PromptResponse::new(StopReason::EndTurn)), + ); + let _ = _rx1.await; + + let (rx2, _guard2, token2) = bridge + .pending_session_prompt_responses + .register_waiter(session_id.clone()) + .unwrap(); + + let late_payload = response_with_prompt_id(StopReason::EndTurn, token1); + super::handle("same-session", &late_payload, None, &bridge).await; + + assert!( + bridge + .pending_session_prompt_responses + .has_waiter(&session_id), + "late response with old token must not resolve new prompt" + ); + bridge.pending_session_prompt_responses.resolve_waiter( + &session_id, + token2, + Ok(PromptResponse::new(StopReason::EndTurn)), + ); + let result = rx2.await.unwrap().unwrap(); + assert_eq!(result.stop_reason, StopReason::EndTurn); + } +} diff --git a/rsworkspace/crates/acp-nats/src/client/mod.rs b/rsworkspace/crates/acp-nats/src/client/mod.rs index a6457893b..a666357df 100644 --- a/rsworkspace/crates/acp-nats/src/client/mod.rs +++ b/rsworkspace/crates/acp-nats/src/client/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod ext_session_prompt_response; pub(crate) mod fs_read_text_file; pub(crate) mod fs_write_text_file; pub(crate) mod request_permission; @@ -222,6 +223,15 @@ async fn dispatch_client_method< ClientMethod::SessionUpdate => { session_update::handle(&payload, ctx.client, &parsed.session_id).await; } + ClientMethod::ExtSessionPromptResponse => { + ext_session_prompt_response::handle( + parsed.session_id.as_str(), + &payload, + reply.as_deref(), + ctx.bridge, + ) + .await; + } ClientMethod::TerminalCreate => { terminal_create::handle( &payload, diff --git a/rsworkspace/crates/acp-nats/src/nats/client_method.rs b/rsworkspace/crates/acp-nats/src/nats/client_method.rs index fdeae1694..c4f852038 100644 --- a/rsworkspace/crates/acp-nats/src/nats/client_method.rs +++ b/rsworkspace/crates/acp-nats/src/nats/client_method.rs @@ -9,6 +9,7 @@ pub enum ClientMethod { TerminalOutput, TerminalRelease, TerminalWaitForExit, + ExtSessionPromptResponse, } impl ClientMethod { @@ -23,6 +24,7 @@ impl ClientMethod { "client.terminal.output" => Some(Self::TerminalOutput), "client.terminal.release" => Some(Self::TerminalRelease), "client.terminal.wait_for_exit" => Some(Self::TerminalWaitForExit), + "client.ext.session.prompt_response" => Some(Self::ExtSessionPromptResponse), _ => None, } } diff --git a/rsworkspace/crates/acp-nats/src/nats/parsing.rs b/rsworkspace/crates/acp-nats/src/nats/parsing.rs index 928b08082..28aa9f57e 100644 --- a/rsworkspace/crates/acp-nats/src/nats/parsing.rs +++ b/rsworkspace/crates/acp-nats/src/nats/parsing.rs @@ -92,6 +92,14 @@ mod tests { assert_eq!(parsed.method, ClientMethod::TerminalWaitForExit); } + #[test] + fn test_parse_ext_session_prompt_response() { + let subject = "acp.sess999.client.ext.session.prompt_response"; + let parsed = parse_client_subject(subject).unwrap(); + assert_eq!(parsed.session_id.as_str(), "sess999"); + assert_eq!(parsed.method, ClientMethod::ExtSessionPromptResponse); + } + #[test] fn test_parse_with_custom_prefix() { let subject = "myapp.sess123.client.session.update"; diff --git a/rsworkspace/crates/acp-nats/src/nats/subjects.rs b/rsworkspace/crates/acp-nats/src/nats/subjects.rs index e3b611bb0..4133a821a 100644 --- a/rsworkspace/crates/acp-nats/src/nats/subjects.rs +++ b/rsworkspace/crates/acp-nats/src/nats/subjects.rs @@ -78,6 +78,14 @@ pub mod client { format!("{}.{}.client.terminal.wait_for_exit", prefix, session_id) } + pub fn ext(prefix: &str, session_id: &str, method: &str) -> String { + format!("{}.{}.client.ext.{}", prefix, session_id, method) + } + + pub fn ext_session_prompt_response(prefix: &str, session_id: &str) -> String { + ext(prefix, session_id, "session.prompt_response") + } + pub mod wildcards { pub fn all(prefix: &str) -> String { format!("{}.*.client.>", prefix) @@ -162,6 +170,14 @@ mod tests { ); } + #[test] + fn client_ext_session_prompt_response_subject() { + assert_eq!( + client::ext_session_prompt_response("acp", "s1"), + "acp.s1.client.ext.session.prompt_response" + ); + } + #[test] fn client_wildcards_all() { assert_eq!(client::wildcards::all("foo"), "foo.*.client.>"); diff --git a/rsworkspace/crates/acp-nats/src/pending_prompt_waiters.rs b/rsworkspace/crates/acp-nats/src/pending_prompt_waiters.rs index 42d19bc75..b41ac68f5 100644 --- a/rsworkspace/crates/acp-nats/src/pending_prompt_waiters.rs +++ b/rsworkspace/crates/acp-nats/src/pending_prompt_waiters.rs @@ -12,6 +12,8 @@ //! prompt calls. //! - Timed-out sessions are tracked briefly to suppress noisy duplicate timeout-related warnings //! during late-response windows. +//! - Per-prompt correlation via `PromptToken` prevents late responses for prompt A from resolving +//! a newly registered prompt B for the same session. use std::collections::HashMap; use std::sync::Mutex; @@ -24,8 +26,12 @@ use trogon_std::time::GetElapsed; type PromptResponseReceiver = oneshot::Receiver>; +/// Per-prompt correlation token. Ensures late responses for prompt A cannot resolve prompt B. +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] +pub(crate) struct PromptToken(pub u64); + struct WaiterEntry { - token: u64, + token: PromptToken, sender: oneshot::Sender>, } @@ -35,19 +41,19 @@ struct WaiterEntry { pub(crate) struct PromptWaiterGuard<'a, I: Copy> { waiters: &'a PendingSessionPromptResponseWaiters, session_id: SessionId, - waiter_token: u64, + prompt_token: PromptToken, } impl<'a, I: Copy> PromptWaiterGuard<'a, I> { fn new( waiters: &'a PendingSessionPromptResponseWaiters, session_id: SessionId, - waiter_token: u64, + prompt_token: PromptToken, ) -> Self { Self { waiters, session_id, - waiter_token, + prompt_token, } } } @@ -55,7 +61,7 @@ impl<'a, I: Copy> PromptWaiterGuard<'a, I> { impl<'a, I: Copy> Drop for PromptWaiterGuard<'a, I> { fn drop(&mut self) { self.waiters - .remove_waiter_if_token_matches(&self.session_id, self.waiter_token); + .remove_waiter_if_token_matches(&self.session_id, self.prompt_token); } } @@ -66,7 +72,7 @@ impl<'a, I: Copy> Drop for PromptWaiterGuard<'a, I> { pub(crate) struct PendingSessionPromptResponseWaiters { waiters: Mutex>, next_waiter_token: AtomicU64, - timed_out: Mutex>, + timed_out: Mutex>, } impl PendingSessionPromptResponseWaiters { @@ -82,38 +88,56 @@ impl PendingSessionPromptResponseWaiters { /// Registers the receiver for the next prompt response of `session_id`. /// /// Returns `Err(())` when another waiter is already active for the same session. + /// Returns `(receiver, guard, prompt_token)`; the token must be sent with the request and + /// echoed in the response for correct correlation. pub fn register_waiter( &self, session_id: SessionId, - ) -> std::result::Result<(PromptResponseReceiver, PromptWaiterGuard<'_, I>), ()> { + ) -> std::result::Result< + ( + PromptResponseReceiver, + PromptWaiterGuard<'_, I>, + PromptToken, + ), + (), + > { let (tx, rx) = oneshot::channel(); let mut waiters = self.waiters.lock().unwrap(); if waiters.contains_key(&session_id) { return Err(()); } - let waiter_token = self.next_waiter_token.fetch_add(1, Ordering::Relaxed); - self.timed_out.lock().unwrap().remove(&session_id); + let token_value = self.next_waiter_token.fetch_add(1, Ordering::Relaxed); + let prompt_token = PromptToken(token_value); + self.timed_out + .lock() + .unwrap() + .retain(|(s, _), _| s != &session_id); waiters.insert( session_id.clone(), WaiterEntry { - token: waiter_token, + token: prompt_token, sender: tx, }, ); - Ok((rx, PromptWaiterGuard::new(self, session_id, waiter_token))) + Ok(( + rx, + PromptWaiterGuard::new(self, session_id, prompt_token), + prompt_token, + )) } - /// Marks a session as timed out to suppress transient duplicate warnings for late responses. + /// Marks a prompt waiter as timed out to suppress transient duplicate warnings for late responses. pub(crate) fn mark_prompt_waiter_timed_out>( &self, session_id: SessionId, + prompt_token: PromptToken, clock: &C, ) { self.purge_expired_timed_out_waiters(clock); self.timed_out .lock() .unwrap() - .insert(session_id, clock.now()); + .insert((session_id, prompt_token), clock.now()); } /// Drops timeout-suppression markers after a short window. @@ -126,32 +150,63 @@ impl PendingSessionPromptResponseWaiters { }); } - /// Delivers a backend prompt result to the currently waiting caller for `session_id`. - #[allow(dead_code)] + /// Returns true if a late prompt response for this (session, token) should not emit a missing-waiter warning. + pub(crate) fn should_suppress_missing_waiter_warning>( + &self, + session_id: &SessionId, + prompt_token: PromptToken, + _clock: &C, + ) -> bool { + self.timed_out + .lock() + .unwrap() + .contains_key(&(session_id.clone(), prompt_token)) + } + + /// Delivers a backend prompt result to the waiting caller for `(session_id, prompt_token)`. + /// Only resolves if the token matches; late responses for a different prompt are ignored. pub fn resolve_waiter( &self, session_id: &SessionId, + prompt_token: PromptToken, response: std::result::Result, ) -> bool { - let waiter = self.waiters.lock().unwrap().remove(session_id); - self.timed_out.lock().unwrap().remove(session_id); + let mut waiters = self.waiters.lock().unwrap(); + let should_remove = waiters + .get(session_id) + .is_some_and(|e| e.token == prompt_token); + let waiter = if should_remove { + waiters.remove(session_id) + } else { + None + }; + drop(waiters); if let Some(waiter) = waiter { + self.timed_out + .lock() + .unwrap() + .remove(&(session_id.clone(), prompt_token)); waiter.sender.send(response).is_ok() } else { false } } - fn remove_waiter_if_token_matches(&self, session_id: &SessionId, waiter_token: u64) { + fn remove_waiter_if_token_matches(&self, session_id: &SessionId, prompt_token: PromptToken) { let mut waiters = self.waiters.lock().unwrap(); if waiters .get(session_id) - .is_some_and(|entry| entry.token == waiter_token) + .is_some_and(|entry| entry.token == prompt_token) { waiters.remove(session_id); } } + #[cfg(test)] + pub(crate) fn has_waiter(&self, session_id: &SessionId) -> bool { + self.waiters.lock().unwrap().contains_key(session_id) + } + #[cfg(test)] pub(crate) fn remove_waiter_for_test(&self, session_id: &SessionId) { self.waiters.lock().unwrap().remove(session_id); @@ -172,6 +227,7 @@ mod tests { let waiters = PendingSessionPromptResponseWaiters::::new(); let resolved = waiters.resolve_waiter( &SessionId::from("s1"), + PromptToken(0), Ok(PromptResponse::new(StopReason::EndTurn)), ); assert!(!resolved); @@ -181,8 +237,7 @@ mod tests { fn purge_expired_timed_out_waiters_removes_expired_markers() { let waiters = PendingSessionPromptResponseWaiters::::new(); let clock = MockClock::new(); - - waiters.mark_prompt_waiter_timed_out(SessionId::from("s1"), &clock); + waiters.mark_prompt_waiter_timed_out(SessionId::from("s1"), PromptToken(0), &clock); assert_eq!(waiters.timed_out.lock().unwrap().len(), 1); clock.advance(PROMPT_TIMEOUT_WARNING_SUPPRESSION_WINDOW + Duration::from_millis(1)); @@ -196,15 +251,23 @@ mod tests { let waiters = PendingSessionPromptResponseWaiters::::new(); let session_id = SessionId::from("s1"); - let (_rx1, guard1) = waiters.register_waiter(session_id.clone()).unwrap(); - assert!(waiters.resolve_waiter(&session_id, Ok(PromptResponse::new(StopReason::EndTurn)))); + let (_rx1, guard1, token1) = waiters.register_waiter(session_id.clone()).unwrap(); + assert!(waiters.resolve_waiter( + &session_id, + token1, + Ok(PromptResponse::new(StopReason::EndTurn)) + )); - let (rx2, _guard2) = waiters.register_waiter(session_id.clone()).unwrap(); + let (rx2, _guard2, token2) = waiters.register_waiter(session_id.clone()).unwrap(); drop(guard1); assert!( - waiters.resolve_waiter(&session_id, Ok(PromptResponse::new(StopReason::EndTurn))), + waiters.resolve_waiter( + &session_id, + token2, + Ok(PromptResponse::new(StopReason::EndTurn)) + ), "old guard must not remove the new waiter's sender" );