From 8eb7c0ef437e6371a28f459f08fd1fbbaec449d3 Mon Sep 17 00:00:00 2001 From: Aegis-AI Date: Thu, 23 Apr 2026 21:19:56 -0700 Subject: [PATCH] fix: move callCapturing to *ModelInner where callers expect it The DFlashTargetModel conformances in SwiftLM call model.callCapturing() where model is *ModelInner, not the outer *Model wrapper. --- Libraries/MLXLLM/Models/Llama.swift | 44 ++++++++++++------------ Libraries/MLXLLM/Models/Qwen3.swift | 44 ++++++++++++------------ Libraries/MLXLLM/Models/Qwen3MoE.swift | 46 ++++++++++++-------------- 3 files changed, 63 insertions(+), 71 deletions(-) diff --git a/Libraries/MLXLLM/Models/Llama.swift b/Libraries/MLXLLM/Models/Llama.swift index b9fb62e96..b8a93451d 100644 --- a/Libraries/MLXLLM/Models/Llama.swift +++ b/Libraries/MLXLLM/Models/Llama.swift @@ -152,6 +152,27 @@ public class LlamaModelInner: Module, LayerPartitionable { return norm(h) } + + public func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = embedTokens(inputs) + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layers.count) } + var normalized: [KVCache?] = Array(repeating: nil, count: layers.count) + for (i, v) in c.prefix(layers.count).enumerated() { normalized[i] = v } + return normalized + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) + var captured: [Int: MLXArray] = [:] + for (i, layer) in layers.enumerated() { + h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) { + layer(h, mask: mask, cache: kvCache[i]) + } + if captureLayerIDs.contains(i) { captured[i] = h } + } + return (norm(h), captured) + } } /// Model for Llama and Mistral model types. @@ -182,29 +203,6 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { } } - public func callCapturing( - _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set - ) -> (MLXArray, [Int: MLXArray]) { - var h = model.embedTokens(inputs) - let layerCount = model.layers.count - let kvCache: [KVCache?] = { - guard let c = cache else { return Array(repeating: nil, count: layerCount) } - var normalized: [KVCache?] = Array(repeating: nil, count: layerCount) - for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v } - return normalized - }() - let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) - var captured: [Int: MLXArray] = [:] - for (i, layer) in model.layers.enumerated() { - h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) { - layer(h, mask: mask, cache: kvCache[i]) - } - if captureLayerIDs.contains(i) { captured[i] = h } - } - h = model.norm(h) - return (h, captured) - } - public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { // Remove unused precomputed rotary frequencies weights.filter { diff --git a/Libraries/MLXLLM/Models/Qwen3.swift b/Libraries/MLXLLM/Models/Qwen3.swift index 795b2fc86..35ed5df97 100644 --- a/Libraries/MLXLLM/Models/Qwen3.swift +++ b/Libraries/MLXLLM/Models/Qwen3.swift @@ -176,6 +176,27 @@ public class Qwen3ModelInner: Module, LayerPartitionable { return norm(h) } + + public func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = embedTokens(inputs) + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layers.count) } + var normalized: [KVCache?] = Array(repeating: nil, count: layers.count) + for (i, v) in c.prefix(layers.count).enumerated() { normalized[i] = v } + return normalized + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) + var captured: [Int: MLXArray] = [:] + for (i, layer) in layers.enumerated() { + h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) { + layer(h, mask: mask, cache: kvCache[i]) + } + if captureLayerIDs.contains(i) { captured[i] = h } + } + return (norm(h), captured) + } } public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider { @@ -208,29 +229,6 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider { return out } - public func callCapturing( - _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set - ) -> (MLXArray, [Int: MLXArray]) { - var h = model.embedTokens(inputs) - let layerCount = model.layers.count - let kvCache: [KVCache?] = { - guard let c = cache else { return Array(repeating: nil, count: layerCount) } - var normalized: [KVCache?] = Array(repeating: nil, count: layerCount) - for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v } - return normalized - }() - let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) - var captured: [Int: MLXArray] = [:] - for (i, layer) in model.layers.enumerated() { - h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) { - layer(h, mask: mask, cache: kvCache[i]) - } - if captureLayerIDs.contains(i) { captured[i] = h } - } - h = model.norm(h) - return (h, captured) - } - public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var weights = weights diff --git a/Libraries/MLXLLM/Models/Qwen3MoE.swift b/Libraries/MLXLLM/Models/Qwen3MoE.swift index 1d047b8f5..5b72a33a8 100644 --- a/Libraries/MLXLLM/Models/Qwen3MoE.swift +++ b/Libraries/MLXLLM/Models/Qwen3MoE.swift @@ -229,6 +229,27 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE { return norm(h) } + + public func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = embedTokens(inputs) + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layers.count) } + var normalized: [KVCache?] = Array(repeating: nil, count: layers.count) + for (i, v) in c.prefix(layers.count).enumerated() { normalized[i] = v } + return normalized + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) + var captured: [Int: MLXArray] = [:] + for (i, layer) in layers.enumerated() { + h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount, stream: streamExperts) { + layer(h, mask: mask, cache: kvCache[i]) + } + if captureLayerIDs.contains(i) { captured[i] = h } + } + return (norm(h), captured) + } } public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider { @@ -261,31 +282,6 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider { return out } - public func callCapturing( - _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set - ) -> (MLXArray, [Int: MLXArray]) { - var h = model.embedTokens(inputs) - let layerCount = model.layers.count - let kvCache: [KVCache?] = { - guard let c = cache else { return Array(repeating: nil, count: layerCount) } - var normalized: [KVCache?] = Array(repeating: nil, count: layerCount) - for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v } - return normalized - }() - let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) - var captured: [Int: MLXArray] = [:] - for (i, layer) in model.layers.enumerated() { - h = partitionedLayerCall( - index: i, gpuLayerCount: model.gpuLayerCount, stream: model.streamExperts - ) { - layer(h, mask: mask, cache: kvCache[i]) - } - if captureLayerIDs.contains(i) { captured[i] = h } - } - h = model.norm(h) - return (h, captured) - } - public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var sanitizedWeights = weights