Skip to content

Commit 19717cf

Browse files
simbasimba
authored andcommitted
feat: add full OpenAI-compatible tool calling support
1 parent 91ee743 commit 19717cf

File tree

1 file changed

+126
-10
lines changed

1 file changed

+126
-10
lines changed

Sources/mlx-server/main.swift

Lines changed: 126 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,30 +142,48 @@ struct MLXServer: AsyncParsableCommand {
142142
}
143143
}
144144

145+
// Convert OpenAI tools format → [String: any Sendable] for UserInput
146+
let toolSpecs: [[String: any Sendable]]? = chatReq.tools?.map { tool in
147+
var spec: [String: any Sendable] = ["type": tool.type]
148+
var fn: [String: any Sendable] = ["name": tool.function.name]
149+
if let desc = tool.function.description { fn["description"] = desc }
150+
if let params = tool.function.parameters {
151+
fn["parameters"] = params.mapValues { $0.value }
152+
}
153+
spec["function"] = fn
154+
return spec
155+
}
156+
145157
// ── Acquire slot (concurrency limiter) ──
146158
await semaphore.wait()
147159

148160
// Pass enable_thinking to the Jinja chat template via additionalContext
149161
// (mirrors llama-server's --chat-template-kwargs '{"enable_thinking":false}')
150162
let templateContext: [String: any Sendable]? = thinkingEnabled ? nil : ["enable_thinking": false]
151-
let userInput = UserInput(chat: chatMessages, additionalContext: templateContext)
163+
let userInput = UserInput(chat: chatMessages, tools: toolSpecs, additionalContext: templateContext)
152164
let lmInput = try await container.prepare(input: userInput)
153165
let stream = try await container.generate(input: lmInput, parameters: params)
154166

155167
if isStream {
156168
// SSE streaming
157169
let (sseStream, cont) = AsyncStream<String>.makeStream()
158170
Task {
171+
var hasToolCalls = false
172+
var toolCallIndex = 0
159173
for await generation in stream {
160174
switch generation {
161175
case .chunk(let text):
162176
cont.yield(sseChunk(modelId: modelId, delta: text, finishReason: nil))
177+
case .toolCall(let tc):
178+
hasToolCalls = true
179+
let argsJson = serializeToolCallArgs(tc.function.arguments)
180+
cont.yield(sseToolCallChunk(modelId: modelId, index: toolCallIndex, name: tc.function.name, arguments: argsJson))
181+
toolCallIndex += 1
163182
case .info:
164-
cont.yield(sseChunk(modelId: modelId, delta: "", finishReason: "stop"))
183+
let reason = hasToolCalls ? "tool_calls" : "stop"
184+
cont.yield(sseChunk(modelId: modelId, delta: "", finishReason: reason))
165185
cont.yield("data: [DONE]\n\n")
166186
cont.finish()
167-
case .toolCall:
168-
break
169187
}
170188
}
171189
cont.finish()
@@ -177,15 +195,25 @@ struct MLXServer: AsyncParsableCommand {
177195
body: .init(asyncSequence: sseStream.map { ByteBuffer(string: $0) })
178196
)
179197
} else {
180-
// Non-streaming: collect all chunks
198+
// Non-streaming: collect all chunks and tool calls
181199
var fullText = ""
182200
var completionTokenCount = 0
201+
var collectedToolCalls: [ToolCallResponse] = []
202+
var tcIndex = 0
183203
for await generation in stream {
184204
switch generation {
185205
case .chunk(let text):
186206
fullText += text
187207
completionTokenCount += 1
188-
case .info, .toolCall:
208+
case .toolCall(let tc):
209+
let argsJson = serializeToolCallArgs(tc.function.arguments)
210+
collectedToolCalls.append(ToolCallResponse(
211+
id: "call_\(UUID().uuidString.prefix(8))",
212+
type: "function",
213+
function: ToolCallFunction(name: tc.function.name, arguments: argsJson)
214+
))
215+
tcIndex += 1
216+
case .info:
189217
break
190218
}
191219
}
@@ -196,15 +224,20 @@ struct MLXServer: AsyncParsableCommand {
196224
let estimatedPromptTokens = max(1, promptText.count / 4)
197225
let totalTokens = estimatedPromptTokens + completionTokenCount
198226

227+
let hasToolCalls = !collectedToolCalls.isEmpty
199228
let resp = ChatCompletionResponse(
200229
id: "chatcmpl-\(UUID().uuidString)",
201230
model: modelId,
202231
created: Int(Date().timeIntervalSince1970),
203232
choices: [
204233
Choice(
205234
index: 0,
206-
message: AssistantMessage(role: "assistant", content: fullText),
207-
finishReason: "stop"
235+
message: AssistantMessage(
236+
role: "assistant",
237+
content: fullText.isEmpty && hasToolCalls ? nil : fullText,
238+
toolCalls: hasToolCalls ? collectedToolCalls : nil
239+
),
240+
finishReason: hasToolCalls ? "tool_calls" : "stop"
208241
)
209242
],
210243
usage: TokenUsage(promptTokens: estimatedPromptTokens, completionTokens: completionTokenCount, totalTokens: totalTokens)
@@ -319,23 +352,68 @@ func sseChunk(modelId: String, delta: String, finishReason: String?) -> String {
319352
return "data: \(String(data: data, encoding: .utf8)!)\n\n"
320353
}
321354

355+
func sseToolCallChunk(modelId: String, index: Int, name: String, arguments: String) -> String {
356+
let chunk: [String: Any] = [
357+
"id": "chatcmpl-\(UUID().uuidString)",
358+
"object": "chat.completion.chunk",
359+
"created": Int(Date().timeIntervalSince1970),
360+
"model": modelId,
361+
"choices": [[
362+
"index": 0,
363+
"delta": [
364+
"role": "assistant",
365+
"tool_calls": [[
366+
"index": index,
367+
"id": "call_\(UUID().uuidString.prefix(8))",
368+
"type": "function",
369+
"function": [
370+
"name": name,
371+
"arguments": arguments,
372+
] as [String: Any],
373+
] as [String: Any]],
374+
] as [String: Any],
375+
] as [String: Any]]
376+
]
377+
let data = try! JSONSerialization.data(withJSONObject: chunk)
378+
return "data: \(String(data: data, encoding: .utf8)!)\n\n"
379+
}
380+
381+
/// Serialize ToolCall arguments ([String: JSONValue]) to a JSON string
382+
func serializeToolCallArgs(_ args: [String: JSONValue]) -> String {
383+
let anyDict = args.mapValues { $0.anyValue }
384+
guard let data = try? JSONSerialization.data(withJSONObject: anyDict) else {
385+
return "{}"
386+
}
387+
return String(data: data, encoding: .utf8) ?? "{}"
388+
}
389+
322390
// ── OpenAI-compatible types ───────────────────────────────────────────────────
323391

324392
struct ChatCompletionRequest: Decodable {
325393
struct Message: Decodable {
326394
let role: String
327395
let content: String
328396
}
397+
struct ToolDef: Decodable {
398+
let type: String
399+
let function: ToolFuncDef
400+
}
401+
struct ToolFuncDef: Decodable {
402+
let name: String
403+
let description: String?
404+
let parameters: [String: AnyCodable]?
405+
}
329406
let model: String?
330407
let messages: [Message]
331408
let stream: Bool?
332409
let maxTokens: Int?
333410
let temperature: Double?
334411
let topP: Double?
335412
let repetitionPenalty: Double?
413+
let tools: [ToolDef]?
336414

337415
enum CodingKeys: String, CodingKey {
338-
case model, messages, stream, temperature
416+
case model, messages, stream, temperature, tools
339417
case maxTokens = "max_tokens"
340418
case topP = "top_p"
341419
case repetitionPenalty = "repetition_penalty"
@@ -364,7 +442,45 @@ struct Choice: Encodable {
364442

365443
struct AssistantMessage: Encodable {
366444
let role: String
367-
let content: String
445+
let content: String?
446+
let toolCalls: [ToolCallResponse]?
447+
448+
enum CodingKeys: String, CodingKey {
449+
case role, content
450+
case toolCalls = "tool_calls"
451+
}
452+
}
453+
454+
struct ToolCallResponse: Encodable {
455+
let id: String
456+
let type: String
457+
let function: ToolCallFunction
458+
}
459+
460+
struct ToolCallFunction: Encodable {
461+
let name: String
462+
let arguments: String
463+
}
464+
465+
/// AnyCodable: decode arbitrary JSON for tool parameters pass-through
466+
struct AnyCodable: Decodable, Sendable {
467+
let value: Any
468+
init(from decoder: Decoder) throws {
469+
let c = try decoder.singleValueContainer()
470+
if c.decodeNil() { value = NSNull() }
471+
else if let b = try? c.decode(Bool.self) { value = b }
472+
else if let i = try? c.decode(Int.self) { value = i }
473+
else if let d = try? c.decode(Double.self) { value = d }
474+
else if let s = try? c.decode(String.self) { value = s }
475+
else if let a = try? c.decode([AnyCodable].self) { value = a.map { $0.value } }
476+
else if let d = try? c.decode([String: AnyCodable].self) { value = d.mapValues { $0.value } }
477+
else { value = NSNull() }
478+
}
479+
// Convert back to [String: any Sendable] for ToolSpec usage
480+
static func toSendable(_ dict: [String: AnyCodable]?) -> [String: any Sendable]? {
481+
guard let dict else { return nil }
482+
return dict.mapValues { $0.value as! any Sendable }
483+
}
368484
}
369485

370486
struct TokenUsage: Encodable {

0 commit comments

Comments
 (0)