Skip to content

Commit 29e1d64

Browse files
simbasimba
authored andcommitted
feat: wire GPU/CPU layer partitioning to --gpu-layers flag
Phase 2 integration — connects the mlx-swift-lm fork's new LayerPartitionable protocol to mlx-server's CLI and profiler: - --gpu-layers N: explicitly set N layers on GPU, rest on CPU - --gpu-layers auto: use partition plan recommendation - Auto-partition: when model exceeds available RAM (overcommit > 1.0), automatically applies the recommended GPU layer count - PartitionPlan: added mutable gpuLayers field (updated after actual partitioning) and cpu_layers in /health response - Fixed .chunk API change in latest fork (now returns tokenId tuple) - Updated Package.swift comment to note partitioning support
1 parent 1416ac6 commit 29e1d64

5 files changed

Lines changed: 45 additions & 23 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ DerivedData/
1313
# IDE
1414
.vscode/
1515
.idea/
16+
mlx-swift-lm

Package.resolved

Lines changed: 0 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ let package = Package(
77
dependencies: [
88
// Apple MLX Swift — core inference engine (Apple-maintained, tagged releases)
99
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.30.6")),
10-
// Apple's LLM library built on MLX Swift (SharpAI fork)
11-
// Pinned to main branch for Qwen3.5 support (PRs #97, #120, #129, #133, #135 — not yet in a release tag)
10+
// Apple's LLM library built on MLX Swift (SharpAI fork — with GPU/CPU layer partitioning)
1211
.package(url: "https://github.com/SharpAI/mlx-swift-lm", branch: "main"),
1312
// HuggingFace tokenizers + model download
1413
.package(url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "1.2.0")),

Sources/mlx-server/ModelProfiler.swift

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ struct PartitionPlan: Sendable {
136136
let estimatedTokensPerSec: Double
137137
let warnings: [String]
138138

139+
/// Actual GPU layers after partitioning (updated by server after model load)
140+
var gpuLayers: Int
141+
139142
var fitsInMemory: Bool { strategy == .fullGPU }
140143

141144
/// JSON-compatible dictionary for the /health endpoint
@@ -147,7 +150,8 @@ struct PartitionPlan: Sendable {
147150
"kv_cache_gb": round(kvCacheMemoryGB * 10) / 10,
148151
"total_required_gb": round(totalRequiredGB * 10) / 10,
149152
"system_ram_gb": round(systemRAMGB * 10) / 10,
150-
"gpu_layers": recommendedGPULayers,
153+
"gpu_layers": gpuLayers,
154+
"cpu_layers": totalLayers - gpuLayers,
151155
"total_layers": totalLayers,
152156
"estimated_tok_s": round(estimatedTokensPerSec * 10) / 10,
153157
]
@@ -409,7 +413,8 @@ enum ModelProfiler {
409413
recommendedMemoryLimit: memoryLimit,
410414
recommendedCacheLimit: cacheLimit,
411415
estimatedTokensPerSec: estimatedSpeed,
412-
warnings: warnings
416+
warnings: warnings,
417+
gpuLayers: gpuLayers // Initially same as recommended; updated after actual partitioning
413418
)
414419
}
415420

Sources/mlx-server/Server.swift

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ struct MLXServer: AsyncParsableCommand {
124124
case .layerPartitioned:
125125
Memory.cacheLimit = plan.recommendedCacheLimit
126126
print("[mlx-server] \(plan.strategy.emoji) Memory strategy: LAYER PARTITIONED (\(plan.recommendedGPULayers)/\(plan.totalLayers) GPU layers)")
127-
print("[mlx-server] Note: GPU/CPU layer split requires --gpu-layers support (coming soon)")
128127
for w in plan.warnings { print("[mlx-server] \(w)") }
129128
case .tooLarge:
130129
Memory.cacheLimit = plan.recommendedCacheLimit
@@ -136,13 +135,26 @@ struct MLXServer: AsyncParsableCommand {
136135
return
137136
}
138137

139-
// --gpu-layers validation (accept now, prepare for Phase 2)
140-
if let gpuLayersArg = self.gpuLayers, gpuLayersArg != "auto" {
141-
if let n = Int(gpuLayersArg) {
142-
print("[mlx-server] --gpu-layers \(n) requested. Note: layer-level CPU/GPU split is under development.")
138+
// ── Determine GPU layer count ──
139+
// Priority: 1) explicit --gpu-layers flag, 2) partition plan auto, 3) nil (all GPU)
140+
var requestedGPULayers: Int? = nil
141+
if let gpuLayersArg = self.gpuLayers {
142+
if gpuLayersArg == "auto" {
143+
// Use partition plan recommendation if available
144+
requestedGPULayers = partitionPlan?.recommendedGPULayers
145+
print("[mlx-server] --gpu-layers auto → \(requestedGPULayers.map(String.init) ?? "all") layers on GPU")
146+
} else if let n = Int(gpuLayersArg) {
147+
requestedGPULayers = n
148+
print("[mlx-server] --gpu-layers \(n)\(n) layers on GPU")
143149
} else {
144-
print("[mlx-server] Warning: --gpu-layers must be 'auto' or an integer, got '\(gpuLayersArg)'. Using auto.")
150+
print("[mlx-server] Warning: --gpu-layers must be 'auto' or an integer, got '\(gpuLayersArg)'. Using all GPU.")
145151
}
152+
} else if let plan = partitionPlan,
153+
(plan.strategy == .layerPartitioned || plan.strategy == .swapAssisted),
154+
plan.overcommitRatio > 1.0 {
155+
// Auto-partition when model exceeds available RAM (no flag needed)
156+
requestedGPULayers = plan.recommendedGPULayers
157+
print("[mlx-server] Auto-partitioning: \(plan.recommendedGPULayers)/\(plan.totalLayers) layers on GPU")
146158
}
147159

148160
let isVision = self.vision
@@ -164,6 +176,20 @@ struct MLXServer: AsyncParsableCommand {
164176
}
165177
}
166178

179+
// ── Apply GPU/CPU layer partitioning ──
180+
if let gpuCount = requestedGPULayers {
181+
let actual = await container.setGPULayers(gpuCount)
182+
if let actual {
183+
let total = partitionPlan?.totalLayers ?? actual
184+
let cpuCount = total - actual
185+
print("[mlx-server] 🔀 Layer split active: \(actual) GPU / \(cpuCount) CPU")
186+
// Update the partition plan to reflect actual split
187+
partitionPlan?.gpuLayers = actual
188+
} else {
189+
print("[mlx-server] ⚠️ Model does not support layer partitioning (architecture not yet adapted)")
190+
}
191+
}
192+
167193
print("[mlx-server] Model loaded. Starting HTTP server on \(host):\(port)")
168194

169195
// ── Capture CLI defaults into a shared config ──
@@ -721,7 +747,7 @@ func handleChatStreaming(
721747
for await generation in stream {
722748
if stopped { break }
723749
switch generation {
724-
case .chunk(let text):
750+
case .chunk(let text, _):
725751
completionTokenCount += 1
726752
fullText += text
727753
// GPU yield: prevent Metal from starving macOS WindowServer
@@ -792,7 +818,7 @@ func handleChatNonStreaming(
792818
var tcIndex = 0
793819
for await generation in stream {
794820
switch generation {
795-
case .chunk(let text):
821+
case .chunk(let text, _):
796822
fullText += text
797823
completionTokenCount += 1
798824
// GPU yield: prevent Metal from starving macOS WindowServer
@@ -936,7 +962,7 @@ func handleTextStreaming(
936962
for await generation in stream {
937963
if stopped { break }
938964
switch generation {
939-
case .chunk(let text):
965+
case .chunk(let text, _):
940966
completionTokenCount += 1
941967
fullText += text
942968
// GPU yield: prevent Metal from starving macOS WindowServer
@@ -993,7 +1019,7 @@ func handleTextNonStreaming(
9931019
var completionTokenCount = 0
9941020
for await generation in stream {
9951021
switch generation {
996-
case .chunk(let text):
1022+
case .chunk(let text, _):
9971023
fullText += text
9981024
completionTokenCount += 1
9991025
// GPU yield: prevent Metal from starving macOS WindowServer

0 commit comments

Comments
 (0)