Skip to content

Commit d0f02ad

Browse files
committed
fix(acp): Use pending version during initialize
1 parent 63133e6 commit d0f02ad

2 files changed

Lines changed: 84 additions & 13 deletions

File tree

src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,7 @@ pub(super) async fn outgoing_protocol_actor(
6868
} => match protocol_compat.outgoing_response(&method, response) {
6969
Ok(value) => {
7070
tracing::debug!(?id, "Sending success response");
71-
jsonrpcmsg::Message::Response(jsonrpcmsg::Response::success_v2(
72-
value,
73-
Some(id),
74-
))
71+
jsonrpcmsg::Message::Response(jsonrpcmsg::Response::success_v2(value, Some(id)))
7572
}
7673
Err(error) => {
7774
tracing::warn!(?id, %method, ?error, "Sending error response");

src/agent-client-protocol/src/jsonrpc/protocol_compat.rs

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ mod imp {
134134
(Self::Disabled, other) => other,
135135
(this, Self::Disabled) => this,
136136
(Self::Acp(this), Self::Acp(other)) => {
137-
debug_assert_eq!(
137+
assert_eq!(
138138
this.api, other.api,
139139
"cannot merge ACP builders with different API protocol versions; \
140140
handler chains share a single API surface",
@@ -212,7 +212,7 @@ mod imp {
212212
return self.incoming_initialize_request(mode, message);
213213
}
214214

215-
convert_message(message, self.negotiated(), mode.api)
215+
convert_message(message, self.active_wire_version(), mode.api)
216216
}
217217

218218
pub(crate) fn outgoing_message(
@@ -225,9 +225,10 @@ mod imp {
225225

226226
let wire_version = if message.method() == "initialize" {
227227
set_protocol_version(&mut message.params, mode.latest_supported)?;
228+
self.set_pending_initialize(mode.latest_supported);
228229
mode.latest_supported
229230
} else {
230-
self.negotiated()
231+
self.active_wire_version()
231232
};
232233

233234
convert_message(message, mode.api, wire_version)
@@ -247,7 +248,7 @@ mod imp {
247248
return self.incoming_initialize_response(mode, value);
248249
}
249250

250-
convert_response(method, value, self.negotiated(), mode.api)
251+
convert_response(method, value, self.active_wire_version(), mode.api)
251252
}
252253

253254
pub(crate) fn outgoing_response(
@@ -279,7 +280,7 @@ mod imp {
279280
self.set_negotiated(negotiated);
280281
negotiated
281282
} else {
282-
self.negotiated()
283+
self.active_wire_version()
283284
};
284285

285286
convert_response(method, value, mode.api, wire_version)
@@ -306,6 +307,7 @@ mod imp {
306307
mode: AcpProtocolMode,
307308
mut value: serde_json::Value,
308309
) -> Result<serde_json::Value, crate::Error> {
310+
let _pending_initialize = self.take_pending_initialize();
309311
let Some(response_version) = protocol_version_from_value(&value) else {
310312
return Ok(value);
311313
};
@@ -351,11 +353,12 @@ mod imp {
351353
}
352354
}
353355

354-
fn negotiated(&self) -> ProtocolVersionKind {
355-
self.state
356+
fn active_wire_version(&self) -> ProtocolVersionKind {
357+
let state = self
358+
.state
356359
.lock()
357-
.expect("protocol compatibility state mutex poisoned")
358-
.negotiated
360+
.expect("protocol compatibility state mutex poisoned");
361+
state.pending_initialize.unwrap_or(state.negotiated)
359362
}
360363

361364
fn set_negotiated(&self, negotiated: ProtocolVersionKind) {
@@ -618,6 +621,77 @@ mod imp {
618621
negotiated.as_protocol_version(),
619622
))
620623
}
624+
625+
#[cfg(test)]
626+
mod tests {
627+
use super::*;
628+
629+
fn negotiated(compat: &ProtocolCompat) -> ProtocolVersionKind {
630+
compat
631+
.state
632+
.lock()
633+
.expect("protocol compatibility state mutex poisoned")
634+
.negotiated
635+
}
636+
637+
#[test]
638+
fn initialize_request_sets_active_wire_version_before_response() -> Result<(), crate::Error>
639+
{
640+
let compat = ProtocolCompat::new(ProtocolMode::v2_agent());
641+
assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V1);
642+
643+
compat.incoming_message(UntypedMessage::new(
644+
"initialize",
645+
v2::InitializeRequest::new(ProtocolVersion::V2),
646+
)?)?;
647+
648+
assert_eq!(negotiated(&compat), ProtocolVersionKind::V1);
649+
assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V2);
650+
651+
compat.outgoing_response(
652+
"initialize",
653+
Ok(serde_json::to_value(v2::InitializeResponse::new(
654+
ProtocolVersion::V2,
655+
))?),
656+
)?;
657+
658+
assert_eq!(negotiated(&compat), ProtocolVersionKind::V2);
659+
assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V2);
660+
Ok(())
661+
}
662+
663+
#[test]
664+
fn outgoing_initialize_sets_active_wire_version_before_response() -> Result<(), crate::Error>
665+
{
666+
let compat = ProtocolCompat::new(ProtocolMode::v2_client());
667+
assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V1);
668+
669+
compat.outgoing_message(UntypedMessage::new(
670+
"initialize",
671+
v2::InitializeRequest::new(ProtocolVersion::V1),
672+
)?)?;
673+
674+
assert_eq!(negotiated(&compat), ProtocolVersionKind::V1);
675+
assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V2);
676+
677+
compat.incoming_response(
678+
"initialize",
679+
Ok(serde_json::to_value(v2::InitializeResponse::new(
680+
ProtocolVersion::V2,
681+
))?),
682+
)?;
683+
684+
assert_eq!(negotiated(&compat), ProtocolVersionKind::V2);
685+
assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V2);
686+
Ok(())
687+
}
688+
689+
#[test]
690+
#[should_panic(expected = "cannot merge ACP builders with different API protocol versions")]
691+
fn merging_different_api_protocol_modes_panics() {
692+
let _ = ProtocolMode::v1_agent().merge(ProtocolMode::v2_agent());
693+
}
694+
}
621695
}
622696

623697
pub(crate) use imp::{ProtocolCompat, ProtocolMode};

0 commit comments

Comments
 (0)