@@ -661,7 +661,7 @@ struct MLXServer: AsyncParsableCommand {
661661 do {
662662 let bodyData = try await collectBody ( request)
663663 return try await handleChatCompletion (
664- bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats, promptCache: promptCache,
664+ request : request , bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats, promptCache: promptCache,
665665 draftModelRef: draftModelRef, numDraftTokens: numDraftTokensConfig
666666 )
667667 } catch {
@@ -682,7 +682,7 @@ struct MLXServer: AsyncParsableCommand {
682682 do {
683683 let bodyData = try await collectBody ( request)
684684 return try await handleTextCompletion (
685- bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats
685+ request : request , bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats
686686 )
687687 } catch {
688688 let errMsg = String ( describing: error) . replacingOccurrences ( of: " \" " , with: " ' " )
@@ -1020,6 +1020,7 @@ func collectBody(_ request: Request) async throws -> Data {
10201020// ── Chat Completions Handler ─────────────────────────────────────────────────
10211021
10221022func handleChatCompletion(
1023+ request: Request,
10231024 bodyData: Data,
10241025 config: ServerConfig,
10251026 container: ModelContainer,
@@ -1032,6 +1033,7 @@ func handleChatCompletion(
10321033 let chatReq = try JSONDecoder ( ) . decode ( ChatCompletionRequest . self, from: bodyData)
10331034 let isStream = chatReq. stream ?? false
10341035 let jsonMode = chatReq. responseFormat? . type == " json_object "
1036+ let emitPrefillProgress = prefillProgressEnabled ( in: request)
10351037
10361038 // ── Merge per-request overrides with CLI defaults ──
10371039 let tokenLimit = chatReq. maxTokens ?? config. maxTokens
@@ -1284,7 +1286,8 @@ func handleChatCompletion(
12841286 stream: stream, modelId: modelId, stopSequences: stopSequences,
12851287 includeUsage: includeUsage, promptTokenCount: promptTokenCount,
12861288 enableThinking: enableThinking, jsonMode: jsonMode, semaphore: semaphore,
1287- stats: stats, genStart: genStart, prefillStart: prefillStart, onPrefillDone: onPrefillDone
1289+ stats: stats, genStart: genStart, prefillStart: prefillStart,
1290+ emitPrefillProgress: emitPrefillProgress, onPrefillDone: onPrefillDone
12881291 )
12891292 } else {
12901293 return try await handleChatNonStreaming (
@@ -1384,29 +1387,32 @@ func handleChatStreaming(
13841387 stats: ServerStats,
13851388 genStart: Date,
13861389 prefillStart: Date,
1390+ emitPrefillProgress: Bool,
13871391 onPrefillDone: ( ( ) async -> Void) ? = nil
13881392) -> Response {
13891393 let ( sseStream, cont) = AsyncStream< String> . makeStream( )
13901394
1391- // ── Prefill heartbeat: emit llama-server-style slot_update progress every 2 s ──
1392- // n_past is updated by activePrefillProgressHook in LLMModel.prepare() after each
1393- // 512-token chunk; single-chunk prompts only show elapsed_seconds.
13941395 let prefillState = PrefillState ( )
1395- activePrefillProgressHook = { nPast, _ in
1396- Task { await prefillState. update ( nPast: nPast) }
1397- }
1398- Task {
1399- var elapsed = 0
1400- while await !prefillState . done {
1401- try ? await Task . sleep ( for: . seconds( 2 ) )
1402- if await !prefillState. done {
1403- elapsed += 2
1404- let nPast = await prefillState. nPast
1405- _ = cont. yield ( ssePrefillChunk (
1406- modelId: modelId,
1407- nPast: nPast,
1408- promptTokens: promptTokenCount,
1409- elapsedSeconds: elapsed) )
1396+ activePrefillProgressHook = nil
1397+ if emitPrefillProgress {
1398+ // ── Optional prefill heartbeat: emit a named SSE event every 2 s ──
1399+ // n_past is updated by activePrefillProgressHook in LLMModel.prepare() after each
1400+ // 512-token chunk; single-chunk prompts only show elapsed_seconds.
1401+ activePrefillProgressHook = { nPast, _ in
1402+ Task { await prefillState. update ( nPast: nPast) }
1403+ }
1404+ Task {
1405+ var elapsed = 0
1406+ while await !prefillState . done {
1407+ try ? await Task . sleep ( for: . seconds( 2 ) )
1408+ if await !prefillState. done {
1409+ elapsed += 2
1410+ let nPast = await prefillState. nPast
1411+ _ = cont. yield ( ssePrefillChunk (
1412+ nPast: nPast,
1413+ promptTokens: promptTokenCount,
1414+ elapsedSeconds: elapsed) )
1415+ }
14101416 }
14111417 }
14121418 }
@@ -1735,6 +1741,7 @@ func extractThinkingBlock(from text: String) -> (String?, String) {
17351741// ── Text Completions Handler ─────────────────────────────────────────────────
17361742
17371743func handleTextCompletion(
1744+ request: Request ,
17381745 bodyData: Data ,
17391746 config: ServerConfig ,
17401747 container: ModelContainer ,
@@ -1743,6 +1750,7 @@ func handleTextCompletion(
17431750) async throws -> Response {
17441751 let compReq = try JSONDecoder ( ) . decode ( TextCompletionRequest . self, from: bodyData)
17451752 let isStream = compReq. stream ?? false
1753+ let emitPrefillProgress = prefillProgressEnabled ( in: request)
17461754
17471755 let tokenLimit = compReq. maxTokens ?? config. maxTokens
17481756 let temperature = compReq. temperature. map ( Float . init) ?? config. temp
@@ -1783,7 +1791,8 @@ func handleTextCompletion(
17831791 if isStream {
17841792 return handleTextStreaming (
17851793 stream: stream, modelId: modelId, stopSequences: stopSequences,
1786- semaphore: semaphore, stats: stats, genStart: genStart
1794+ promptTokenCount: promptTokenCount, semaphore: semaphore, stats: stats,
1795+ genStart: genStart, emitPrefillProgress: emitPrefillProgress
17871796 )
17881797 } else {
17891798 return try await handleTextNonStreaming (
@@ -1799,19 +1808,48 @@ func handleTextStreaming(
17991808 stream: AsyncStream < Generation > ,
18001809 modelId: String ,
18011810 stopSequences: [ String ] ,
1811+ promptTokenCount: Int ,
18021812 semaphore: AsyncSemaphore ,
18031813 stats: ServerStats ,
1804- genStart: Date
1814+ genStart: Date ,
1815+ emitPrefillProgress: Bool
18051816) -> Response {
18061817 let ( sseStream, cont) = AsyncStream< String> . makeStream( )
1818+ let prefillState = PrefillState ( )
1819+ activePrefillProgressHook = nil
1820+ if emitPrefillProgress {
1821+ activePrefillProgressHook = { nPast, _ in
1822+ Task { await prefillState. update ( nPast: nPast) }
1823+ }
1824+ Task {
1825+ var elapsed = 0
1826+ while await !prefillState . done {
1827+ try ? await Task . sleep ( for: . seconds( 2 ) )
1828+ if await !prefillState. done {
1829+ elapsed += 2
1830+ let nPast = await prefillState. nPast
1831+ _ = cont. yield ( ssePrefillChunk (
1832+ nPast: nPast,
1833+ promptTokens: promptTokenCount,
1834+ elapsedSeconds: elapsed) )
1835+ }
1836+ }
1837+ }
1838+ }
18071839 Task {
18081840 var completionTokenCount = 0
18091841 var fullText = " "
18101842 var stopped = false
1843+ var firstToken = true
18111844 for await generation in stream {
18121845 if stopped { break }
18131846 switch generation {
18141847 case . chunk( let text, _) :
1848+ if firstToken {
1849+ activePrefillProgressHook = nil
1850+ await prefillState. finish ( )
1851+ firstToken = false
1852+ }
18151853 completionTokenCount += 1
18161854 fullText += text
18171855 // GPU yield: prevent Metal from starving macOS WindowServer
@@ -1834,6 +1872,8 @@ func handleTextStreaming(
18341872 case . toolCall:
18351873 break
18361874 case . info( let info) :
1875+ activePrefillProgressHook = nil
1876+ await prefillState. finish ( )
18371877 if !stopped {
18381878 var reason : String
18391879 switch info. stopReason {
@@ -1979,7 +2019,7 @@ struct CORSMiddleware<Context: RequestContext>: RouterMiddleware {
19792019 }
19802020 }
19812021 fields. append ( HTTPField ( name: HTTPField . Name ( " Access-Control-Allow-Methods " ) !, value: " GET, POST, OPTIONS " ) )
1982- fields. append ( HTTPField ( name: HTTPField . Name ( " Access-Control-Allow-Headers " ) !, value: " Content-Type, Authorization " ) )
2022+ fields. append ( HTTPField ( name: HTTPField . Name ( " Access-Control-Allow-Headers " ) !, value: " Content-Type, Authorization, X-SwiftLM-Prefill-Progress " ) )
19832023 return HTTPFields ( fields)
19842024 }
19852025}
@@ -2032,6 +2072,22 @@ func jsonHeaders() -> HTTPFields {
20322072 HTTPFields ( [ HTTPField ( name: . contentType, value: " application/json " ) ] )
20332073}
20342074
2075+ let prefillProgressHeaderName = HTTPField . Name ( " X-SwiftLM-Prefill-Progress " ) !
2076+
2077+ func parseTruthyHeaderValue( _ value: String ? ) -> Bool {
2078+ guard let value else { return false }
2079+ switch value. trimmingCharacters ( in: . whitespacesAndNewlines) . lowercased ( ) {
2080+ case " 1 " , " on " , " true " , " yes " :
2081+ return true
2082+ default :
2083+ return false
2084+ }
2085+ }
2086+
2087+ func prefillProgressEnabled( in request: Request ) -> Bool {
2088+ parseTruthyHeaderValue ( request. headers [ values: prefillProgressHeaderName] . first)
2089+ }
2090+
20352091func sseHeaders( ) -> HTTPFields {
20362092 HTTPFields ( [
20372093 HTTPField ( name: . contentType, value: " text/event-stream " ) ,
@@ -2074,30 +2130,25 @@ func sseChunk(modelId: String, reasoningContent: String?, content: String?, fini
20742130 return " data: \( String ( data: data, encoding: . utf8) !) \r \n \r \n "
20752131}
20762132
2077- /// Prefill-progress heartbeat chunk — emitted every 2s while the server is processing the prompt.
2078- /// Uses object type "prefill_progress" so clients can filter it without confusing it with real tokens.
2133+ /// Prefill-progress heartbeat chunk — emitted every 2s while the server is processing the prompt
2134+ /// when explicitly enabled via `X-SwiftLM-Prefill-Progress: true`.
2135+ /// It is sent as a named SSE event to avoid breaking strict OpenAI-compatible clients.
20792136/// Format mirrors llama-server's slot_update event:
20802137/// n_past : tokens evaluated so far (real value from chunked prefill, or 0 for single-chunk)
20812138/// n_prompt_tokens : total prompt token count
20822139/// fraction : n_past / n_prompt_tokens (0.0–1.0), useful for progress bars
20832140/// elapsed_seconds : wall-clock time since the request started
2084- func ssePrefillChunk( modelId : String , nPast: Int = 0 , promptTokens: Int , elapsedSeconds: Int ) -> String {
2141+ func ssePrefillChunk( nPast: Int = 0 , promptTokens: Int , elapsedSeconds: Int ) -> String {
20852142 let fraction = promptTokens > 0 ? Double ( nPast) / Double( promptTokens) : 0.0
20862143 let chunk : [ String : Any ] = [
2087- " id " : " prefill- \( UUID ( ) . uuidString) " ,
2088- " object " : " prefill_progress " ,
2089- " created " : Int ( Date ( ) . timeIntervalSince1970) ,
2090- " model " : modelId,
2091- " prefill " : [
2092- " status " : " processing " ,
2093- " n_past " : nPast,
2094- " n_prompt_tokens " : promptTokens,
2095- " fraction " : fraction,
2096- " elapsed_seconds " : elapsedSeconds
2097- ]
2144+ " status " : " processing " ,
2145+ " n_past " : nPast,
2146+ " n_prompt_tokens " : promptTokens,
2147+ " fraction " : fraction,
2148+ " elapsed_seconds " : elapsedSeconds
20982149 ]
20992150 let data = try ! JSONSerialization . data ( withJSONObject: chunk)
2100- return " data : \( String ( data: data, encoding: . utf8) !) \r \n \r \n "
2151+ return " event: prefill_progress \r \n data : \( String ( data: data, encoding: . utf8) !) \r \n \r \n "
21012152}
21022153
21032154func sseUsageChunk( modelId: String , promptTokens: Int , completionTokens: Int ) -> String {
0 commit comments