diff --git a/Libraries/MLXLLM/Models/Llama.swift b/Libraries/MLXLLM/Models/Llama.swift index 1f71b66d0..b9fb62e96 100644 --- a/Libraries/MLXLLM/Models/Llama.swift +++ b/Libraries/MLXLLM/Models/Llama.swift @@ -120,7 +120,7 @@ class LlamaTransformerBlock: Module { public class LlamaModelInner: Module, LayerPartitionable { - @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding let layers: [LlamaTransformerBlock] let norm: RMSNorm @@ -162,7 +162,7 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { public let model: LlamaModelInner - @ModuleInfo(key: "lm_head") var lmHead: Linear? + @ModuleInfo(key: "lm_head") public var lmHead: Linear? public init(_ args: LlamaConfiguration) { self.vocabularySize = args.vocabularySize @@ -182,6 +182,29 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { } } + public func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = model.embedTokens(inputs) + let layerCount = model.layers.count + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layerCount) } + var normalized: [KVCache?] = Array(repeating: nil, count: layerCount) + for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v } + return normalized + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) + var captured: [Int: MLXArray] = [:] + for (i, layer) in model.layers.enumerated() { + h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) { + layer(h, mask: mask, cache: kvCache[i]) + } + if captureLayerIDs.contains(i) { captured[i] = h } + } + h = model.norm(h) + return (h, captured) + } + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { // Remove unused precomputed rotary frequencies weights.filter { diff --git a/Libraries/MLXLLM/Models/Qwen3.swift b/Libraries/MLXLLM/Models/Qwen3.swift index 6777cb844..795b2fc86 100644 --- a/Libraries/MLXLLM/Models/Qwen3.swift +++ b/Libraries/MLXLLM/Models/Qwen3.swift @@ -145,9 +145,9 @@ public class Qwen3ModelInner: Module, LayerPartitionable { // LayerPartitionable public var gpuLayerCount: Int? public var totalLayerCount: Int { layers.count } - @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding - fileprivate let layers: [Qwen3TransformerBlock] + let layers: [Qwen3TransformerBlock] let norm: RMSNorm public init(_ args: Qwen3Configuration) { @@ -185,7 +185,7 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider { public let model: Qwen3ModelInner let configuration: Qwen3Configuration - @ModuleInfo(key: "lm_head") var lmHead: Linear? + @ModuleInfo(key: "lm_head") public var lmHead: Linear? public init(_ args: Qwen3Configuration) { self.configuration = args @@ -208,6 +208,29 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider { return out } + public func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = model.embedTokens(inputs) + let layerCount = model.layers.count + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layerCount) } + var normalized: [KVCache?] = Array(repeating: nil, count: layerCount) + for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v } + return normalized + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) + var captured: [Int: MLXArray] = [:] + for (i, layer) in model.layers.enumerated() { + h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) { + layer(h, mask: mask, cache: kvCache[i]) + } + if captureLayerIDs.contains(i) { captured[i] = h } + } + h = model.norm(h) + return (h, captured) + } + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var weights = weights diff --git a/Libraries/MLXLLM/Models/Qwen3MoE.swift b/Libraries/MLXLLM/Models/Qwen3MoE.swift index c5b53fcd9..1d047b8f5 100644 --- a/Libraries/MLXLLM/Models/Qwen3MoE.swift +++ b/Libraries/MLXLLM/Models/Qwen3MoE.swift @@ -196,9 +196,9 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE { // StreamableMoE public var streamExperts: Bool = false - @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding - fileprivate let layers: [Qwen3MoeDecoderLayer] + let layers: [Qwen3MoeDecoderLayer] let norm: RMSNorm let args: Qwen3MoEConfiguration @@ -238,7 +238,7 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider { public let model: Qwen3MoEModelInner let configuration: Qwen3MoEConfiguration - @ModuleInfo(key: "lm_head") var lmHead: Linear? + @ModuleInfo(key: "lm_head") public var lmHead: Linear? public init(_ args: Qwen3MoEConfiguration) { self.configuration = args @@ -261,6 +261,31 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider { return out } + public func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = model.embedTokens(inputs) + let layerCount = model.layers.count + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layerCount) } + var normalized: [KVCache?] = Array(repeating: nil, count: layerCount) + for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v } + return normalized + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) + var captured: [Int: MLXArray] = [:] + for (i, layer) in model.layers.enumerated() { + h = partitionedLayerCall( + index: i, gpuLayerCount: model.gpuLayerCount, stream: model.streamExperts + ) { + layer(h, mask: mask, cache: kvCache[i]) + } + if captureLayerIDs.contains(i) { captured[i] = h } + } + h = model.norm(h) + return (h, captured) + } + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var sanitizedWeights = weights