Skip to content

Commit a3ae80a

Browse files
committed
fix(openai): preserve upload cancellation during retry
1 parent da7900a commit a3ae80a

2 files changed

Lines changed: 55 additions & 0 deletions

File tree

Sources/AgentRunKit/LLM/Providers/OpenAIChat/OpenAIClientStreaming.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ extension OpenAIClient {
7777
do {
7878
(data, response) = try await session.upload(for: urlRequest, fromFile: bodyFileURL)
7979
} catch {
80+
if HTTPRetry.isCancellation(error) {
81+
throw CancellationError()
82+
}
8083
lastError = TransportError.networkError(error)
8184
continue
8285
}

Tests/AgentRunKitTests/LLM/Core/LLMClientTests.swift

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,30 @@ private let malformedToolHistory: [ChatMessage] = [
113113
)),
114114
]
115115

116+
/// @unchecked Sendable justification: URLProtocol is Foundation infrastructure and has no
117+
/// shared mutable state.
118+
private final class CancelledUploadURLProtocol: URLProtocol, @unchecked Sendable {
119+
static func configuration() -> URLSessionConfiguration {
120+
let configuration = URLSessionConfiguration.ephemeral
121+
configuration.protocolClasses = [CancelledUploadURLProtocol.self]
122+
return configuration
123+
}
124+
125+
override static func canInit(with request: URLRequest) -> Bool {
126+
request.url != nil
127+
}
128+
129+
override static func canonicalRequest(for request: URLRequest) -> URLRequest {
130+
request
131+
}
132+
133+
override func startLoading() {
134+
client?.urlProtocol(self, didFailWithError: URLError(.cancelled))
135+
}
136+
137+
override func stopLoading() {}
138+
}
139+
116140
struct ProviderHistoryValidationTests {
117141
private func assertGenerateRejectsMalformedHistory(client: any LLMClient) async {
118142
await #expect(throws: AgentError.malformedHistory(.unfinishedToolCallBatch(ids: ["call_1"]))) {
@@ -599,6 +623,34 @@ struct OpenAIClientURLRequestTests {
599623
#expect(multipartPart(named: "model", parts: parts)?.body == "whisper-1")
600624
#expect(multipartPart(named: "file", parts: parts)?.body == "audio-data")
601625
}
626+
627+
@Test
628+
func transcribeFileUploadCancellationPropagatesCancellationError() async throws {
629+
let session = URLSession(configuration: CancelledUploadURLProtocol.configuration())
630+
defer { session.invalidateAndCancel() }
631+
let client = try OpenAIClient(
632+
apiKey: "test-key",
633+
model: "test/model",
634+
baseURL: #require(URL(string: "https://cancelled-upload.test/v1")),
635+
session: session
636+
)
637+
let audioURL = FileManager.default.temporaryDirectory
638+
.appendingPathComponent("swiftagent-cancelled-upload-\(UUID().uuidString).wav")
639+
try Data("audio-data".utf8).write(to: audioURL)
640+
defer { try? FileManager.default.removeItem(at: audioURL) }
641+
642+
do {
643+
_ = try await client.transcribe(
644+
audioFileURL: audioURL,
645+
format: .wav,
646+
model: "whisper-1"
647+
)
648+
Issue.record("Expected CancellationError")
649+
} catch is CancellationError {
650+
} catch {
651+
Issue.record("Expected CancellationError, got \(type(of: error)): \(error)")
652+
}
653+
}
602654
}
603655

604656
struct ReasoningConfigTests {

0 commit comments

Comments
 (0)