Skip to content

Commit 63da149

Browse files
author
Aegis-AI
committed
fix: address Copilot review — safe cache normalization + partitionedLayerCall
- Replace compactMap{} cache handling with explicit normalization that preserves layer↔cache index alignment (nil-padding to layers.count) - Wrap per-layer calls in partitionedLayerCall() to respect gpuLayerCount CPU/GPU routing (matches callAsFunction behavior) - Qwen3MoE additionally passes stream:model.streamExperts to match its Qwen3MoEModelInner.callAsFunction routing
1 parent bcdecc8 commit 63da149

3 files changed

Lines changed: 35 additions & 9 deletions

File tree

Libraries/MLXLLM/Models/Llama.swift

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,19 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {
186186
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
187187
) -> (MLXArray, [Int: MLXArray]) {
188188
var h = model.embedTokens(inputs)
189-
let kvCache = cache?.compactMap { $0 }
190-
let mask = createAttentionMask(h: h, cache: kvCache?.first)
189+
let layerCount = model.layers.count
190+
let kvCache: [KVCache?] = {
191+
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
192+
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
193+
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
194+
return normalized
195+
}()
196+
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
191197
var captured: [Int: MLXArray] = [:]
192198
for (i, layer) in model.layers.enumerated() {
193-
h = layer(h, mask: mask, cache: kvCache?[i])
199+
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
200+
layer(h, mask: mask, cache: kvCache[i])
201+
}
194202
if captureLayerIDs.contains(i) { captured[i] = h }
195203
}
196204
h = model.norm(h)

Libraries/MLXLLM/Models/Qwen3.swift

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,19 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
212212
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
213213
) -> (MLXArray, [Int: MLXArray]) {
214214
var h = model.embedTokens(inputs)
215-
let kvCache = cache?.compactMap { $0 }
216-
let mask = createAttentionMask(h: h, cache: kvCache?.first)
215+
let layerCount = model.layers.count
216+
let kvCache: [KVCache?] = {
217+
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
218+
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
219+
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
220+
return normalized
221+
}()
222+
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
217223
var captured: [Int: MLXArray] = [:]
218224
for (i, layer) in model.layers.enumerated() {
219-
h = layer(h, mask: mask, cache: kvCache?[i])
225+
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
226+
layer(h, mask: mask, cache: kvCache[i])
227+
}
220228
if captureLayerIDs.contains(i) { captured[i] = h }
221229
}
222230
h = model.norm(h)

Libraries/MLXLLM/Models/Qwen3MoE.swift

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,21 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
265265
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
266266
) -> (MLXArray, [Int: MLXArray]) {
267267
var h = model.embedTokens(inputs)
268-
let kvCache = cache?.compactMap { $0 }
269-
let mask = createAttentionMask(h: h, cache: kvCache?.first)
268+
let layerCount = model.layers.count
269+
let kvCache: [KVCache?] = {
270+
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
271+
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
272+
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
273+
return normalized
274+
}()
275+
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
270276
var captured: [Int: MLXArray] = [:]
271277
for (i, layer) in model.layers.enumerated() {
272-
h = layer(h, mask: mask, cache: kvCache?[i])
278+
h = partitionedLayerCall(
279+
index: i, gpuLayerCount: model.gpuLayerCount, stream: model.streamExperts
280+
) {
281+
layer(h, mask: mask, cache: kvCache[i])
282+
}
273283
if captureLayerIDs.contains(i) { captured[i] = h }
274284
}
275285
h = model.norm(h)

0 commit comments

Comments
 (0)