@@ -142,30 +142,48 @@ struct MLXServer: AsyncParsableCommand {
142142 }
143143 }
144144
145+ // Convert OpenAI tools format → [String: any Sendable] for UserInput
146+ let toolSpecs : [ [ String : any Sendable ] ] ? = chatReq. tools? . map { tool in
147+ var spec : [ String : any Sendable ] = [ " type " : tool. type]
148+ var fn : [ String : any Sendable ] = [ " name " : tool. function. name]
149+ if let desc = tool. function. description { fn [ " description " ] = desc }
150+ if let params = tool. function. parameters {
151+ fn [ " parameters " ] = params. mapValues { $0. value }
152+ }
153+ spec [ " function " ] = fn
154+ return spec
155+ }
156+
145157 // ── Acquire slot (concurrency limiter) ──
146158 await semaphore. wait ( )
147159
148160 // Pass enable_thinking to the Jinja chat template via additionalContext
149161 // (mirrors llama-server's --chat-template-kwargs '{"enable_thinking":false}')
150162 let templateContext : [ String : any Sendable ] ? = thinkingEnabled ? nil : [ " enable_thinking " : false ]
151- let userInput = UserInput ( chat: chatMessages, additionalContext: templateContext)
163+ let userInput = UserInput ( chat: chatMessages, tools : toolSpecs , additionalContext: templateContext)
152164 let lmInput = try await container. prepare ( input: userInput)
153165 let stream = try await container. generate ( input: lmInput, parameters: params)
154166
155167 if isStream {
156168 // SSE streaming
157169 let ( sseStream, cont) = AsyncStream< String> . makeStream( )
158170 Task {
171+ var hasToolCalls = false
172+ var toolCallIndex = 0
159173 for await generation in stream {
160174 switch generation {
161175 case . chunk( let text) :
162176 cont. yield ( sseChunk ( modelId: modelId, delta: text, finishReason: nil ) )
177+ case . toolCall( let tc) :
178+ hasToolCalls = true
179+ let argsJson = serializeToolCallArgs ( tc. function. arguments)
180+ cont. yield ( sseToolCallChunk ( modelId: modelId, index: toolCallIndex, name: tc. function. name, arguments: argsJson) )
181+ toolCallIndex += 1
163182 case . info:
164- cont. yield ( sseChunk ( modelId: modelId, delta: " " , finishReason: " stop " ) )
183+ let reason = hasToolCalls ? " tool_calls " : " stop "
184+ cont. yield ( sseChunk ( modelId: modelId, delta: " " , finishReason: reason) )
165185 cont. yield ( " data: [DONE] \n \n " )
166186 cont. finish ( )
167- case . toolCall:
168- break
169187 }
170188 }
171189 cont. finish ( )
@@ -177,15 +195,25 @@ struct MLXServer: AsyncParsableCommand {
177195 body: . init( asyncSequence: sseStream. map { ByteBuffer ( string: $0) } )
178196 )
179197 } else {
180- // Non-streaming: collect all chunks
198+ // Non-streaming: collect all chunks and tool calls
181199 var fullText = " "
182200 var completionTokenCount = 0
201+ var collectedToolCalls : [ ToolCallResponse ] = [ ]
202+ var tcIndex = 0
183203 for await generation in stream {
184204 switch generation {
185205 case . chunk( let text) :
186206 fullText += text
187207 completionTokenCount += 1
188- case . info, . toolCall:
208+ case . toolCall( let tc) :
209+ let argsJson = serializeToolCallArgs ( tc. function. arguments)
210+ collectedToolCalls. append ( ToolCallResponse (
211+ id: " call_ \( UUID ( ) . uuidString. prefix ( 8 ) ) " ,
212+ type: " function " ,
213+ function: ToolCallFunction ( name: tc. function. name, arguments: argsJson)
214+ ) )
215+ tcIndex += 1
216+ case . info:
189217 break
190218 }
191219 }
@@ -196,15 +224,20 @@ struct MLXServer: AsyncParsableCommand {
196224 let estimatedPromptTokens = max ( 1 , promptText. count / 4 )
197225 let totalTokens = estimatedPromptTokens + completionTokenCount
198226
227+ let hasToolCalls = !collectedToolCalls. isEmpty
199228 let resp = ChatCompletionResponse (
200229 id: " chatcmpl- \( UUID ( ) . uuidString) " ,
201230 model: modelId,
202231 created: Int ( Date ( ) . timeIntervalSince1970) ,
203232 choices: [
204233 Choice (
205234 index: 0 ,
206- message: AssistantMessage ( role: " assistant " , content: fullText) ,
207- finishReason: " stop "
235+ message: AssistantMessage (
236+ role: " assistant " ,
237+ content: fullText. isEmpty && hasToolCalls ? nil : fullText,
238+ toolCalls: hasToolCalls ? collectedToolCalls : nil
239+ ) ,
240+ finishReason: hasToolCalls ? " tool_calls " : " stop "
208241 )
209242 ] ,
210243 usage: TokenUsage ( promptTokens: estimatedPromptTokens, completionTokens: completionTokenCount, totalTokens: totalTokens)
@@ -319,23 +352,68 @@ func sseChunk(modelId: String, delta: String, finishReason: String?) -> String {
319352 return " data: \( String ( data: data, encoding: . utf8) !) \n \n "
320353}
321354
355+ func sseToolCallChunk( modelId: String , index: Int , name: String , arguments: String ) -> String {
356+ let chunk : [ String : Any ] = [
357+ " id " : " chatcmpl- \( UUID ( ) . uuidString) " ,
358+ " object " : " chat.completion.chunk " ,
359+ " created " : Int ( Date ( ) . timeIntervalSince1970) ,
360+ " model " : modelId,
361+ " choices " : [ [
362+ " index " : 0 ,
363+ " delta " : [
364+ " role " : " assistant " ,
365+ " tool_calls " : [ [
366+ " index " : index,
367+ " id " : " call_ \( UUID ( ) . uuidString. prefix ( 8 ) ) " ,
368+ " type " : " function " ,
369+ " function " : [
370+ " name " : name,
371+ " arguments " : arguments,
372+ ] as [ String : Any ] ,
373+ ] as [ String : Any ] ] ,
374+ ] as [ String : Any ] ,
375+ ] as [ String : Any ] ]
376+ ]
377+ let data = try ! JSONSerialization . data ( withJSONObject: chunk)
378+ return " data: \( String ( data: data, encoding: . utf8) !) \n \n "
379+ }
380+
381+ /// Serialize ToolCall arguments ([String: JSONValue]) to a JSON string
382+ func serializeToolCallArgs( _ args: [ String : JSONValue ] ) -> String {
383+ let anyDict = args. mapValues { $0. anyValue }
384+ guard let data = try ? JSONSerialization . data ( withJSONObject: anyDict) else {
385+ return " {} "
386+ }
387+ return String ( data: data, encoding: . utf8) ?? " {} "
388+ }
389+
322390// ── OpenAI-compatible types ───────────────────────────────────────────────────
323391
324392struct ChatCompletionRequest : Decodable {
325393 struct Message : Decodable {
326394 let role : String
327395 let content : String
328396 }
397+ struct ToolDef : Decodable {
398+ let type : String
399+ let function : ToolFuncDef
400+ }
401+ struct ToolFuncDef : Decodable {
402+ let name : String
403+ let description : String ?
404+ let parameters : [ String : AnyCodable ] ?
405+ }
329406 let model : String ?
330407 let messages : [ Message ]
331408 let stream : Bool ?
332409 let maxTokens : Int ?
333410 let temperature : Double ?
334411 let topP : Double ?
335412 let repetitionPenalty : Double ?
413+ let tools : [ ToolDef ] ?
336414
337415 enum CodingKeys : String , CodingKey {
338- case model, messages, stream, temperature
416+ case model, messages, stream, temperature, tools
339417 case maxTokens = " max_tokens "
340418 case topP = " top_p "
341419 case repetitionPenalty = " repetition_penalty "
@@ -364,7 +442,45 @@ struct Choice: Encodable {
364442
365443struct AssistantMessage : Encodable {
366444 let role : String
367- let content : String
445+ let content : String ?
446+ let toolCalls : [ ToolCallResponse ] ?
447+
448+ enum CodingKeys : String , CodingKey {
449+ case role, content
450+ case toolCalls = " tool_calls "
451+ }
452+ }
453+
454+ struct ToolCallResponse : Encodable {
455+ let id : String
456+ let type : String
457+ let function : ToolCallFunction
458+ }
459+
460+ struct ToolCallFunction : Encodable {
461+ let name : String
462+ let arguments : String
463+ }
464+
465+ /// AnyCodable: decode arbitrary JSON for tool parameters pass-through
466+ struct AnyCodable : Decodable , Sendable {
467+ let value : Any
468+ init ( from decoder: Decoder ) throws {
469+ let c = try decoder. singleValueContainer ( )
470+ if c. decodeNil ( ) { value = NSNull ( ) }
471+ else if let b = try ? c. decode ( Bool . self) { value = b }
472+ else if let i = try ? c. decode ( Int . self) { value = i }
473+ else if let d = try ? c. decode ( Double . self) { value = d }
474+ else if let s = try ? c. decode ( String . self) { value = s }
475+ else if let a = try ? c. decode ( [ AnyCodable ] . self) { value = a. map { $0. value } }
476+ else if let d = try ? c. decode ( [ String : AnyCodable ] . self) { value = d. mapValues { $0. value } }
477+ else { value = NSNull ( ) }
478+ }
479+ // Convert back to [String: any Sendable] for ToolSpec usage
480+ static func toSendable( _ dict: [ String : AnyCodable ] ? ) -> [ String : any Sendable ] ? {
481+ guard let dict else { return nil }
482+ return dict. mapValues { $0. value as! any Sendable }
483+ }
368484}
369485
370486struct TokenUsage : Encodable {
0 commit comments