@@ -51,6 +51,9 @@ struct MLXServer: AsyncParsableCommand {
5151 @Option ( name: . long, help: " Number of parallel request slots " )
5252 var parallel : Int = 1
5353
54+ @Flag ( name: . long, help: " Enable thinking/reasoning mode (Qwen3.5 etc). Default: disabled " )
55+ var thinking : Bool = false
56+
5457 mutating func run( ) async throws {
5558 print ( " [mlx-server] Loading model: \( model) " )
5659 let modelId = model
@@ -72,6 +75,7 @@ struct MLXServer: AsyncParsableCommand {
7275 let defaultTemp = self . temp
7376 let defaultTopP = self . topP
7477 let defaultRepeatPenalty = self . repeatPenalty
78+ let thinkingEnabled = self . thinking
7579 let parallelSlots = self . parallel
7680
7781 // ── Concurrency limiter ──
@@ -138,27 +142,48 @@ struct MLXServer: AsyncParsableCommand {
138142 }
139143 }
140144
145+ // Convert OpenAI tools format → [String: any Sendable] for UserInput
146+ let toolSpecs : [ [ String : any Sendable ] ] ? = chatReq. tools? . map { tool in
147+ var spec : [ String : any Sendable ] = [ " type " : tool. type]
148+ var fn : [ String : any Sendable ] = [ " name " : tool. function. name]
149+ if let desc = tool. function. description { fn [ " description " ] = desc }
150+ if let params = tool. function. parameters {
151+ fn [ " parameters " ] = params. mapValues { $0. value }
152+ }
153+ spec [ " function " ] = fn
154+ return spec
155+ }
156+
141157 // ── Acquire slot (concurrency limiter) ──
142158 await semaphore. wait ( )
143159
144- let userInput = UserInput ( chat: chatMessages)
160+ // Pass enable_thinking to the Jinja chat template via additionalContext
161+ // (mirrors llama-server's --chat-template-kwargs '{"enable_thinking":false}')
162+ let templateContext : [ String : any Sendable ] ? = thinkingEnabled ? nil : [ " enable_thinking " : false ]
163+ let userInput = UserInput ( chat: chatMessages, tools: toolSpecs, additionalContext: templateContext)
145164 let lmInput = try await container. prepare ( input: userInput)
146165 let stream = try await container. generate ( input: lmInput, parameters: params)
147166
148167 if isStream {
149168 // SSE streaming
150169 let ( sseStream, cont) = AsyncStream< String> . makeStream( )
151170 Task {
171+ var hasToolCalls = false
172+ var toolCallIndex = 0
152173 for await generation in stream {
153174 switch generation {
154175 case . chunk( let text) :
155176 cont. yield ( sseChunk ( modelId: modelId, delta: text, finishReason: nil ) )
177+ case . toolCall( let tc) :
178+ hasToolCalls = true
179+ let argsJson = serializeToolCallArgs ( tc. function. arguments)
180+ cont. yield ( sseToolCallChunk ( modelId: modelId, index: toolCallIndex, name: tc. function. name, arguments: argsJson) )
181+ toolCallIndex += 1
156182 case . info:
157- cont. yield ( sseChunk ( modelId: modelId, delta: " " , finishReason: " stop " ) )
183+ let reason = hasToolCalls ? " tool_calls " : " stop "
184+ cont. yield ( sseChunk ( modelId: modelId, delta: " " , finishReason: reason) )
158185 cont. yield ( " data: [DONE] \n \n " )
159186 cont. finish ( )
160- case . toolCall:
161- break
162187 }
163188 }
164189 cont. finish ( )
@@ -170,15 +195,25 @@ struct MLXServer: AsyncParsableCommand {
170195 body: . init( asyncSequence: sseStream. map { ByteBuffer ( string: $0) } )
171196 )
172197 } else {
173- // Non-streaming: collect all chunks
198+ // Non-streaming: collect all chunks and tool calls
174199 var fullText = " "
175200 var completionTokenCount = 0
201+ var collectedToolCalls : [ ToolCallResponse ] = [ ]
202+ var tcIndex = 0
176203 for await generation in stream {
177204 switch generation {
178205 case . chunk( let text) :
179206 fullText += text
180207 completionTokenCount += 1
181- case . info, . toolCall:
208+ case . toolCall( let tc) :
209+ let argsJson = serializeToolCallArgs ( tc. function. arguments)
210+ collectedToolCalls. append ( ToolCallResponse (
211+ id: " call_ \( UUID ( ) . uuidString. prefix ( 8 ) ) " ,
212+ type: " function " ,
213+ function: ToolCallFunction ( name: tc. function. name, arguments: argsJson)
214+ ) )
215+ tcIndex += 1
216+ case . info:
182217 break
183218 }
184219 }
@@ -189,15 +224,20 @@ struct MLXServer: AsyncParsableCommand {
189224 let estimatedPromptTokens = max ( 1 , promptText. count / 4 )
190225 let totalTokens = estimatedPromptTokens + completionTokenCount
191226
227+ let hasToolCalls = !collectedToolCalls. isEmpty
192228 let resp = ChatCompletionResponse (
193229 id: " chatcmpl- \( UUID ( ) . uuidString) " ,
194230 model: modelId,
195231 created: Int ( Date ( ) . timeIntervalSince1970) ,
196232 choices: [
197233 Choice (
198234 index: 0 ,
199- message: AssistantMessage ( role: " assistant " , content: fullText) ,
200- finishReason: " stop "
235+ message: AssistantMessage (
236+ role: " assistant " ,
237+ content: fullText. isEmpty && hasToolCalls ? nil : fullText,
238+ toolCalls: hasToolCalls ? collectedToolCalls : nil
239+ ) ,
240+ finishReason: hasToolCalls ? " tool_calls " : " stop "
201241 )
202242 ] ,
203243 usage: TokenUsage ( promptTokens: estimatedPromptTokens, completionTokens: completionTokenCount, totalTokens: totalTokens)
@@ -312,23 +352,68 @@ func sseChunk(modelId: String, delta: String, finishReason: String?) -> String {
312352 return " data: \( String ( data: data, encoding: . utf8) !) \n \n "
313353}
314354
355+ func sseToolCallChunk( modelId: String , index: Int , name: String , arguments: String ) -> String {
356+ let chunk : [ String : Any ] = [
357+ " id " : " chatcmpl- \( UUID ( ) . uuidString) " ,
358+ " object " : " chat.completion.chunk " ,
359+ " created " : Int ( Date ( ) . timeIntervalSince1970) ,
360+ " model " : modelId,
361+ " choices " : [ [
362+ " index " : 0 ,
363+ " delta " : [
364+ " role " : " assistant " ,
365+ " tool_calls " : [ [
366+ " index " : index,
367+ " id " : " call_ \( UUID ( ) . uuidString. prefix ( 8 ) ) " ,
368+ " type " : " function " ,
369+ " function " : [
370+ " name " : name,
371+ " arguments " : arguments,
372+ ] as [ String : Any ] ,
373+ ] as [ String : Any ] ] ,
374+ ] as [ String : Any ] ,
375+ ] as [ String : Any ] ]
376+ ]
377+ let data = try ! JSONSerialization . data ( withJSONObject: chunk)
378+ return " data: \( String ( data: data, encoding: . utf8) !) \n \n "
379+ }
380+
381+ /// Serialize ToolCall arguments ([String: JSONValue]) to a JSON string
382+ func serializeToolCallArgs( _ args: [ String : JSONValue ] ) -> String {
383+ let anyDict = args. mapValues { $0. anyValue }
384+ guard let data = try ? JSONSerialization . data ( withJSONObject: anyDict) else {
385+ return " {} "
386+ }
387+ return String ( data: data, encoding: . utf8) ?? " {} "
388+ }
389+
315390// ── OpenAI-compatible types ───────────────────────────────────────────────────
316391
317392struct ChatCompletionRequest : Decodable {
318393 struct Message : Decodable {
319394 let role : String
320395 let content : String
321396 }
397+ struct ToolDef : Decodable {
398+ let type : String
399+ let function : ToolFuncDef
400+ }
401+ struct ToolFuncDef : Decodable {
402+ let name : String
403+ let description : String ?
404+ let parameters : [ String : AnyCodable ] ?
405+ }
322406 let model : String ?
323407 let messages : [ Message ]
324408 let stream : Bool ?
325409 let maxTokens : Int ?
326410 let temperature : Double ?
327411 let topP : Double ?
328412 let repetitionPenalty : Double ?
413+ let tools : [ ToolDef ] ?
329414
330415 enum CodingKeys : String , CodingKey {
331- case model, messages, stream, temperature
416+ case model, messages, stream, temperature, tools
332417 case maxTokens = " max_tokens "
333418 case topP = " top_p "
334419 case repetitionPenalty = " repetition_penalty "
@@ -357,7 +442,45 @@ struct Choice: Encodable {
357442
358443struct AssistantMessage : Encodable {
359444 let role : String
360- let content : String
445+ let content : String ?
446+ let toolCalls : [ ToolCallResponse ] ?
447+
448+ enum CodingKeys : String , CodingKey {
449+ case role, content
450+ case toolCalls = " tool_calls "
451+ }
452+ }
453+
454+ struct ToolCallResponse : Encodable {
455+ let id : String
456+ let type : String
457+ let function : ToolCallFunction
458+ }
459+
460+ struct ToolCallFunction : Encodable {
461+ let name : String
462+ let arguments : String
463+ }
464+
465+ /// AnyCodable: decode arbitrary JSON for tool parameters pass-through
466+ struct AnyCodable : Decodable , Sendable {
467+ let value : Any
468+ init ( from decoder: Decoder ) throws {
469+ let c = try decoder. singleValueContainer ( )
470+ if c. decodeNil ( ) { value = NSNull ( ) }
471+ else if let b = try ? c. decode ( Bool . self) { value = b }
472+ else if let i = try ? c. decode ( Int . self) { value = i }
473+ else if let d = try ? c. decode ( Double . self) { value = d }
474+ else if let s = try ? c. decode ( String . self) { value = s }
475+ else if let a = try ? c. decode ( [ AnyCodable ] . self) { value = a. map { $0. value } }
476+ else if let d = try ? c. decode ( [ String : AnyCodable ] . self) { value = d. mapValues { $0. value } }
477+ else { value = NSNull ( ) }
478+ }
479+ // Convert back to [String: any Sendable] for ToolSpec usage
480+ static func toSendable( _ dict: [ String : AnyCodable ] ? ) -> [ String : any Sendable ] ? {
481+ guard let dict else { return nil }
482+ return dict. mapValues { $0. value as! any Sendable }
483+ }
361484}
362485
363486struct TokenUsage : Encodable {
0 commit comments