Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions Libraries/MLXLLM/Models/Llama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class LlamaTransformerBlock: Module {

public class LlamaModelInner: Module, LayerPartitionable {

@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
@ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding

let layers: [LlamaTransformerBlock]
let norm: RMSNorm
Expand Down Expand Up @@ -162,7 +162,7 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {

public let model: LlamaModelInner

@ModuleInfo(key: "lm_head") var lmHead: Linear?
@ModuleInfo(key: "lm_head") public var lmHead: Linear?

public init(_ args: LlamaConfiguration) {
self.vocabularySize = args.vocabularySize
Expand All @@ -182,6 +182,29 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {
}
}

public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = model.embedTokens(inputs)
let layerCount = model.layers.count
let kvCache: [KVCache?] = {
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
layer(h, mask: mask, cache: kvCache[i])
}
if captureLayerIDs.contains(i) { captured[i] = h }
Comment on lines +188 to +202
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing builds kvCache via compactMap { $0 }, which (1) drops nil entries and breaks layer↔cache alignment and (2) makes kvCache?[i] unsafe because standard array subscripting will trap when i is out of range. Normalize to an array with model.layers.count entries (padding with nils) and use safe indexing.

Copilot uses AI. Check for mistakes.
}
h = model.norm(h)
return (h, captured)
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
// Remove unused precomputed rotary frequencies
weights.filter {
Expand Down
29 changes: 26 additions & 3 deletions Libraries/MLXLLM/Models/Qwen3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ public class Qwen3ModelInner: Module, LayerPartitionable {
// LayerPartitionable
public var gpuLayerCount: Int?
public var totalLayerCount: Int { layers.count }
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
@ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding

fileprivate let layers: [Qwen3TransformerBlock]
let layers: [Qwen3TransformerBlock]
let norm: RMSNorm

public init(_ args: Qwen3Configuration) {
Expand Down Expand Up @@ -185,7 +185,7 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
public let model: Qwen3ModelInner
let configuration: Qwen3Configuration

@ModuleInfo(key: "lm_head") var lmHead: Linear?
@ModuleInfo(key: "lm_head") public var lmHead: Linear?

public init(_ args: Qwen3Configuration) {
self.configuration = args
Expand All @@ -208,6 +208,29 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
return out
}

public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = model.embedTokens(inputs)
let layerCount = model.layers.count
let kvCache: [KVCache?] = {
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
layer(h, mask: mask, cache: kvCache[i])
}
if captureLayerIDs.contains(i) { captured[i] = h }
}
h = model.norm(h)
return (h, captured)
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var weights = weights

Expand Down
31 changes: 28 additions & 3 deletions Libraries/MLXLLM/Models/Qwen3MoE.swift
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE {
// StreamableMoE
public var streamExperts: Bool = false

@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
@ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding

fileprivate let layers: [Qwen3MoeDecoderLayer]
let layers: [Qwen3MoeDecoderLayer]
let norm: RMSNorm
let args: Qwen3MoEConfiguration

Expand Down Expand Up @@ -238,7 +238,7 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
public let model: Qwen3MoEModelInner
let configuration: Qwen3MoEConfiguration

@ModuleInfo(key: "lm_head") var lmHead: Linear?
@ModuleInfo(key: "lm_head") public var lmHead: Linear?

public init(_ args: Qwen3MoEConfiguration) {
self.configuration = args
Expand All @@ -261,6 +261,31 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
return out
}

public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = model.embedTokens(inputs)
let layerCount = model.layers.count
let kvCache: [KVCache?] = {
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = partitionedLayerCall(
index: i, gpuLayerCount: model.gpuLayerCount, stream: model.streamExperts
) {
layer(h, mask: mask, cache: kvCache[i])
}
if captureLayerIDs.contains(i) { captured[i] = h }
}
h = model.norm(h)
return (h, captured)
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var sanitizedWeights = weights

Expand Down
Loading