feat: expose embedTokens, lmHead, callCapturing for DFlash conformance on Llama/Qwen3/Qwen3MoE#31
Conversation
…e on Llama/Qwen3/Qwen3MoE
Enables Sources/SwiftLM/{Llama,Qwen3,Qwen3MoE}+DFlash.swift to conform
to DFlashTargetModel by exposing:
- LlamaModelInner.embedTokens (public)
- LlamaModel.lmHead (public)
- LlamaModel.callCapturing(_:cache:captureLayerIDs:)
- Qwen3ModelInner.embedTokens + layers (public)
- Qwen3Model.lmHead (public)
- Qwen3Model.callCapturing(_:cache:captureLayerIDs:)
- Qwen3MoEModelInner.embedTokens + layers (public)
- Qwen3MoEModel.lmHead (public)
- Qwen3MoEModel.callCapturing(_:cache:captureLayerIDs:)
Authored by: 0xClandestine (original design in unreachable commit b5762584)
|
Much love ser! Mind including Kimi support as well? 0xClandestine/SwiftLM#1 |
There was a problem hiding this comment.
Pull request overview
Exposes a small public API surface on Llama/Qwen3/Qwen3MoE models to support downstream DFlashTargetModel conformance (per SwiftLM PR ml-explore#78).
Changes:
- Make
embedTokenspublicly accessible on*ModelInner. - Make
lmHeadpublicly accessible on*Model. - Add a new public
callCapturing(_:cache:captureLayerIDs:)API to return final hidden states plus selected intermediate layer outputs.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
Libraries/MLXLLM/Models/Llama.swift |
Makes embedTokens/lmHead public and adds callCapturing for capturing intermediate states. |
Libraries/MLXLLM/Models/Qwen3.swift |
Makes embedTokens/lmHead public, exposes layers, and adds callCapturing. |
Libraries/MLXLLM/Models/Qwen3MoE.swift |
Makes embedTokens/lmHead public, exposes layers, and adds callCapturing for MoE model. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding | ||
|
|
||
| fileprivate let layers: [Qwen3TransformerBlock] | ||
| public let layers: [Qwen3TransformerBlock] |
There was a problem hiding this comment.
layers is declared public, but its element type (Qwen3TransformerBlock) is internal (no access modifier), which is not allowed for a public API and will fail to compile. Either keep layers non-public (e.g., fileprivate as before) or make the layer type(s) public as well (larger API surface).
| public let layers: [Qwen3TransformerBlock] | |
| let layers: [Qwen3TransformerBlock] |
| let kvCache = cache?.compactMap { $0 } | ||
| let mask = createAttentionMask(h: h, cache: kvCache?.first) | ||
| var captured: [Int: MLXArray] = [:] | ||
| for (i, layer) in model.layers.enumerated() { | ||
| h = layer(h, mask: mask, cache: kvCache?[i]) |
There was a problem hiding this comment.
callCapturing builds kvCache via compactMap { $0 }, which (1) drops nil entries and breaks layer↔cache alignment and (2) makes kvCache?[i] unsafe because standard array subscripting will trap when i is out of range. Normalize to an array with layers.count entries (padding with nils) and use safe indexing, similar to Qwen35TextModelInner.callCapturing.
| let kvCache = cache?.compactMap { $0 } | |
| let mask = createAttentionMask(h: h, cache: kvCache?.first) | |
| var captured: [Int: MLXArray] = [:] | |
| for (i, layer) in model.layers.enumerated() { | |
| h = layer(h, mask: mask, cache: kvCache?[i]) | |
| let kvCache: [KVCache?] = { | |
| var normalized = Array(repeating: Optional<KVCache>.none, count: model.layers.count) | |
| if let cache { | |
| for (i, value) in cache.prefix(model.layers.count).enumerated() { | |
| normalized[i] = value | |
| } | |
| } | |
| return normalized | |
| }() | |
| let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) | |
| var captured: [Int: MLXArray] = [:] | |
| for (i, layer) in model.layers.enumerated() { | |
| let layerCache = i < kvCache.count ? kvCache[i] : nil | |
| h = layer(h, mask: mask, cache: layerCache) |
| let mask = createAttentionMask(h: h, cache: kvCache?.first) | ||
| var captured: [Int: MLXArray] = [:] | ||
| for (i, layer) in model.layers.enumerated() { | ||
| h = layer(h, mask: mask, cache: kvCache?[i]) |
There was a problem hiding this comment.
callCapturing loops through layers without partitionedLayerCall(...), so it ignores gpuLayerCount routing (and any per-layer streaming behavior used elsewhere). This makes callCapturing behave differently from callAsFunction and can break CPU/GPU partitioning. Wrap the layer forward pass in partitionedLayerCall(index:gpuLayerCount:) (and stream flags where applicable).
| h = layer(h, mask: mask, cache: kvCache?[i]) | |
| h = layer.partitionedLayerCall(index: i, gpuLayerCount: configuration.gpuLayerCount) { | |
| layer(h, mask: mask, cache: kvCache?[i]) | |
| } |
| @ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding | ||
|
|
||
| fileprivate let layers: [Qwen3MoeDecoderLayer] | ||
| public let layers: [Qwen3MoeDecoderLayer] |
There was a problem hiding this comment.
layers is declared public, but its element type (Qwen3MoeDecoderLayer) is internal, which is not allowed for a public API and will fail to compile. Either keep layers non-public (e.g., fileprivate as before) or make the layer type(s) public as well (larger API surface).
| public let layers: [Qwen3MoeDecoderLayer] | |
| let layers: [Qwen3MoeDecoderLayer] |
| public func callCapturing( | ||
| _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int> | ||
| ) -> (MLXArray, [Int: MLXArray]) { | ||
| var h = model.embedTokens(inputs) | ||
| let kvCache = cache?.compactMap { $0 } | ||
| let mask = createAttentionMask(h: h, cache: kvCache?.first) | ||
| var captured: [Int: MLXArray] = [:] | ||
| for (i, layer) in model.layers.enumerated() { | ||
| h = layer(h, mask: mask, cache: kvCache?[i]) |
There was a problem hiding this comment.
callCapturing builds kvCache via compactMap { $0 }, which (1) drops nil entries and breaks layer↔cache alignment and (2) makes kvCache?[i] unsafe because standard array subscripting will trap when i is out of range. Normalize to an array with model.layers.count entries (padding with nils) and use safe indexing.
| public func callCapturing( | |
| _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int> | |
| ) -> (MLXArray, [Int: MLXArray]) { | |
| var h = model.embedTokens(inputs) | |
| let kvCache = cache?.compactMap { $0 } | |
| let mask = createAttentionMask(h: h, cache: kvCache?.first) | |
| var captured: [Int: MLXArray] = [:] | |
| for (i, layer) in model.layers.enumerated() { | |
| h = layer(h, mask: mask, cache: kvCache?[i]) | |
| private func normalizedKVCache(_ cache: [KVCache?]?) -> [KVCache?] { | |
| let layerCount = model.layers.count | |
| guard var normalized = cache else { | |
| return Array(repeating: nil, count: layerCount) | |
| } | |
| if normalized.count > layerCount { | |
| normalized = Array(normalized.prefix(layerCount)) | |
| } else if normalized.count < layerCount { | |
| normalized.append(contentsOf: Array(repeating: nil, count: layerCount - normalized.count)) | |
| } | |
| return normalized | |
| } | |
| private func cacheAt(_ cache: [KVCache?], index: Int) -> KVCache? { | |
| cache.indices.contains(index) ? cache[index] : nil | |
| } | |
| public func callCapturing( | |
| _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int> | |
| ) -> (MLXArray, [Int: MLXArray]) { | |
| var h = model.embedTokens(inputs) | |
| let kvCache = normalizedKVCache(cache) | |
| let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) | |
| var captured: [Int: MLXArray] = [:] | |
| for (i, layer) in model.layers.enumerated() { | |
| h = layer(h, mask: mask, cache: cacheAt(kvCache, index: i)) |
| let mask = createAttentionMask(h: h, cache: kvCache?.first) | ||
| var captured: [Int: MLXArray] = [:] | ||
| for (i, layer) in model.layers.enumerated() { | ||
| h = layer(h, mask: mask, cache: kvCache?[i]) |
There was a problem hiding this comment.
callCapturing does not use partitionedLayerCall(index:gpuLayerCount:stream:cacheToEval:), so it ignores both CPU/GPU layer partitioning and streamExperts behavior that Qwen3MoEModelInner.callAsFunction relies on. Wrap the per-layer call using partitionedLayerCall(..., stream: model.streamExperts, cacheToEval: ...) to match the normal forward path.
| h = layer(h, mask: mask, cache: kvCache?[i]) | |
| h = layer.partitionedLayerCall( | |
| index: i, | |
| gpuLayerCount: configuration.gpuLayerCount, | |
| stream: model.streamExperts, | |
| cacheToEval: kvCache?[i] | |
| ) { layer, cache in | |
| layer(h, mask: mask, cache: cache) | |
| } |
| var h = model.embedTokens(inputs) | ||
| let kvCache = cache?.compactMap { $0 } | ||
| let mask = createAttentionMask(h: h, cache: kvCache?.first) | ||
| var captured: [Int: MLXArray] = [:] | ||
| for (i, layer) in model.layers.enumerated() { | ||
| h = layer(h, mask: mask, cache: kvCache?[i]) | ||
| if captureLayerIDs.contains(i) { captured[i] = h } |
There was a problem hiding this comment.
callCapturing builds kvCache via compactMap { $0 }, which (1) drops nil entries and breaks layer↔cache alignment and (2) makes kvCache?[i] unsafe because standard array subscripting will trap when i is out of range. Normalize to an array with model.layers.count entries (padding with nils) and use safe indexing.
| let mask = createAttentionMask(h: h, cache: kvCache?.first) | ||
| var captured: [Int: MLXArray] = [:] | ||
| for (i, layer) in model.layers.enumerated() { | ||
| h = layer(h, mask: mask, cache: kvCache?[i]) |
There was a problem hiding this comment.
callCapturing bypasses partitionedLayerCall(index:gpuLayerCount:), so it ignores gpuLayerCount CPU/GPU routing used in LlamaModelInner.callAsFunction. Wrap the per-layer forward pass in partitionedLayerCall to keep behavior consistent with the normal path.
| h = layer(h, mask: mask, cache: kvCache?[i]) | |
| h = model.partitionedLayerCall(index: i, gpuLayerCount: model.args.gpuLayerCount) { | |
| layer(h, mask: mask, cache: kvCache?[i]) | |
| } |
…e error Qwen3TransformerBlock and Qwen3MoeDecoderLayer are internal types. Declaring public let layers of those types violates Swift's access control rules. callCapturing() is defined in the same module so can access internal layers directly — no need to expose them publicly.
|
Thanks @0xClandestine! Fixed the build error — Re: Kimi — |
Likewise ser, appreciate the quick responses and upstream work. Also are you sure Llama+DFlash.swift specifically supports K2.5->K2.6? |
|
Correction on Kimi: I was wrong. Kimi K2.5/K2.6 uses |
…ayerCall
- Replace compactMap{} cache handling with explicit normalization that
preserves layer↔cache index alignment (nil-padding to layers.count)
- Wrap per-layer calls in partitionedLayerCall() to respect gpuLayerCount
CPU/GPU routing (matches callAsFunction behavior)
- Qwen3MoE additionally passes stream:model.streamExperts to match
its Qwen3MoEModelInner.callAsFunction routing
|
@0xClandestine I did a wrong force push(push should never be forced...), once CI goes green and a maintainer merges this, I'll bring d9f824b's 3 files back into SwiftLM PR ml-explore#78 with the correct submodule pin. |
No problem, I'm happy to take this on unless you've got it? |
Sure, please go ahead. I don't have it yet. :) |
Summary
Exposes the minimum public API surface needed by
Sources/SwiftLM/{Llama,Qwen3,Qwen3MoE}+DFlash.swiftin SwiftLM PR ml-explore#78 to conform toDFlashTargetModel.Changes
LlamaModelInner.embedTokens→publicLlamaModel.lmHead→publicLlamaModel.callCapturing(_:cache:captureLayerIDs:)— new public methodQwen3ModelInner,Qwen3Model,Qwen3MoEModelInner,Qwen3MoEModelContext
Originally designed by @0xClandestine in commit
b5762584(local only, never pushed to a public remote). This PR makes those APIs available fromSharpAI/mlx-swift-lmmain so the submodule pin can be updated cleanly.Required by: SharpAI/SwiftLM#78