@@ -134,7 +134,7 @@ mod imp {
134134 ( Self :: Disabled , other) => other,
135135 ( this, Self :: Disabled ) => this,
136136 ( Self :: Acp ( this) , Self :: Acp ( other) ) => {
137- debug_assert_eq ! (
137+ assert_eq ! (
138138 this. api, other. api,
139139 "cannot merge ACP builders with different API protocol versions; \
140140 handler chains share a single API surface",
@@ -212,7 +212,7 @@ mod imp {
212212 return self . incoming_initialize_request ( mode, message) ;
213213 }
214214
215- convert_message ( message, self . negotiated ( ) , mode. api )
215+ convert_message ( message, self . active_wire_version ( ) , mode. api )
216216 }
217217
218218 pub ( crate ) fn outgoing_message (
@@ -225,9 +225,10 @@ mod imp {
225225
226226 let wire_version = if message. method ( ) == "initialize" {
227227 set_protocol_version ( & mut message. params , mode. latest_supported ) ?;
228+ self . set_pending_initialize ( mode. latest_supported ) ;
228229 mode. latest_supported
229230 } else {
230- self . negotiated ( )
231+ self . active_wire_version ( )
231232 } ;
232233
233234 convert_message ( message, mode. api , wire_version)
@@ -247,7 +248,7 @@ mod imp {
247248 return self . incoming_initialize_response ( mode, value) ;
248249 }
249250
250- convert_response ( method, value, self . negotiated ( ) , mode. api )
251+ convert_response ( method, value, self . active_wire_version ( ) , mode. api )
251252 }
252253
253254 pub ( crate ) fn outgoing_response (
@@ -279,7 +280,7 @@ mod imp {
279280 self . set_negotiated ( negotiated) ;
280281 negotiated
281282 } else {
282- self . negotiated ( )
283+ self . active_wire_version ( )
283284 } ;
284285
285286 convert_response ( method, value, mode. api , wire_version)
@@ -306,6 +307,7 @@ mod imp {
306307 mode : AcpProtocolMode ,
307308 mut value : serde_json:: Value ,
308309 ) -> Result < serde_json:: Value , crate :: Error > {
310+ let _pending_initialize = self . take_pending_initialize ( ) ;
309311 let Some ( response_version) = protocol_version_from_value ( & value) else {
310312 return Ok ( value) ;
311313 } ;
@@ -351,11 +353,12 @@ mod imp {
351353 }
352354 }
353355
354- fn negotiated ( & self ) -> ProtocolVersionKind {
355- self . state
356+ fn active_wire_version ( & self ) -> ProtocolVersionKind {
357+ let state = self
358+ . state
356359 . lock ( )
357- . expect ( "protocol compatibility state mutex poisoned" )
358- . negotiated
360+ . expect ( "protocol compatibility state mutex poisoned" ) ;
361+ state . pending_initialize . unwrap_or ( state . negotiated )
359362 }
360363
361364 fn set_negotiated ( & self , negotiated : ProtocolVersionKind ) {
@@ -618,6 +621,77 @@ mod imp {
618621 negotiated. as_protocol_version( ) ,
619622 ) )
620623 }
624+
625+ #[ cfg( test) ]
626+ mod tests {
627+ use super :: * ;
628+
629+ fn negotiated ( compat : & ProtocolCompat ) -> ProtocolVersionKind {
630+ compat
631+ . state
632+ . lock ( )
633+ . expect ( "protocol compatibility state mutex poisoned" )
634+ . negotiated
635+ }
636+
637+ #[ test]
638+ fn initialize_request_sets_active_wire_version_before_response ( ) -> Result < ( ) , crate :: Error >
639+ {
640+ let compat = ProtocolCompat :: new ( ProtocolMode :: v2_agent ( ) ) ;
641+ assert_eq ! ( compat. active_wire_version( ) , ProtocolVersionKind :: V1 ) ;
642+
643+ compat. incoming_message ( UntypedMessage :: new (
644+ "initialize" ,
645+ v2:: InitializeRequest :: new ( ProtocolVersion :: V2 ) ,
646+ ) ?) ?;
647+
648+ assert_eq ! ( negotiated( & compat) , ProtocolVersionKind :: V1 ) ;
649+ assert_eq ! ( compat. active_wire_version( ) , ProtocolVersionKind :: V2 ) ;
650+
651+ compat. outgoing_response (
652+ "initialize" ,
653+ Ok ( serde_json:: to_value ( v2:: InitializeResponse :: new (
654+ ProtocolVersion :: V2 ,
655+ ) ) ?) ,
656+ ) ?;
657+
658+ assert_eq ! ( negotiated( & compat) , ProtocolVersionKind :: V2 ) ;
659+ assert_eq ! ( compat. active_wire_version( ) , ProtocolVersionKind :: V2 ) ;
660+ Ok ( ( ) )
661+ }
662+
663+ #[ test]
664+ fn outgoing_initialize_sets_active_wire_version_before_response ( ) -> Result < ( ) , crate :: Error >
665+ {
666+ let compat = ProtocolCompat :: new ( ProtocolMode :: v2_client ( ) ) ;
667+ assert_eq ! ( compat. active_wire_version( ) , ProtocolVersionKind :: V1 ) ;
668+
669+ compat. outgoing_message ( UntypedMessage :: new (
670+ "initialize" ,
671+ v2:: InitializeRequest :: new ( ProtocolVersion :: V1 ) ,
672+ ) ?) ?;
673+
674+ assert_eq ! ( negotiated( & compat) , ProtocolVersionKind :: V1 ) ;
675+ assert_eq ! ( compat. active_wire_version( ) , ProtocolVersionKind :: V2 ) ;
676+
677+ compat. incoming_response (
678+ "initialize" ,
679+ Ok ( serde_json:: to_value ( v2:: InitializeResponse :: new (
680+ ProtocolVersion :: V2 ,
681+ ) ) ?) ,
682+ ) ?;
683+
684+ assert_eq ! ( negotiated( & compat) , ProtocolVersionKind :: V2 ) ;
685+ assert_eq ! ( compat. active_wire_version( ) , ProtocolVersionKind :: V2 ) ;
686+ Ok ( ( ) )
687+ }
688+
689+ #[ test]
690+ #[ should_panic( expected = "cannot merge ACP builders with different API protocol versions" ) ]
691+ fn merging_different_api_protocol_modes_panics ( ) {
692+ let _ = ProtocolMode :: v1_agent ( ) . merge ( ProtocolMode :: v2_agent ( ) ) ;
693+ }
694+ }
621695}
622696
623697pub ( crate ) use imp:: { ProtocolCompat , ProtocolMode } ;
0 commit comments