Skip to content

Commit 8809805

Browse files
authored
Merge pull request #1 from SharpAI/feature/qwen35-support
Feature/qwen35 support
2 parents 2dce45c + 19717cf commit 8809805

File tree

3 files changed

+162
-20
lines changed

3 files changed

+162
-20
lines changed

Package.resolved

Lines changed: 23 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ let package = Package(
55
name: "mlx-server",
66
platforms: [.macOS(.v14)],
77
dependencies: [
8-
// Apple MLX Swift — core inference engine
9-
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.30.3")),
10-
// Apple's LLM library built on MLX Swift (Qwen, Llama, Mistral, Gemma etc.)
11-
.package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.0.0"),
8+
// Apple MLX Swift — core inference engine (Apple-maintained, tagged releases)
9+
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.30.6")),
10+
// Apple's LLM library built on MLX Swift (SharpAI fork)
11+
// Pinned to main branch for Qwen3.5 support (PRs #97, #120, #129, #133, #135 — not yet in a release tag)
12+
.package(url: "https://github.com/SharpAI/mlx-swift-lm", branch: "main"),
1213
// HuggingFace tokenizers + model download
13-
.package(url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "1.1.0")),
14+
.package(url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "1.2.0")),
1415
// Lightweight HTTP server (Apple-backed Swift server project)
1516
.package(url: "https://github.com/hummingbird-project/hummingbird", from: "2.0.0"),
1617
// Async argument parser (for CLI flags: --model, --port)

Sources/mlx-server/main.swift

Lines changed: 133 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ struct MLXServer: AsyncParsableCommand {
5151
@Option(name: .long, help: "Number of parallel request slots")
5252
var parallel: Int = 1
5353

54+
@Flag(name: .long, help: "Enable thinking/reasoning mode (Qwen3.5 etc). Default: disabled")
55+
var thinking: Bool = false
56+
5457
mutating func run() async throws {
5558
print("[mlx-server] Loading model: \(model)")
5659
let modelId = model
@@ -72,6 +75,7 @@ struct MLXServer: AsyncParsableCommand {
7275
let defaultTemp = self.temp
7376
let defaultTopP = self.topP
7477
let defaultRepeatPenalty = self.repeatPenalty
78+
let thinkingEnabled = self.thinking
7579
let parallelSlots = self.parallel
7680

7781
// ── Concurrency limiter ──
@@ -138,27 +142,48 @@ struct MLXServer: AsyncParsableCommand {
138142
}
139143
}
140144

145+
// Convert OpenAI tools format → [String: any Sendable] for UserInput
146+
let toolSpecs: [[String: any Sendable]]? = chatReq.tools?.map { tool in
147+
var spec: [String: any Sendable] = ["type": tool.type]
148+
var fn: [String: any Sendable] = ["name": tool.function.name]
149+
if let desc = tool.function.description { fn["description"] = desc }
150+
if let params = tool.function.parameters {
151+
fn["parameters"] = params.mapValues { $0.value }
152+
}
153+
spec["function"] = fn
154+
return spec
155+
}
156+
141157
// ── Acquire slot (concurrency limiter) ──
142158
await semaphore.wait()
143159

144-
let userInput = UserInput(chat: chatMessages)
160+
// Pass enable_thinking to the Jinja chat template via additionalContext
161+
// (mirrors llama-server's --chat-template-kwargs '{"enable_thinking":false}')
162+
let templateContext: [String: any Sendable]? = thinkingEnabled ? nil : ["enable_thinking": false]
163+
let userInput = UserInput(chat: chatMessages, tools: toolSpecs, additionalContext: templateContext)
145164
let lmInput = try await container.prepare(input: userInput)
146165
let stream = try await container.generate(input: lmInput, parameters: params)
147166

148167
if isStream {
149168
// SSE streaming
150169
let (sseStream, cont) = AsyncStream<String>.makeStream()
151170
Task {
171+
var hasToolCalls = false
172+
var toolCallIndex = 0
152173
for await generation in stream {
153174
switch generation {
154175
case .chunk(let text):
155176
cont.yield(sseChunk(modelId: modelId, delta: text, finishReason: nil))
177+
case .toolCall(let tc):
178+
hasToolCalls = true
179+
let argsJson = serializeToolCallArgs(tc.function.arguments)
180+
cont.yield(sseToolCallChunk(modelId: modelId, index: toolCallIndex, name: tc.function.name, arguments: argsJson))
181+
toolCallIndex += 1
156182
case .info:
157-
cont.yield(sseChunk(modelId: modelId, delta: "", finishReason: "stop"))
183+
let reason = hasToolCalls ? "tool_calls" : "stop"
184+
cont.yield(sseChunk(modelId: modelId, delta: "", finishReason: reason))
158185
cont.yield("data: [DONE]\n\n")
159186
cont.finish()
160-
case .toolCall:
161-
break
162187
}
163188
}
164189
cont.finish()
@@ -170,15 +195,25 @@ struct MLXServer: AsyncParsableCommand {
170195
body: .init(asyncSequence: sseStream.map { ByteBuffer(string: $0) })
171196
)
172197
} else {
173-
// Non-streaming: collect all chunks
198+
// Non-streaming: collect all chunks and tool calls
174199
var fullText = ""
175200
var completionTokenCount = 0
201+
var collectedToolCalls: [ToolCallResponse] = []
202+
var tcIndex = 0
176203
for await generation in stream {
177204
switch generation {
178205
case .chunk(let text):
179206
fullText += text
180207
completionTokenCount += 1
181-
case .info, .toolCall:
208+
case .toolCall(let tc):
209+
let argsJson = serializeToolCallArgs(tc.function.arguments)
210+
collectedToolCalls.append(ToolCallResponse(
211+
id: "call_\(UUID().uuidString.prefix(8))",
212+
type: "function",
213+
function: ToolCallFunction(name: tc.function.name, arguments: argsJson)
214+
))
215+
tcIndex += 1
216+
case .info:
182217
break
183218
}
184219
}
@@ -189,15 +224,20 @@ struct MLXServer: AsyncParsableCommand {
189224
let estimatedPromptTokens = max(1, promptText.count / 4)
190225
let totalTokens = estimatedPromptTokens + completionTokenCount
191226

227+
let hasToolCalls = !collectedToolCalls.isEmpty
192228
let resp = ChatCompletionResponse(
193229
id: "chatcmpl-\(UUID().uuidString)",
194230
model: modelId,
195231
created: Int(Date().timeIntervalSince1970),
196232
choices: [
197233
Choice(
198234
index: 0,
199-
message: AssistantMessage(role: "assistant", content: fullText),
200-
finishReason: "stop"
235+
message: AssistantMessage(
236+
role: "assistant",
237+
content: fullText.isEmpty && hasToolCalls ? nil : fullText,
238+
toolCalls: hasToolCalls ? collectedToolCalls : nil
239+
),
240+
finishReason: hasToolCalls ? "tool_calls" : "stop"
201241
)
202242
],
203243
usage: TokenUsage(promptTokens: estimatedPromptTokens, completionTokens: completionTokenCount, totalTokens: totalTokens)
@@ -312,23 +352,68 @@ func sseChunk(modelId: String, delta: String, finishReason: String?) -> String {
312352
return "data: \(String(data: data, encoding: .utf8)!)\n\n"
313353
}
314354

355+
func sseToolCallChunk(modelId: String, index: Int, name: String, arguments: String) -> String {
356+
let chunk: [String: Any] = [
357+
"id": "chatcmpl-\(UUID().uuidString)",
358+
"object": "chat.completion.chunk",
359+
"created": Int(Date().timeIntervalSince1970),
360+
"model": modelId,
361+
"choices": [[
362+
"index": 0,
363+
"delta": [
364+
"role": "assistant",
365+
"tool_calls": [[
366+
"index": index,
367+
"id": "call_\(UUID().uuidString.prefix(8))",
368+
"type": "function",
369+
"function": [
370+
"name": name,
371+
"arguments": arguments,
372+
] as [String: Any],
373+
] as [String: Any]],
374+
] as [String: Any],
375+
] as [String: Any]]
376+
]
377+
let data = try! JSONSerialization.data(withJSONObject: chunk)
378+
return "data: \(String(data: data, encoding: .utf8)!)\n\n"
379+
}
380+
381+
/// Serialize ToolCall arguments ([String: JSONValue]) to a JSON string
382+
func serializeToolCallArgs(_ args: [String: JSONValue]) -> String {
383+
let anyDict = args.mapValues { $0.anyValue }
384+
guard let data = try? JSONSerialization.data(withJSONObject: anyDict) else {
385+
return "{}"
386+
}
387+
return String(data: data, encoding: .utf8) ?? "{}"
388+
}
389+
315390
// ── OpenAI-compatible types ───────────────────────────────────────────────────
316391

317392
struct ChatCompletionRequest: Decodable {
318393
struct Message: Decodable {
319394
let role: String
320395
let content: String
321396
}
397+
struct ToolDef: Decodable {
398+
let type: String
399+
let function: ToolFuncDef
400+
}
401+
struct ToolFuncDef: Decodable {
402+
let name: String
403+
let description: String?
404+
let parameters: [String: AnyCodable]?
405+
}
322406
let model: String?
323407
let messages: [Message]
324408
let stream: Bool?
325409
let maxTokens: Int?
326410
let temperature: Double?
327411
let topP: Double?
328412
let repetitionPenalty: Double?
413+
let tools: [ToolDef]?
329414

330415
enum CodingKeys: String, CodingKey {
331-
case model, messages, stream, temperature
416+
case model, messages, stream, temperature, tools
332417
case maxTokens = "max_tokens"
333418
case topP = "top_p"
334419
case repetitionPenalty = "repetition_penalty"
@@ -357,7 +442,45 @@ struct Choice: Encodable {
357442

358443
struct AssistantMessage: Encodable {
359444
let role: String
360-
let content: String
445+
let content: String?
446+
let toolCalls: [ToolCallResponse]?
447+
448+
enum CodingKeys: String, CodingKey {
449+
case role, content
450+
case toolCalls = "tool_calls"
451+
}
452+
}
453+
454+
struct ToolCallResponse: Encodable {
455+
let id: String
456+
let type: String
457+
let function: ToolCallFunction
458+
}
459+
460+
struct ToolCallFunction: Encodable {
461+
let name: String
462+
let arguments: String
463+
}
464+
465+
/// AnyCodable: decode arbitrary JSON for tool parameters pass-through
466+
struct AnyCodable: Decodable, Sendable {
467+
let value: Any
468+
init(from decoder: Decoder) throws {
469+
let c = try decoder.singleValueContainer()
470+
if c.decodeNil() { value = NSNull() }
471+
else if let b = try? c.decode(Bool.self) { value = b }
472+
else if let i = try? c.decode(Int.self) { value = i }
473+
else if let d = try? c.decode(Double.self) { value = d }
474+
else if let s = try? c.decode(String.self) { value = s }
475+
else if let a = try? c.decode([AnyCodable].self) { value = a.map { $0.value } }
476+
else if let d = try? c.decode([String: AnyCodable].self) { value = d.mapValues { $0.value } }
477+
else { value = NSNull() }
478+
}
479+
// Convert back to [String: any Sendable] for ToolSpec usage
480+
static func toSendable(_ dict: [String: AnyCodable]?) -> [String: any Sendable]? {
481+
guard let dict else { return nil }
482+
return dict.mapValues { $0.value as! any Sendable }
483+
}
361484
}
362485

363486
struct TokenUsage: Encodable {

0 commit comments

Comments
 (0)