Skip to content

Commit 39015cf

Browse files
committed
make OpenAI streaming strict by default
1 parent e4a2036 commit 39015cf

3 files changed

Lines changed: 127 additions & 39 deletions

File tree

Package.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ let package = Package(
9090
.testTarget(
9191
name: "SwiftBuddyTests",
9292
dependencies: ["SwiftBuddy", "MLXInferenceCore"]
93+
),
94+
.testTarget(
95+
name: "SwiftLMTests",
96+
dependencies: ["SwiftLM"]
9397
)
9498
]
9599
)

Sources/SwiftLM/Server.swift

Lines changed: 90 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ struct MLXServer: AsyncParsableCommand {
661661
do {
662662
let bodyData = try await collectBody(request)
663663
return try await handleChatCompletion(
664-
bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats, promptCache: promptCache,
664+
request: request, bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats, promptCache: promptCache,
665665
draftModelRef: draftModelRef, numDraftTokens: numDraftTokensConfig
666666
)
667667
} catch {
@@ -682,7 +682,7 @@ struct MLXServer: AsyncParsableCommand {
682682
do {
683683
let bodyData = try await collectBody(request)
684684
return try await handleTextCompletion(
685-
bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats
685+
request: request, bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats
686686
)
687687
} catch {
688688
let errMsg = String(describing: error).replacingOccurrences(of: "\"", with: "'")
@@ -1020,6 +1020,7 @@ func collectBody(_ request: Request) async throws -> Data {
10201020
// ── Chat Completions Handler ─────────────────────────────────────────────────
10211021

10221022
func handleChatCompletion(
1023+
request: Request,
10231024
bodyData: Data,
10241025
config: ServerConfig,
10251026
container: ModelContainer,
@@ -1032,6 +1033,7 @@ func handleChatCompletion(
10321033
let chatReq = try JSONDecoder().decode(ChatCompletionRequest.self, from: bodyData)
10331034
let isStream = chatReq.stream ?? false
10341035
let jsonMode = chatReq.responseFormat?.type == "json_object"
1036+
let emitPrefillProgress = prefillProgressEnabled(in: request)
10351037

10361038
// ── Merge per-request overrides with CLI defaults ──
10371039
let tokenLimit = chatReq.maxTokens ?? config.maxTokens
@@ -1284,7 +1286,8 @@ func handleChatCompletion(
12841286
stream: stream, modelId: modelId, stopSequences: stopSequences,
12851287
includeUsage: includeUsage, promptTokenCount: promptTokenCount,
12861288
enableThinking: enableThinking, jsonMode: jsonMode, semaphore: semaphore,
1287-
stats: stats, genStart: genStart, prefillStart: prefillStart, onPrefillDone: onPrefillDone
1289+
stats: stats, genStart: genStart, prefillStart: prefillStart,
1290+
emitPrefillProgress: emitPrefillProgress, onPrefillDone: onPrefillDone
12881291
)
12891292
} else {
12901293
return try await handleChatNonStreaming(
@@ -1384,29 +1387,32 @@ func handleChatStreaming(
13841387
stats: ServerStats,
13851388
genStart: Date,
13861389
prefillStart: Date,
1390+
emitPrefillProgress: Bool,
13871391
onPrefillDone: (() async -> Void)? = nil
13881392
) -> Response {
13891393
let (sseStream, cont) = AsyncStream<String>.makeStream()
13901394

1391-
// ── Prefill heartbeat: emit llama-server-style slot_update progress every 2 s ──
1392-
// n_past is updated by activePrefillProgressHook in LLMModel.prepare() after each
1393-
// 512-token chunk; single-chunk prompts only show elapsed_seconds.
13941395
let prefillState = PrefillState()
1395-
activePrefillProgressHook = { nPast, _ in
1396-
Task { await prefillState.update(nPast: nPast) }
1397-
}
1398-
Task {
1399-
var elapsed = 0
1400-
while await !prefillState.done {
1401-
try? await Task.sleep(for: .seconds(2))
1402-
if await !prefillState.done {
1403-
elapsed += 2
1404-
let nPast = await prefillState.nPast
1405-
_ = cont.yield(ssePrefillChunk(
1406-
modelId: modelId,
1407-
nPast: nPast,
1408-
promptTokens: promptTokenCount,
1409-
elapsedSeconds: elapsed))
1396+
activePrefillProgressHook = nil
1397+
if emitPrefillProgress {
1398+
// ── Optional prefill heartbeat: emit a named SSE event every 2 s ──
1399+
// n_past is updated by activePrefillProgressHook in LLMModel.prepare() after each
1400+
// 512-token chunk; single-chunk prompts only show elapsed_seconds.
1401+
activePrefillProgressHook = { nPast, _ in
1402+
Task { await prefillState.update(nPast: nPast) }
1403+
}
1404+
Task {
1405+
var elapsed = 0
1406+
while await !prefillState.done {
1407+
try? await Task.sleep(for: .seconds(2))
1408+
if await !prefillState.done {
1409+
elapsed += 2
1410+
let nPast = await prefillState.nPast
1411+
_ = cont.yield(ssePrefillChunk(
1412+
nPast: nPast,
1413+
promptTokens: promptTokenCount,
1414+
elapsedSeconds: elapsed))
1415+
}
14101416
}
14111417
}
14121418
}
@@ -1735,6 +1741,7 @@ func extractThinkingBlock(from text: String) -> (String?, String) {
17351741
// ── Text Completions Handler ─────────────────────────────────────────────────
17361742

17371743
func handleTextCompletion(
1744+
request: Request,
17381745
bodyData: Data,
17391746
config: ServerConfig,
17401747
container: ModelContainer,
@@ -1743,6 +1750,7 @@ func handleTextCompletion(
17431750
) async throws -> Response {
17441751
let compReq = try JSONDecoder().decode(TextCompletionRequest.self, from: bodyData)
17451752
let isStream = compReq.stream ?? false
1753+
let emitPrefillProgress = prefillProgressEnabled(in: request)
17461754

17471755
let tokenLimit = compReq.maxTokens ?? config.maxTokens
17481756
let temperature = compReq.temperature.map(Float.init) ?? config.temp
@@ -1783,7 +1791,8 @@ func handleTextCompletion(
17831791
if isStream {
17841792
return handleTextStreaming(
17851793
stream: stream, modelId: modelId, stopSequences: stopSequences,
1786-
semaphore: semaphore, stats: stats, genStart: genStart
1794+
promptTokenCount: promptTokenCount, semaphore: semaphore, stats: stats,
1795+
genStart: genStart, emitPrefillProgress: emitPrefillProgress
17871796
)
17881797
} else {
17891798
return try await handleTextNonStreaming(
@@ -1799,19 +1808,48 @@ func handleTextStreaming(
17991808
stream: AsyncStream<Generation>,
18001809
modelId: String,
18011810
stopSequences: [String],
1811+
promptTokenCount: Int,
18021812
semaphore: AsyncSemaphore,
18031813
stats: ServerStats,
1804-
genStart: Date
1814+
genStart: Date,
1815+
emitPrefillProgress: Bool
18051816
) -> Response {
18061817
let (sseStream, cont) = AsyncStream<String>.makeStream()
1818+
let prefillState = PrefillState()
1819+
activePrefillProgressHook = nil
1820+
if emitPrefillProgress {
1821+
activePrefillProgressHook = { nPast, _ in
1822+
Task { await prefillState.update(nPast: nPast) }
1823+
}
1824+
Task {
1825+
var elapsed = 0
1826+
while await !prefillState.done {
1827+
try? await Task.sleep(for: .seconds(2))
1828+
if await !prefillState.done {
1829+
elapsed += 2
1830+
let nPast = await prefillState.nPast
1831+
_ = cont.yield(ssePrefillChunk(
1832+
nPast: nPast,
1833+
promptTokens: promptTokenCount,
1834+
elapsedSeconds: elapsed))
1835+
}
1836+
}
1837+
}
1838+
}
18071839
Task {
18081840
var completionTokenCount = 0
18091841
var fullText = ""
18101842
var stopped = false
1843+
var firstToken = true
18111844
for await generation in stream {
18121845
if stopped { break }
18131846
switch generation {
18141847
case .chunk(let text, _):
1848+
if firstToken {
1849+
activePrefillProgressHook = nil
1850+
await prefillState.finish()
1851+
firstToken = false
1852+
}
18151853
completionTokenCount += 1
18161854
fullText += text
18171855
// GPU yield: prevent Metal from starving macOS WindowServer
@@ -1834,6 +1872,8 @@ func handleTextStreaming(
18341872
case .toolCall:
18351873
break
18361874
case .info(let info):
1875+
activePrefillProgressHook = nil
1876+
await prefillState.finish()
18371877
if !stopped {
18381878
var reason: String
18391879
switch info.stopReason {
@@ -1979,7 +2019,7 @@ struct CORSMiddleware<Context: RequestContext>: RouterMiddleware {
19792019
}
19802020
}
19812021
fields.append(HTTPField(name: HTTPField.Name("Access-Control-Allow-Methods")!, value: "GET, POST, OPTIONS"))
1982-
fields.append(HTTPField(name: HTTPField.Name("Access-Control-Allow-Headers")!, value: "Content-Type, Authorization"))
2022+
fields.append(HTTPField(name: HTTPField.Name("Access-Control-Allow-Headers")!, value: "Content-Type, Authorization, X-SwiftLM-Prefill-Progress"))
19832023
return HTTPFields(fields)
19842024
}
19852025
}
@@ -2032,6 +2072,22 @@ func jsonHeaders() -> HTTPFields {
20322072
HTTPFields([HTTPField(name: .contentType, value: "application/json")])
20332073
}
20342074

2075+
let prefillProgressHeaderName = HTTPField.Name("X-SwiftLM-Prefill-Progress")!
2076+
2077+
func parseTruthyHeaderValue(_ value: String?) -> Bool {
2078+
guard let value else { return false }
2079+
switch value.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() {
2080+
case "1", "on", "true", "yes":
2081+
return true
2082+
default:
2083+
return false
2084+
}
2085+
}
2086+
2087+
func prefillProgressEnabled(in request: Request) -> Bool {
2088+
parseTruthyHeaderValue(request.headers[values: prefillProgressHeaderName].first)
2089+
}
2090+
20352091
func sseHeaders() -> HTTPFields {
20362092
HTTPFields([
20372093
HTTPField(name: .contentType, value: "text/event-stream"),
@@ -2074,30 +2130,25 @@ func sseChunk(modelId: String, reasoningContent: String?, content: String?, fini
20742130
return "data: \(String(data: data, encoding: .utf8)!)\r\n\r\n"
20752131
}
20762132

2077-
/// Prefill-progress heartbeat chunk — emitted every 2s while the server is processing the prompt.
2078-
/// Uses object type "prefill_progress" so clients can filter it without confusing it with real tokens.
2133+
/// Prefill-progress heartbeat chunk — emitted every 2s while the server is processing the prompt
2134+
/// when explicitly enabled via `X-SwiftLM-Prefill-Progress: true`.
2135+
/// It is sent as a named SSE event to avoid breaking strict OpenAI-compatible clients.
20792136
/// Format mirrors llama-server's slot_update event:
20802137
/// n_past : tokens evaluated so far (real value from chunked prefill, or 0 for single-chunk)
20812138
/// n_prompt_tokens : total prompt token count
20822139
/// fraction : n_past / n_prompt_tokens (0.0–1.0), useful for progress bars
20832140
/// elapsed_seconds : wall-clock time since the request started
2084-
func ssePrefillChunk(modelId: String, nPast: Int = 0, promptTokens: Int, elapsedSeconds: Int) -> String {
2141+
func ssePrefillChunk(nPast: Int = 0, promptTokens: Int, elapsedSeconds: Int) -> String {
20852142
let fraction = promptTokens > 0 ? Double(nPast) / Double(promptTokens) : 0.0
20862143
let chunk: [String: Any] = [
2087-
"id": "prefill-\(UUID().uuidString)",
2088-
"object": "prefill_progress",
2089-
"created": Int(Date().timeIntervalSince1970),
2090-
"model": modelId,
2091-
"prefill": [
2092-
"status": "processing",
2093-
"n_past": nPast,
2094-
"n_prompt_tokens": promptTokens,
2095-
"fraction": fraction,
2096-
"elapsed_seconds": elapsedSeconds
2097-
]
2144+
"status": "processing",
2145+
"n_past": nPast,
2146+
"n_prompt_tokens": promptTokens,
2147+
"fraction": fraction,
2148+
"elapsed_seconds": elapsedSeconds
20982149
]
20992150
let data = try! JSONSerialization.data(withJSONObject: chunk)
2100-
return "data: \(String(data: data, encoding: .utf8)!)\r\n\r\n"
2151+
return "event: prefill_progress\r\ndata: \(String(data: data, encoding: .utf8)!)\r\n\r\n"
21012152
}
21022153

21032154
func sseUsageChunk(modelId: String, promptTokens: Int, completionTokens: Int) -> String {
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import XCTest
2+
@testable import SwiftLM
3+
4+
final class ServerSSETests: XCTestCase {
5+
func testParseTruthyHeaderValue() {
6+
XCTAssertTrue(parseTruthyHeaderValue("true"))
7+
XCTAssertTrue(parseTruthyHeaderValue("TRUE"))
8+
XCTAssertTrue(parseTruthyHeaderValue(" yes "))
9+
XCTAssertTrue(parseTruthyHeaderValue("1"))
10+
XCTAssertFalse(parseTruthyHeaderValue(nil))
11+
XCTAssertFalse(parseTruthyHeaderValue("false"))
12+
XCTAssertFalse(parseTruthyHeaderValue("0"))
13+
}
14+
15+
func testPrefillChunkUsesNamedEventAndLeanPayload() throws {
16+
let chunk = ssePrefillChunk(nPast: 32, promptTokens: 128, elapsedSeconds: 4)
17+
18+
XCTAssertTrue(chunk.hasPrefix("event: prefill_progress\r\ndata: "))
19+
XCTAssertTrue(chunk.hasSuffix("\r\n\r\n"))
20+
21+
let prefix = "event: prefill_progress\r\ndata: "
22+
let payload = String(chunk.dropFirst(prefix.count).dropLast(4))
23+
let data = try XCTUnwrap(payload.data(using: .utf8))
24+
let json = try XCTUnwrap(JSONSerialization.jsonObject(with: data) as? [String: Any])
25+
26+
XCTAssertEqual(json["status"] as? String, "processing")
27+
XCTAssertEqual(json["n_past"] as? Int, 32)
28+
XCTAssertEqual(json["n_prompt_tokens"] as? Int, 128)
29+
XCTAssertEqual(json["elapsed_seconds"] as? Int, 4)
30+
XCTAssertNil(json["object"])
31+
XCTAssertNil(json["choices"])
32+
}
33+
}

0 commit comments

Comments
 (0)