Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions Libraries/MLXLLM/Models/Llama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,27 @@ public class LlamaModelInner: Module, LayerPartitionable {

return norm(h)
}

public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = embedTokens(inputs)
let kvCache: [KVCache?] = {
guard let c = cache else { return Array(repeating: nil, count: layers.count) }
var normalized: [KVCache?] = Array(repeating: nil, count: layers.count)
for (i, v) in c.prefix(layers.count).enumerated() { normalized[i] = v }
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in layers.enumerated() {
h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) {
layer(h, mask: mask, cache: kvCache[i])
}
if captureLayerIDs.contains(i) { captured[i] = h }
}
return (norm(h), captured)
}
}

/// Model for Llama and Mistral model types.
Expand Down Expand Up @@ -182,29 +203,6 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {
}
}

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing was removed from the public LlamaModel wrapper. Since it was introduced as part of the public API in the prior PR, consider keeping a forwarding LlamaModel.callCapturing(...) that delegates to model.callCapturing(...) (optionally marking it deprecated) to avoid a breaking change for any clients already calling the wrapper type.

Suggested change
@available(*, deprecated, message: "Use the underlying model.callCapturing(...) instead.")
public func callCapturing(_ inputs: MLXArray, cache: [KVCache]?) -> Any {
model.callCapturing(inputs, cache: cache)
}

Copilot uses AI. Check for mistakes.
public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = model.embedTokens(inputs)
let layerCount = model.layers.count
let kvCache: [KVCache?] = {
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
layer(h, mask: mask, cache: kvCache[i])
}
if captureLayerIDs.contains(i) { captured[i] = h }
}
h = model.norm(h)
return (h, captured)
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
// Remove unused precomputed rotary frequencies
weights.filter {
Expand Down
44 changes: 21 additions & 23 deletions Libraries/MLXLLM/Models/Qwen3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,27 @@ public class Qwen3ModelInner: Module, LayerPartitionable {

return norm(h)
}

public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = embedTokens(inputs)
let kvCache: [KVCache?] = {
guard let c = cache else { return Array(repeating: nil, count: layers.count) }
var normalized: [KVCache?] = Array(repeating: nil, count: layers.count)
for (i, v) in c.prefix(layers.count).enumerated() { normalized[i] = v }
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in layers.enumerated() {
h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) {
layer(h, mask: mask, cache: kvCache[i])
}
if captureLayerIDs.contains(i) { captured[i] = h }
}
return (norm(h), captured)
}
}

public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
Expand Down Expand Up @@ -208,29 +229,6 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
return out
}

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing was removed from the public Qwen3Model wrapper. To preserve source compatibility for users that may have started calling Qwen3Model.callCapturing(...) after #31, consider keeping a wrapper method that forwards to model.callCapturing(...) (and optionally deprecate it if you want callers to move to the inner type).

Suggested change
public func callCapturing(_ inputs: MLXArray, cache: [KVCache]?) -> (MLXArray, [MLXArray]) {
var (out, captured) = model.callCapturing(inputs, cache: cache)
if let lmHead {
out = lmHead(out)
} else {
out = model.embedTokens.asLinear(out)
}
return (out, captured)
}

Copilot uses AI. Check for mistakes.
public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = model.embedTokens(inputs)
let layerCount = model.layers.count
let kvCache: [KVCache?] = {
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
layer(h, mask: mask, cache: kvCache[i])
}
if captureLayerIDs.contains(i) { captured[i] = h }
}
h = model.norm(h)
return (h, captured)
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var weights = weights

Expand Down
46 changes: 21 additions & 25 deletions Libraries/MLXLLM/Models/Qwen3MoE.swift
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,27 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE {

return norm(h)
}

public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = embedTokens(inputs)
let kvCache: [KVCache?] = {
guard let c = cache else { return Array(repeating: nil, count: layers.count) }
var normalized: [KVCache?] = Array(repeating: nil, count: layers.count)
for (i, v) in c.prefix(layers.count).enumerated() { normalized[i] = v }
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in layers.enumerated() {
h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount, stream: streamExperts) {
layer(h, mask: mask, cache: kvCache[i])
}
if captureLayerIDs.contains(i) { captured[i] = h }
}
return (norm(h), captured)
}
}

public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
Expand Down Expand Up @@ -261,31 +282,6 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
return out
}

public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = model.embedTokens(inputs)
let layerCount = model.layers.count
let kvCache: [KVCache?] = {
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = partitionedLayerCall(
index: i, gpuLayerCount: model.gpuLayerCount, stream: model.streamExperts
) {
layer(h, mask: mask, cache: kvCache[i])
}
if captureLayerIDs.contains(i) { captured[i] = h }
}
h = model.norm(h)
return (h, captured)
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var sanitizedWeights = weights

Expand Down
Loading