@@ -124,7 +124,6 @@ struct MLXServer: AsyncParsableCommand {
124124 case . layerPartitioned:
125125 Memory . cacheLimit = plan. recommendedCacheLimit
126126 print ( " [mlx-server] \( plan. strategy. emoji) Memory strategy: LAYER PARTITIONED ( \( plan. recommendedGPULayers) / \( plan. totalLayers) GPU layers) " )
127- print ( " [mlx-server] Note: GPU/CPU layer split requires --gpu-layers support (coming soon) " )
128127 for w in plan. warnings { print ( " [mlx-server] \( w) " ) }
129128 case . tooLarge:
130129 Memory . cacheLimit = plan. recommendedCacheLimit
@@ -136,13 +135,26 @@ struct MLXServer: AsyncParsableCommand {
136135 return
137136 }
138137
139- // --gpu-layers validation (accept now, prepare for Phase 2)
140- if let gpuLayersArg = self . gpuLayers, gpuLayersArg != " auto " {
141- if let n = Int ( gpuLayersArg) {
142- print ( " [mlx-server] --gpu-layers \( n) requested. Note: layer-level CPU/GPU split is under development. " )
138+ // ── Determine GPU layer count ──
139+ // Priority: 1) explicit --gpu-layers flag, 2) partition plan auto, 3) nil (all GPU)
140+ var requestedGPULayers : Int ? = nil
141+ if let gpuLayersArg = self . gpuLayers {
142+ if gpuLayersArg == " auto " {
143+ // Use partition plan recommendation if available
144+ requestedGPULayers = partitionPlan? . recommendedGPULayers
145+ print ( " [mlx-server] --gpu-layers auto → \( requestedGPULayers. map ( String . init) ?? " all " ) layers on GPU " )
146+ } else if let n = Int ( gpuLayersArg) {
147+ requestedGPULayers = n
148+ print ( " [mlx-server] --gpu-layers \( n) → \( n) layers on GPU " )
143149 } else {
144- print ( " [mlx-server] Warning: --gpu-layers must be 'auto' or an integer, got ' \( gpuLayersArg) '. Using auto . " )
150+ print ( " [mlx-server] Warning: --gpu-layers must be 'auto' or an integer, got ' \( gpuLayersArg) '. Using all GPU . " )
145151 }
152+ } else if let plan = partitionPlan,
153+ ( plan. strategy == . layerPartitioned || plan. strategy == . swapAssisted) ,
154+ plan. overcommitRatio > 1.0 {
155+ // Auto-partition when model exceeds available RAM (no flag needed)
156+ requestedGPULayers = plan. recommendedGPULayers
157+ print ( " [mlx-server] Auto-partitioning: \( plan. recommendedGPULayers) / \( plan. totalLayers) layers on GPU " )
146158 }
147159
148160 let isVision = self . vision
@@ -164,6 +176,20 @@ struct MLXServer: AsyncParsableCommand {
164176 }
165177 }
166178
179+ // ── Apply GPU/CPU layer partitioning ──
180+ if let gpuCount = requestedGPULayers {
181+ let actual = await container. setGPULayers ( gpuCount)
182+ if let actual {
183+ let total = partitionPlan? . totalLayers ?? actual
184+ let cpuCount = total - actual
185+ print ( " [mlx-server] 🔀 Layer split active: \( actual) GPU / \( cpuCount) CPU " )
186+ // Update the partition plan to reflect actual split
187+ partitionPlan? . gpuLayers = actual
188+ } else {
189+ print ( " [mlx-server] ⚠️ Model does not support layer partitioning (architecture not yet adapted) " )
190+ }
191+ }
192+
167193 print ( " [mlx-server] Model loaded. Starting HTTP server on \( host) : \( port) " )
168194
169195 // ── Capture CLI defaults into a shared config ──
@@ -721,7 +747,7 @@ func handleChatStreaming(
721747 for await generation in stream {
722748 if stopped { break }
723749 switch generation {
724- case . chunk( let text) :
750+ case . chunk( let text, _ ) :
725751 completionTokenCount += 1
726752 fullText += text
727753 // GPU yield: prevent Metal from starving macOS WindowServer
@@ -792,7 +818,7 @@ func handleChatNonStreaming(
792818 var tcIndex = 0
793819 for await generation in stream {
794820 switch generation {
795- case . chunk( let text) :
821+ case . chunk( let text, _ ) :
796822 fullText += text
797823 completionTokenCount += 1
798824 // GPU yield: prevent Metal from starving macOS WindowServer
@@ -936,7 +962,7 @@ func handleTextStreaming(
936962 for await generation in stream {
937963 if stopped { break }
938964 switch generation {
939- case . chunk( let text) :
965+ case . chunk( let text, _ ) :
940966 completionTokenCount += 1
941967 fullText += text
942968 // GPU yield: prevent Metal from starving macOS WindowServer
@@ -993,7 +1019,7 @@ func handleTextNonStreaming(
9931019 var completionTokenCount = 0
9941020 for await generation in stream {
9951021 switch generation {
996- case . chunk( let text) :
1022+ case . chunk( let text, _ ) :
9971023 fullText += text
9981024 completionTokenCount += 1
9991025 // GPU yield: prevent Metal from starving macOS WindowServer
0 commit comments