Skip to content

Commit 25bea00

Browse files
committed
fix(acp): disarm cancellation for buffered responses
1 parent ddb19ba commit 25bea00

4 files changed

Lines changed: 146 additions & 20 deletions

File tree

src/agent-client-protocol/src/jsonrpc.rs

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,9 @@ enum ReplyMessage {
13581358
method: String,
13591359

13601360
sender: oneshot::Sender<ResponsePayload>,
1361+
1362+
#[cfg(feature = "unstable_cancel_request")]
1363+
cancellation_disarm: SentRequestCancellationDisarm,
13611364
},
13621365
}
13631366

@@ -1654,6 +1657,9 @@ enum OutgoingMessage {
16541657

16551658
/// where to send the response when it arrives (includes ack channel)
16561659
response_tx: oneshot::Sender<ResponsePayload>,
1660+
1661+
#[cfg(feature = "unstable_cancel_request")]
1662+
cancellation_disarm: SentRequestCancellationDisarm,
16571663
},
16581664

16591665
/// Send a notification to the server.
@@ -2003,6 +2009,8 @@ impl<Counterpart: Role> ConnectionTo<Counterpart> {
20032009
role_id,
20042010
untyped,
20052011
response_tx,
2012+
#[cfg(feature = "unstable_cancel_request")]
2013+
cancellation_disarm: cancellation.disarm_handle(),
20062014
};
20072015

20082016
match self.message_tx.unbounded_send(message) {
@@ -2455,6 +2463,8 @@ impl ResponseRouter<serde_json::Value> {
24552463
id: jsonrpcmsg::Id,
24562464
role_id: RoleId,
24572465
sender: oneshot::Sender<ResponsePayload>,
2466+
#[cfg(feature = "unstable_cancel_request")]
2467+
cancellation_disarm: SentRequestCancellationDisarm,
24582468
) -> Self {
24592469
let response_method = method.clone();
24602470
let response_id = id.clone();
@@ -2475,6 +2485,9 @@ impl ResponseRouter<serde_json::Value> {
24752485
id = ?response_id,
24762486
"dropped response because local receiver was gone"
24772487
);
2488+
} else {
2489+
#[cfg(feature = "unstable_cancel_request")]
2490+
cancellation_disarm.disarm();
24782491
}
24792492
Ok(())
24802493
}),
@@ -3184,16 +3197,34 @@ fn jsonrpc_id_to_request_id(id: &jsonrpcmsg::Id) -> Result<crate::schema::Reques
31843197
}
31853198

31863199
#[cfg(feature = "unstable_cancel_request")]
3187-
#[derive(Clone)]
3200+
#[derive(Clone, Debug)]
3201+
pub(crate) struct SentRequestCancellationDisarm {
3202+
armed: Arc<AtomicBool>,
3203+
}
3204+
3205+
#[cfg(feature = "unstable_cancel_request")]
3206+
impl SentRequestCancellationDisarm {
3207+
fn new() -> Self {
3208+
Self {
3209+
armed: Arc::new(AtomicBool::new(true)),
3210+
}
3211+
}
3212+
3213+
fn disarm(&self) {
3214+
self.armed.store(false, Ordering::Release);
3215+
}
3216+
}
3217+
3218+
#[cfg(feature = "unstable_cancel_request")]
31883219
enum SentRequestCancellation {
31893220
Send {
31903221
message_tx: OutgoingMessageTx,
31913222
notification: UntypedMessage,
3192-
armed: Arc<AtomicBool>,
3223+
disarm: SentRequestCancellationDisarm,
31933224
},
31943225
Failed {
31953226
error: String,
3196-
armed: Arc<AtomicBool>,
3227+
disarm: SentRequestCancellationDisarm,
31973228
},
31983229
}
31993230

@@ -3211,25 +3242,25 @@ impl SentRequestCancellation {
32113242
)
32123243
})
32133244
.map_err(|error| error.to_string());
3245+
let disarm = SentRequestCancellationDisarm::new();
32143246

32153247
match notification {
32163248
Ok(notification) => Self::Send {
32173249
message_tx,
32183250
notification,
3219-
armed: Arc::new(AtomicBool::new(true)),
3220-
},
3221-
Err(error) => Self::Failed {
3222-
error,
3223-
armed: Arc::new(AtomicBool::new(true)),
3251+
disarm,
32243252
},
3253+
Err(error) => Self::Failed { error, disarm },
32253254
}
32263255
}
32273256

32283257
fn disarm(&self) {
3258+
self.disarm_handle().disarm();
3259+
}
3260+
3261+
fn disarm_handle(&self) -> SentRequestCancellationDisarm {
32293262
match self {
3230-
Self::Send { armed, .. } | Self::Failed { armed, .. } => {
3231-
armed.store(false, Ordering::Release);
3232-
}
3263+
Self::Send { disarm, .. } | Self::Failed { disarm, .. } => disarm.clone(),
32333264
}
32343265
}
32353266

@@ -3238,9 +3269,9 @@ impl SentRequestCancellation {
32383269
Self::Send {
32393270
message_tx,
32403271
notification,
3241-
armed,
3272+
disarm,
32423273
} => {
3243-
if !armed.swap(false, Ordering::AcqRel) {
3274+
if !disarm.armed.swap(false, Ordering::AcqRel) {
32443275
return Ok(());
32453276
}
32463277

@@ -3251,8 +3282,8 @@ impl SentRequestCancellation {
32513282
},
32523283
)
32533284
}
3254-
Self::Failed { error, armed } => {
3255-
if !armed.swap(false, Ordering::AcqRel) {
3285+
Self::Failed { error, disarm } => {
3286+
if !disarm.armed.swap(false, Ordering::AcqRel) {
32563287
return Ok(());
32573288
}
32583289

@@ -3279,17 +3310,17 @@ impl Debug for SentRequestCancellation {
32793310
match self {
32803311
Self::Send {
32813312
notification,
3282-
armed,
3313+
disarm,
32833314
..
32843315
} => f
32853316
.debug_struct("SentRequestCancellation")
32863317
.field("notification", notification)
3287-
.field("armed", &armed.load(Ordering::Acquire))
3318+
.field("armed", &disarm.armed.load(Ordering::Acquire))
32883319
.finish(),
3289-
Self::Failed { error, armed } => f
3320+
Self::Failed { error, disarm } => f
32903321
.debug_struct("SentRequestCancellation")
32913322
.field("error", error)
3292-
.field("armed", &armed.load(Ordering::Acquire))
3323+
.field("armed", &disarm.armed.load(Ordering::Acquire))
32933324
.finish(),
32943325
}
32953326
}

src/agent-client-protocol/src/jsonrpc/incoming_actor.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ struct PendingReply {
3030
method: String,
3131
role_id: RoleId,
3232
sender: oneshot::Sender<crate::jsonrpc::ResponsePayload>,
33+
#[cfg(feature = "unstable_cancel_request")]
34+
cancellation_disarm: super::SentRequestCancellationDisarm,
3335
}
3436

3537
/// Incoming protocol actor: The central dispatch loop for a connection.
@@ -78,6 +80,8 @@ pub(super) async fn incoming_protocol_actor<Counterpart: Role>(
7880
role_id,
7981
method,
8082
sender,
83+
#[cfg(feature = "unstable_cancel_request")]
84+
cancellation_disarm,
8185
} => {
8286
tracing::trace!(?id, %method, "incoming_actor: subscribing to response");
8387
let id = serde_json::to_value(&id).unwrap();
@@ -87,6 +91,8 @@ pub(super) async fn incoming_protocol_actor<Counterpart: Role>(
8791
method,
8892
role_id,
8993
sender,
94+
#[cfg(feature = "unstable_cancel_request")]
95+
cancellation_disarm,
9096
},
9197
);
9298
}
@@ -260,10 +266,19 @@ fn dispatch_from_response(
260266
method,
261267
role_id,
262268
sender,
269+
#[cfg(feature = "unstable_cancel_request")]
270+
cancellation_disarm,
263271
} = pending_reply;
264272

265273
// Create a Dispatch::Response with a ResponseRouter that routes to the oneshot
266-
let router = ResponseRouter::new(method.clone(), id.clone(), role_id, sender);
274+
let router = ResponseRouter::new(
275+
method.clone(),
276+
id.clone(),
277+
role_id,
278+
sender,
279+
#[cfg(feature = "unstable_cancel_request")]
280+
cancellation_disarm,
281+
);
267282
Dispatch::Response(result, router)
268283
}
269284

src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ pub(super) async fn outgoing_protocol_actor(
4141
method,
4242
untyped,
4343
response_tx,
44+
#[cfg(feature = "unstable_cancel_request")]
45+
cancellation_disarm,
4446
} => {
4547
let request = match protocol_compat
4648
.outgoing_message(untyped)
@@ -49,6 +51,8 @@ pub(super) async fn outgoing_protocol_actor(
4951
Ok(request) => request,
5052
Err(error) => {
5153
tracing::warn!(?id, %method, ?error, "Failed to convert outgoing request");
54+
#[cfg(feature = "unstable_cancel_request")]
55+
cancellation_disarm.disarm();
5256
complete_request_with_error(response_tx, error);
5357
continue;
5458
}
@@ -61,6 +65,8 @@ pub(super) async fn outgoing_protocol_actor(
6165
role_id,
6266
method,
6367
sender: response_tx,
68+
#[cfg(feature = "unstable_cancel_request")]
69+
cancellation_disarm,
6470
})
6571
.map_err(crate::Error::into_internal_error)?;
6672

@@ -167,6 +173,8 @@ mod tests {
167173
method: "session/new".into(),
168174
untyped: malformed_v2_known_method()?,
169175
response_tx,
176+
#[cfg(feature = "unstable_cancel_request")]
177+
cancellation_disarm: crate::jsonrpc::SentRequestCancellationDisarm::new(),
170178
})
171179
.map_err(crate::Error::into_internal_error)?;
172180
drop(outgoing_tx);

src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,78 @@ async fn late_response_after_dropped_sent_request_does_not_close_connection() {
628628
.await;
629629
}
630630

631+
#[tokio::test(flavor = "current_thread")]
632+
async fn response_buffered_before_drop_disarms_auto_cancellation() {
633+
use tokio::task::LocalSet;
634+
635+
let local = LocalSet::new();
636+
637+
local
638+
.run_until(async {
639+
let received = Arc::new(Mutex::new(Vec::new()));
640+
let received_for_handler = received.clone();
641+
642+
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();
643+
let server_transport =
644+
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
645+
let server = UntypedRole
646+
.builder()
647+
.on_receive_request(
648+
async |request: SimpleRequest,
649+
responder: Responder<SimpleResponse>,
650+
_connection: ConnectionTo<UntypedRole>| {
651+
responder.respond(SimpleResponse {
652+
result: format!("echo: {}", request.message),
653+
})
654+
},
655+
agent_client_protocol::on_receive_request!(),
656+
)
657+
.on_receive_notification(
658+
async move |notification: CancelRequestNotification,
659+
_connection: ConnectionTo<UntypedRole>| {
660+
received_for_handler
661+
.lock()
662+
.unwrap()
663+
.push(notification.request_id);
664+
Ok(())
665+
},
666+
agent_client_protocol::on_receive_notification!(),
667+
);
668+
669+
tokio::task::spawn_local(async move {
670+
if let Err(error) = server.connect_to(server_transport).await {
671+
panic!("server should stay alive: {error:?}");
672+
}
673+
});
674+
675+
let client_transport =
676+
agent_client_protocol::ByteStreams::new(client_writer, client_reader);
677+
let response = UntypedRole
678+
.builder()
679+
.connect_with(client_transport, async |cx| {
680+
let request: SentRequest<SimpleResponse> = cx.send_request(SimpleRequest {
681+
message: "buffered".into(),
682+
});
683+
684+
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
685+
drop(request);
686+
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
687+
688+
cx.send_request(SimpleRequest {
689+
message: "after buffered".into(),
690+
})
691+
.block_task()
692+
.await
693+
})
694+
.await
695+
.unwrap();
696+
697+
assert_eq!(response.result, "echo: after buffered");
698+
assert!(received.lock().unwrap().is_empty());
699+
})
700+
.await;
701+
}
702+
631703
#[tokio::test(flavor = "current_thread")]
632704
async fn completed_sent_request_does_not_send_cancellation_on_drop() {
633705
use tokio::task::LocalSet;

0 commit comments

Comments
 (0)