Skip to content

Commit 3bb7acf

Browse files
committed
fix(mcp): always disconnect transport on shutdown
1 parent 4f9eb71 commit 3bb7acf

3 files changed

Lines changed: 41 additions & 27 deletions

File tree

Sources/AgentRunKit/MCP/MCPClient.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ public actor MCPClient {
163163
}
164164

165165
func shutdown() async {
166-
guard drainPendingRequests() else { return }
166+
_ = drainPendingRequests()
167167
readerTask?.cancel()
168168
readerTask = nil
169169
await transport.disconnect()

Tests/AgentRunKitTests/MCP/DynamicMCPTransport.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ actor DynamicMCPTransport: MCPTransport {
66
private let stream: AsyncThrowingStream<Data, Error>
77
private let continuation: AsyncThrowingStream<Data, Error>.Continuation
88
private var connected = false
9+
private var effectiveDisconnectCount = 0
910

1011
init(handler: @escaping @Sendable (Data) async throws -> Data?) {
1112
self.handler = handler
@@ -19,6 +20,9 @@ actor DynamicMCPTransport: MCPTransport {
1920
}
2021

2122
func disconnect() async {
23+
if connected {
24+
effectiveDisconnectCount += 1
25+
}
2226
connected = false
2327
continuation.finish()
2428
}
@@ -45,4 +49,8 @@ actor DynamicMCPTransport: MCPTransport {
4549
func terminateStreamWithError(_ error: any Error) {
4650
continuation.finish(throwing: error)
4751
}
52+
53+
func effectiveDisconnectCallCount() -> Int {
54+
effectiveDisconnectCount
55+
}
4856
}

Tests/AgentRunKitTests/MCP/MCPClientTests.swift

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ struct MCPClientTests {
5555

5656
@Test
5757
func initializeVersionMismatch() async throws {
58-
let transport = ScriptedMCPTransport(responses: [
59-
MCPTestHelpers.encodeResponse(
60-
id: 1,
58+
let transport = DynamicMCPTransport { data in
59+
guard let request = decodeRequest(data) else { return nil }
60+
let idValue: Int = if case let .int(val) = request.id { val } else { 0 }
61+
return MCPTestHelpers.encodeResponse(
62+
id: idValue,
6163
result: MCPTestHelpers.initializeResult(protocolVersion: "1999-01-01")
62-
),
63-
])
64+
)
65+
}
6466
let client = MCPClient(serverName: "test", transport: transport)
6567
do {
6668
try await client.connectAndInitialize()
@@ -74,6 +76,7 @@ struct MCPClientTests {
7476
#expect(supported == "1999-01-01")
7577
}
7678
await client.shutdown()
79+
await #expect(transport.effectiveDisconnectCallCount() == 1)
7780
}
7881

7982
@Test
@@ -247,27 +250,6 @@ struct MCPClientTests {
247250
await client.shutdown()
248251
}
249252

250-
@Test
251-
func transportClosedMidCall() async throws {
252-
let transport = DynamicMCPTransport(handler: standardHandler { _ in nil })
253-
let client = MCPClient(serverName: "test", transport: transport)
254-
try await client.connectAndInitialize()
255-
256-
let callTask = Task {
257-
try await client.callTool(name: "test", arguments: Data("{}".utf8))
258-
}
259-
260-
try await Task.sleep(for: .milliseconds(50))
261-
await transport.terminateStream()
262-
263-
do {
264-
_ = try await callTask.value
265-
Issue.record("Expected transportClosed")
266-
} catch let error as MCPError {
267-
#expect(error == .transportClosed)
268-
}
269-
}
270-
271253
@Test
272254
func requestTimeout() async throws {
273255
let transport = DynamicMCPTransport(handler: standardHandler { _ in nil })
@@ -469,13 +451,37 @@ struct MCPClientTests {
469451
}
470452

471453
struct MCPClientEdgeCaseTests {
454+
@Test
455+
func transportClosedMidCall() async throws {
456+
let transport = DynamicMCPTransport(handler: standardHandler { _ in nil })
457+
let client = MCPClient(serverName: "test", transport: transport)
458+
try await client.connectAndInitialize()
459+
460+
let callTask = Task {
461+
try await client.callTool(name: "test", arguments: Data("{}".utf8))
462+
}
463+
464+
try await Task.sleep(for: .milliseconds(50))
465+
await transport.terminateStream()
466+
467+
do {
468+
_ = try await callTask.value
469+
Issue.record("Expected transportClosed")
470+
} catch let error as MCPError {
471+
#expect(error == .transportClosed)
472+
}
473+
await client.shutdown()
474+
await #expect(transport.effectiveDisconnectCallCount() == 1)
475+
}
476+
472477
@Test
473478
func shutdownIdempotent() async throws {
474479
let transport = DynamicMCPTransport(handler: standardHandler())
475480
let client = MCPClient(serverName: "test", transport: transport)
476481
try await client.connectAndInitialize()
477482
await client.shutdown()
478483
await client.shutdown()
484+
await #expect(transport.effectiveDisconnectCallCount() == 1)
479485

480486
await #expect(throws: MCPError.transportClosed) {
481487
try await client.callTool(name: "test", arguments: Data("{}".utf8))

0 commit comments

Comments
 (0)