From 5cf37097f3bb24d295516498e7fb83a8db69c27a Mon Sep 17 00:00:00 2001 From: Aegis-AI Date: Thu, 23 Apr 2026 20:33:59 -0700 Subject: [PATCH 1/3] feat: expose embedTokens, lmHead, callCapturing for DFlash conformance on Llama/Qwen3/Qwen3MoE Enables Sources/SwiftLM/{Llama,Qwen3,Qwen3MoE}+DFlash.swift to conform to DFlashTargetModel by exposing: - LlamaModelInner.embedTokens (public) - LlamaModel.lmHead (public) - LlamaModel.callCapturing(_:cache:captureLayerIDs:) - Qwen3ModelInner.embedTokens + layers (public) - Qwen3Model.lmHead (public) - Qwen3Model.callCapturing(_:cache:captureLayerIDs:) - Qwen3MoEModelInner.embedTokens + layers (public) - Qwen3MoEModel.lmHead (public) - Qwen3MoEModel.callCapturing(_:cache:captureLayerIDs:) Authored by: 0xClandestine (original design in unreachable commit b5762584) --- Libraries/MLXLLM/Models/Llama.swift | 19 +++++++++++++++++-- Libraries/MLXLLM/Models/Qwen3.swift | 21 ++++++++++++++++++--- Libraries/MLXLLM/Models/Qwen3MoE.swift | 21 ++++++++++++++++++--- 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/Libraries/MLXLLM/Models/Llama.swift b/Libraries/MLXLLM/Models/Llama.swift index 1f71b66d0..ccd9c0e09 100644 --- a/Libraries/MLXLLM/Models/Llama.swift +++ b/Libraries/MLXLLM/Models/Llama.swift @@ -120,7 +120,7 @@ class LlamaTransformerBlock: Module { public class LlamaModelInner: Module, LayerPartitionable { - @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding let layers: [LlamaTransformerBlock] let norm: RMSNorm @@ -162,7 +162,7 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { public let model: LlamaModelInner - @ModuleInfo(key: "lm_head") var lmHead: Linear? + @ModuleInfo(key: "lm_head") public var lmHead: Linear? public init(_ args: LlamaConfiguration) { self.vocabularySize = args.vocabularySize @@ -182,6 +182,21 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { } } + public func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = model.embedTokens(inputs) + let kvCache = cache?.compactMap { $0 } + let mask = createAttentionMask(h: h, cache: kvCache?.first) + var captured: [Int: MLXArray] = [:] + for (i, layer) in model.layers.enumerated() { + h = layer(h, mask: mask, cache: kvCache?[i]) + if captureLayerIDs.contains(i) { captured[i] = h } + } + h = model.norm(h) + return (h, captured) + } + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { // Remove unused precomputed rotary frequencies weights.filter { diff --git a/Libraries/MLXLLM/Models/Qwen3.swift b/Libraries/MLXLLM/Models/Qwen3.swift index 6777cb844..cba1aa4b0 100644 --- a/Libraries/MLXLLM/Models/Qwen3.swift +++ b/Libraries/MLXLLM/Models/Qwen3.swift @@ -145,9 +145,9 @@ public class Qwen3ModelInner: Module, LayerPartitionable { // LayerPartitionable public var gpuLayerCount: Int? public var totalLayerCount: Int { layers.count } - @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding - fileprivate let layers: [Qwen3TransformerBlock] + public let layers: [Qwen3TransformerBlock] let norm: RMSNorm public init(_ args: Qwen3Configuration) { @@ -185,7 +185,7 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider { public let model: Qwen3ModelInner let configuration: Qwen3Configuration - @ModuleInfo(key: "lm_head") var lmHead: Linear? + @ModuleInfo(key: "lm_head") public var lmHead: Linear? public init(_ args: Qwen3Configuration) { self.configuration = args @@ -208,6 +208,21 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider { return out } + public func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = model.embedTokens(inputs) + let kvCache = cache?.compactMap { $0 } + let mask = createAttentionMask(h: h, cache: kvCache?.first) + var captured: [Int: MLXArray] = [:] + for (i, layer) in model.layers.enumerated() { + h = layer(h, mask: mask, cache: kvCache?[i]) + if captureLayerIDs.contains(i) { captured[i] = h } + } + h = model.norm(h) + return (h, captured) + } + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var weights = weights diff --git a/Libraries/MLXLLM/Models/Qwen3MoE.swift b/Libraries/MLXLLM/Models/Qwen3MoE.swift index c5b53fcd9..29dd0398f 100644 --- a/Libraries/MLXLLM/Models/Qwen3MoE.swift +++ b/Libraries/MLXLLM/Models/Qwen3MoE.swift @@ -196,9 +196,9 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE { // StreamableMoE public var streamExperts: Bool = false - @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding - fileprivate let layers: [Qwen3MoeDecoderLayer] + public let layers: [Qwen3MoeDecoderLayer] let norm: RMSNorm let args: Qwen3MoEConfiguration @@ -238,7 +238,7 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider { public let model: Qwen3MoEModelInner let configuration: Qwen3MoEConfiguration - @ModuleInfo(key: "lm_head") var lmHead: Linear? + @ModuleInfo(key: "lm_head") public var lmHead: Linear? public init(_ args: Qwen3MoEConfiguration) { self.configuration = args @@ -261,6 +261,21 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider { return out } + public func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = model.embedTokens(inputs) + let kvCache = cache?.compactMap { $0 } + let mask = createAttentionMask(h: h, cache: kvCache?.first) + var captured: [Int: MLXArray] = [:] + for (i, layer) in model.layers.enumerated() { + h = layer(h, mask: mask, cache: kvCache?[i]) + if captureLayerIDs.contains(i) { captured[i] = h } + } + h = model.norm(h) + return (h, captured) + } + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var sanitizedWeights = weights From bcdecc8a8c62198ffcc8a4e2697efd9d5d9fb511 Mon Sep 17 00:00:00 2001 From: Aegis-AI Date: Thu, 23 Apr 2026 20:44:27 -0700 Subject: [PATCH 2/3] fix: revert layers to internal to fix public-property-of-internal-type error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Qwen3TransformerBlock and Qwen3MoeDecoderLayer are internal types. Declaring public let layers of those types violates Swift's access control rules. callCapturing() is defined in the same module so can access internal layers directly — no need to expose them publicly. --- Libraries/MLXLLM/Models/Qwen3.swift | 2 +- Libraries/MLXLLM/Models/Qwen3MoE.swift | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Libraries/MLXLLM/Models/Qwen3.swift b/Libraries/MLXLLM/Models/Qwen3.swift index cba1aa4b0..729edbb76 100644 --- a/Libraries/MLXLLM/Models/Qwen3.swift +++ b/Libraries/MLXLLM/Models/Qwen3.swift @@ -147,7 +147,7 @@ public class Qwen3ModelInner: Module, LayerPartitionable { public var totalLayerCount: Int { layers.count } @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding - public let layers: [Qwen3TransformerBlock] + let layers: [Qwen3TransformerBlock] let norm: RMSNorm public init(_ args: Qwen3Configuration) { diff --git a/Libraries/MLXLLM/Models/Qwen3MoE.swift b/Libraries/MLXLLM/Models/Qwen3MoE.swift index 29dd0398f..916f88044 100644 --- a/Libraries/MLXLLM/Models/Qwen3MoE.swift +++ b/Libraries/MLXLLM/Models/Qwen3MoE.swift @@ -198,7 +198,7 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE { @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding - public let layers: [Qwen3MoeDecoderLayer] + let layers: [Qwen3MoeDecoderLayer] let norm: RMSNorm let args: Qwen3MoEConfiguration From 63da149c41f7434466c44d7fd43f76c340351740 Mon Sep 17 00:00:00 2001 From: Aegis-AI Date: Thu, 23 Apr 2026 21:00:39 -0700 Subject: [PATCH 3/3] =?UTF-8?q?fix:=20address=20Copilot=20review=20?= =?UTF-8?q?=E2=80=94=20safe=20cache=20normalization=20+=20partitionedLayer?= =?UTF-8?q?Call?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace compactMap{} cache handling with explicit normalization that preserves layer↔cache index alignment (nil-padding to layers.count) - Wrap per-layer calls in partitionedLayerCall() to respect gpuLayerCount CPU/GPU routing (matches callAsFunction behavior) - Qwen3MoE additionally passes stream:model.streamExperts to match its Qwen3MoEModelInner.callAsFunction routing --- Libraries/MLXLLM/Models/Llama.swift | 14 +++++++++++--- Libraries/MLXLLM/Models/Qwen3.swift | 14 +++++++++++--- Libraries/MLXLLM/Models/Qwen3MoE.swift | 16 +++++++++++++--- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/Libraries/MLXLLM/Models/Llama.swift b/Libraries/MLXLLM/Models/Llama.swift index ccd9c0e09..b9fb62e96 100644 --- a/Libraries/MLXLLM/Models/Llama.swift +++ b/Libraries/MLXLLM/Models/Llama.swift @@ -186,11 +186,19 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set ) -> (MLXArray, [Int: MLXArray]) { var h = model.embedTokens(inputs) - let kvCache = cache?.compactMap { $0 } - let mask = createAttentionMask(h: h, cache: kvCache?.first) + let layerCount = model.layers.count + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layerCount) } + var normalized: [KVCache?] = Array(repeating: nil, count: layerCount) + for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v } + return normalized + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) var captured: [Int: MLXArray] = [:] for (i, layer) in model.layers.enumerated() { - h = layer(h, mask: mask, cache: kvCache?[i]) + h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) { + layer(h, mask: mask, cache: kvCache[i]) + } if captureLayerIDs.contains(i) { captured[i] = h } } h = model.norm(h) diff --git a/Libraries/MLXLLM/Models/Qwen3.swift b/Libraries/MLXLLM/Models/Qwen3.swift index 729edbb76..795b2fc86 100644 --- a/Libraries/MLXLLM/Models/Qwen3.swift +++ b/Libraries/MLXLLM/Models/Qwen3.swift @@ -212,11 +212,19 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider { _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set ) -> (MLXArray, [Int: MLXArray]) { var h = model.embedTokens(inputs) - let kvCache = cache?.compactMap { $0 } - let mask = createAttentionMask(h: h, cache: kvCache?.first) + let layerCount = model.layers.count + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layerCount) } + var normalized: [KVCache?] = Array(repeating: nil, count: layerCount) + for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v } + return normalized + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) var captured: [Int: MLXArray] = [:] for (i, layer) in model.layers.enumerated() { - h = layer(h, mask: mask, cache: kvCache?[i]) + h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) { + layer(h, mask: mask, cache: kvCache[i]) + } if captureLayerIDs.contains(i) { captured[i] = h } } h = model.norm(h) diff --git a/Libraries/MLXLLM/Models/Qwen3MoE.swift b/Libraries/MLXLLM/Models/Qwen3MoE.swift index 916f88044..1d047b8f5 100644 --- a/Libraries/MLXLLM/Models/Qwen3MoE.swift +++ b/Libraries/MLXLLM/Models/Qwen3MoE.swift @@ -265,11 +265,21 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider { _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set ) -> (MLXArray, [Int: MLXArray]) { var h = model.embedTokens(inputs) - let kvCache = cache?.compactMap { $0 } - let mask = createAttentionMask(h: h, cache: kvCache?.first) + let layerCount = model.layers.count + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layerCount) } + var normalized: [KVCache?] = Array(repeating: nil, count: layerCount) + for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v } + return normalized + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) var captured: [Int: MLXArray] = [:] for (i, layer) in model.layers.enumerated() { - h = layer(h, mask: mask, cache: kvCache?[i]) + h = partitionedLayerCall( + index: i, gpuLayerCount: model.gpuLayerCount, stream: model.streamExperts + ) { + layer(h, mask: mask, cache: kvCache[i]) + } if captureLayerIDs.contains(i) { captured[i] = h } } h = model.norm(h)