Skip to content

feat: expose embedTokens, lmHead, callCapturing for DFlash conformance on Llama/Qwen3/Qwen3MoE#31

Merged
solderzzc merged 3 commits into
mainfrom
feat/dflash-public-api
Apr 24, 2026
Merged

feat: expose embedTokens, lmHead, callCapturing for DFlash conformance on Llama/Qwen3/Qwen3MoE#31
solderzzc merged 3 commits into
mainfrom
feat/dflash-public-api

Conversation

@solderzzc
Copy link
Copy Markdown
Member

Summary

Exposes the minimum public API surface needed by Sources/SwiftLM/{Llama,Qwen3,Qwen3MoE}+DFlash.swift in SwiftLM PR ml-explore#78 to conform to DFlashTargetModel.

Changes

  • LlamaModelInner.embedTokenspublic
  • LlamaModel.lmHeadpublic
  • LlamaModel.callCapturing(_:cache:captureLayerIDs:) — new public method
  • Same pattern applied to Qwen3ModelInner, Qwen3Model, Qwen3MoEModelInner, Qwen3MoEModel

Context

Originally designed by @0xClandestine in commit b5762584 (local only, never pushed to a public remote). This PR makes those APIs available from SharpAI/mlx-swift-lm main so the submodule pin can be updated cleanly.

Required by: SharpAI/SwiftLM#78

…e on Llama/Qwen3/Qwen3MoE

Enables Sources/SwiftLM/{Llama,Qwen3,Qwen3MoE}+DFlash.swift to conform
to DFlashTargetModel by exposing:
  - LlamaModelInner.embedTokens (public)
  - LlamaModel.lmHead (public)
  - LlamaModel.callCapturing(_:cache:captureLayerIDs:)
  - Qwen3ModelInner.embedTokens + layers (public)
  - Qwen3Model.lmHead (public)
  - Qwen3Model.callCapturing(_:cache:captureLayerIDs:)
  - Qwen3MoEModelInner.embedTokens + layers (public)
  - Qwen3MoEModel.lmHead (public)
  - Qwen3MoEModel.callCapturing(_:cache:captureLayerIDs:)

Authored by: 0xClandestine (original design in unreachable commit b5762584)
Copilot AI review requested due to automatic review settings April 24, 2026 03:34
@0xClandestine
Copy link
Copy Markdown

Much love ser! Mind including Kimi support as well? 0xClandestine/SwiftLM#1

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Exposes a small public API surface on Llama/Qwen3/Qwen3MoE models to support downstream DFlashTargetModel conformance (per SwiftLM PR ml-explore#78).

Changes:

  • Make embedTokens publicly accessible on *ModelInner.
  • Make lmHead publicly accessible on *Model.
  • Add a new public callCapturing(_:cache:captureLayerIDs:) API to return final hidden states plus selected intermediate layer outputs.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 8 comments.

File Description
Libraries/MLXLLM/Models/Llama.swift Makes embedTokens/lmHead public and adds callCapturing for capturing intermediate states.
Libraries/MLXLLM/Models/Qwen3.swift Makes embedTokens/lmHead public, exposes layers, and adds callCapturing.
Libraries/MLXLLM/Models/Qwen3MoE.swift Makes embedTokens/lmHead public, exposes layers, and adds callCapturing for MoE model.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread Libraries/MLXLLM/Models/Qwen3.swift Outdated
@ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding

fileprivate let layers: [Qwen3TransformerBlock]
public let layers: [Qwen3TransformerBlock]
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layers is declared public, but its element type (Qwen3TransformerBlock) is internal (no access modifier), which is not allowed for a public API and will fail to compile. Either keep layers non-public (e.g., fileprivate as before) or make the layer type(s) public as well (larger API surface).

Suggested change
public let layers: [Qwen3TransformerBlock]
let layers: [Qwen3TransformerBlock]

Copilot uses AI. Check for mistakes.
Comment thread Libraries/MLXLLM/Models/Qwen3.swift Outdated
Comment on lines +215 to +219
let kvCache = cache?.compactMap { $0 }
let mask = createAttentionMask(h: h, cache: kvCache?.first)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = layer(h, mask: mask, cache: kvCache?[i])
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing builds kvCache via compactMap { $0 }, which (1) drops nil entries and breaks layer↔cache alignment and (2) makes kvCache?[i] unsafe because standard array subscripting will trap when i is out of range. Normalize to an array with layers.count entries (padding with nils) and use safe indexing, similar to Qwen35TextModelInner.callCapturing.

Suggested change
let kvCache = cache?.compactMap { $0 }
let mask = createAttentionMask(h: h, cache: kvCache?.first)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = layer(h, mask: mask, cache: kvCache?[i])
let kvCache: [KVCache?] = {
var normalized = Array(repeating: Optional<KVCache>.none, count: model.layers.count)
if let cache {
for (i, value) in cache.prefix(model.layers.count).enumerated() {
normalized[i] = value
}
}
return normalized
}()
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
let layerCache = i < kvCache.count ? kvCache[i] : nil
h = layer(h, mask: mask, cache: layerCache)

Copilot uses AI. Check for mistakes.
Comment thread Libraries/MLXLLM/Models/Qwen3.swift Outdated
let mask = createAttentionMask(h: h, cache: kvCache?.first)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = layer(h, mask: mask, cache: kvCache?[i])
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing loops through layers without partitionedLayerCall(...), so it ignores gpuLayerCount routing (and any per-layer streaming behavior used elsewhere). This makes callCapturing behave differently from callAsFunction and can break CPU/GPU partitioning. Wrap the layer forward pass in partitionedLayerCall(index:gpuLayerCount:) (and stream flags where applicable).

Suggested change
h = layer(h, mask: mask, cache: kvCache?[i])
h = layer.partitionedLayerCall(index: i, gpuLayerCount: configuration.gpuLayerCount) {
layer(h, mask: mask, cache: kvCache?[i])
}

Copilot uses AI. Check for mistakes.
Comment thread Libraries/MLXLLM/Models/Qwen3MoE.swift Outdated
@ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding

fileprivate let layers: [Qwen3MoeDecoderLayer]
public let layers: [Qwen3MoeDecoderLayer]
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layers is declared public, but its element type (Qwen3MoeDecoderLayer) is internal, which is not allowed for a public API and will fail to compile. Either keep layers non-public (e.g., fileprivate as before) or make the layer type(s) public as well (larger API surface).

Suggested change
public let layers: [Qwen3MoeDecoderLayer]
let layers: [Qwen3MoeDecoderLayer]

Copilot uses AI. Check for mistakes.
Comment thread Libraries/MLXLLM/Models/Qwen3MoE.swift Outdated
Comment on lines +264 to +272
public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = model.embedTokens(inputs)
let kvCache = cache?.compactMap { $0 }
let mask = createAttentionMask(h: h, cache: kvCache?.first)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = layer(h, mask: mask, cache: kvCache?[i])
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing builds kvCache via compactMap { $0 }, which (1) drops nil entries and breaks layer↔cache alignment and (2) makes kvCache?[i] unsafe because standard array subscripting will trap when i is out of range. Normalize to an array with model.layers.count entries (padding with nils) and use safe indexing.

Suggested change
public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = model.embedTokens(inputs)
let kvCache = cache?.compactMap { $0 }
let mask = createAttentionMask(h: h, cache: kvCache?.first)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = layer(h, mask: mask, cache: kvCache?[i])
private func normalizedKVCache(_ cache: [KVCache?]?) -> [KVCache?] {
let layerCount = model.layers.count
guard var normalized = cache else {
return Array(repeating: nil, count: layerCount)
}
if normalized.count > layerCount {
normalized = Array(normalized.prefix(layerCount))
} else if normalized.count < layerCount {
normalized.append(contentsOf: Array(repeating: nil, count: layerCount - normalized.count))
}
return normalized
}
private func cacheAt(_ cache: [KVCache?], index: Int) -> KVCache? {
cache.indices.contains(index) ? cache[index] : nil
}
public func callCapturing(
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
) -> (MLXArray, [Int: MLXArray]) {
var h = model.embedTokens(inputs)
let kvCache = normalizedKVCache(cache)
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = layer(h, mask: mask, cache: cacheAt(kvCache, index: i))

Copilot uses AI. Check for mistakes.
Comment thread Libraries/MLXLLM/Models/Qwen3MoE.swift Outdated
let mask = createAttentionMask(h: h, cache: kvCache?.first)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = layer(h, mask: mask, cache: kvCache?[i])
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing does not use partitionedLayerCall(index:gpuLayerCount:stream:cacheToEval:), so it ignores both CPU/GPU layer partitioning and streamExperts behavior that Qwen3MoEModelInner.callAsFunction relies on. Wrap the per-layer call using partitionedLayerCall(..., stream: model.streamExperts, cacheToEval: ...) to match the normal forward path.

Suggested change
h = layer(h, mask: mask, cache: kvCache?[i])
h = layer.partitionedLayerCall(
index: i,
gpuLayerCount: configuration.gpuLayerCount,
stream: model.streamExperts,
cacheToEval: kvCache?[i]
) { layer, cache in
layer(h, mask: mask, cache: cache)
}

Copilot uses AI. Check for mistakes.
Comment on lines +188 to +194
var h = model.embedTokens(inputs)
let kvCache = cache?.compactMap { $0 }
let mask = createAttentionMask(h: h, cache: kvCache?.first)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = layer(h, mask: mask, cache: kvCache?[i])
if captureLayerIDs.contains(i) { captured[i] = h }
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing builds kvCache via compactMap { $0 }, which (1) drops nil entries and breaks layer↔cache alignment and (2) makes kvCache?[i] unsafe because standard array subscripting will trap when i is out of range. Normalize to an array with model.layers.count entries (padding with nils) and use safe indexing.

Copilot uses AI. Check for mistakes.
Comment thread Libraries/MLXLLM/Models/Llama.swift Outdated
let mask = createAttentionMask(h: h, cache: kvCache?.first)
var captured: [Int: MLXArray] = [:]
for (i, layer) in model.layers.enumerated() {
h = layer(h, mask: mask, cache: kvCache?[i])
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCapturing bypasses partitionedLayerCall(index:gpuLayerCount:), so it ignores gpuLayerCount CPU/GPU routing used in LlamaModelInner.callAsFunction. Wrap the per-layer forward pass in partitionedLayerCall to keep behavior consistent with the normal path.

Suggested change
h = layer(h, mask: mask, cache: kvCache?[i])
h = model.partitionedLayerCall(index: i, gpuLayerCount: model.args.gpuLayerCount) {
layer(h, mask: mask, cache: kvCache?[i])
}

Copilot uses AI. Check for mistakes.
…e error

Qwen3TransformerBlock and Qwen3MoeDecoderLayer are internal types.
Declaring public let layers of those types violates Swift's access
control rules. callCapturing() is defined in the same module so can
access internal layers directly — no need to expose them publicly.
@solderzzc
Copy link
Copy Markdown
Member Author

Thanks @0xClandestine! Fixed the build error — Qwen3TransformerBlock and Qwen3MoeDecoderLayer are internal types, so layers can't be public. Since callCapturing() lives in the same module it can access them directly without the property being public.

Re: Kimi — LlamaModel already covers Kimi/Moonshot models since they use the same architecture. Llama+DFlash.swift in SwiftLM PR ml-explore#78 handles them.

@0xClandestine
Copy link
Copy Markdown

Thanks @0xClandestine! Fixed the build error — Qwen3TransformerBlock and Qwen3MoeDecoderLayer are internal types, so layers can't be public. Since callCapturing() lives in the same module it can access them directly without the property being public.

Re: Kimi — LlamaModel already covers Kimi/Moonshot models since they use the same architecture. Llama+DFlash.swift in SwiftLM PR ml-explore#78 handles them.

Likewise ser, appreciate the quick responses and upstream work.

Also are you sure Llama+DFlash.swift specifically supports K2.5->K2.6?

@solderzzc
Copy link
Copy Markdown
Member Author

Correction on Kimi: I was wrong. Kimi K2.5/K2.6 uses model_type: kimi_k2 with DeepseekV3ForCausalLM architecture — not Llama. So Llama+DFlash.swift does not cover them. A separate DeepseekV3+DFlash.swift conformance would be needed, which in turn requires DeepseekV3Model and DeepseekV3ModelInner to have their APIs exposed. That's a valid follow-up. Sorry for the incorrect claim.

…ayerCall

- Replace compactMap{} cache handling with explicit normalization that
  preserves layer↔cache index alignment (nil-padding to layers.count)
- Wrap per-layer calls in partitionedLayerCall() to respect gpuLayerCount
  CPU/GPU routing (matches callAsFunction behavior)
- Qwen3MoE additionally passes stream:model.streamExperts to match
  its Qwen3MoEModelInner.callAsFunction routing
@solderzzc
Copy link
Copy Markdown
Member Author

@0xClandestine I did a wrong force push(push should never be forced...), once CI goes green and a maintainer merges this, I'll bring d9f824b's 3 files back into SwiftLM PR ml-explore#78 with the correct submodule pin.

@0xClandestine
Copy link
Copy Markdown

Correction on Kimi: I was wrong. Kimi K2.5/K2.6 uses model_type: kimi_k2 with DeepseekV3ForCausalLM architecture — not Llama. So Llama+DFlash.swift does not cover them. A separate DeepseekV3+DFlash.swift conformance would be needed, which in turn requires DeepseekV3Model and DeepseekV3ModelInner to have their APIs exposed. That's a valid follow-up. Sorry for the incorrect claim.

No problem, I'm happy to take this on unless you've got it?

@solderzzc
Copy link
Copy Markdown
Member Author

Correction on Kimi: I was wrong. Kimi K2.5/K2.6 uses model_type: kimi_k2 with DeepseekV3ForCausalLM architecture — not Llama. So Llama+DFlash.swift does not cover them. A separate DeepseekV3+DFlash.swift conformance would be needed, which in turn requires DeepseekV3Model and DeepseekV3ModelInner to have their APIs exposed. That's a valid follow-up. Sorry for the incorrect claim.

No problem, I'm happy to take this on unless you've got it?

Sure, please go ahead. I don't have it yet. :)

@solderzzc solderzzc merged commit adbfbb2 into main Apr 24, 2026
6 checks passed
@solderzzc solderzzc deleted the feat/dflash-public-api branch April 24, 2026 04:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants