Skip to content

Commit d9f824b

Browse files
committed
feat: add DFlashTargetModel conformance for Qwen3, Qwen3MoE, and Llama
Adds Sources/SwiftLM/{Qwen3,Qwen3MoE,Llama}+DFlash.swift — each declares the DFlashTargetModel protocol conformance and delegates to the model's public callCapturing / embedTokens / lmHead (exposed in mlx-swift-lm commit f4cb110). Coverage: Qwen3Model → Qwen3-8B and similar dense Qwen3 variants Qwen3MoEModel → Qwen3-Coder-30B-A3B and other Qwen3 MoE variants LlamaModel → Meta-Llama-3.x, Mistral, and Llama-family models Qwen35MoEModel → already covered via Qwen35Model inheritance Qwen36MoE → no separate Swift class found; uses Qwen35MoE path
1 parent 4c042a6 commit d9f824b

4 files changed

Lines changed: 103 additions & 1 deletion

File tree

Sources/SwiftLM/Llama+DFlash.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2026 SwiftLM Contributors
2+
// MIT License — see LICENSE file
3+
// Bridge: LlamaModel (and Mistral) conform to DFlashTargetModel
4+
5+
import DFlash
6+
import MLX
7+
import MLXLLM
8+
import MLXLMCommon
9+
10+
extension LlamaModel: DFlashTargetModel {
11+
public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray {
12+
model.embedTokens(tokens)
13+
}
14+
15+
public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray {
16+
if let lmHead {
17+
return lmHead(hiddenStates)
18+
}
19+
return model.embedTokens.asLinear(hiddenStates)
20+
}
21+
22+
public func dflashForwardWithCapture(
23+
inputIDs: MLXArray,
24+
cache: [KVCache],
25+
captureLayerIDs: Set<Int>
26+
) -> (MLXArray, [Int: MLXArray]) {
27+
let cacheOpt: [KVCache?] = cache.map { $0 }
28+
let (hiddenStates, captured) = model.callCapturing(
29+
inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs)
30+
return (dflashLmHeadLogits(hiddenStates), captured)
31+
}
32+
33+
public var dflashIsHybridGDN: Bool { false }
34+
}

Sources/SwiftLM/Qwen3+DFlash.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2026 SwiftLM Contributors
2+
// MIT License — see LICENSE file
3+
// Bridge: Qwen3 dense models conform to DFlashTargetModel
4+
5+
import DFlash
6+
import MLX
7+
import MLXLLM
8+
import MLXLMCommon
9+
10+
extension Qwen3Model: DFlashTargetModel {
11+
public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray {
12+
model.embedTokens(tokens)
13+
}
14+
15+
public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray {
16+
if let lmHead {
17+
return lmHead(hiddenStates)
18+
}
19+
return model.embedTokens.asLinear(hiddenStates)
20+
}
21+
22+
public func dflashForwardWithCapture(
23+
inputIDs: MLXArray,
24+
cache: [KVCache],
25+
captureLayerIDs: Set<Int>
26+
) -> (MLXArray, [Int: MLXArray]) {
27+
let cacheOpt: [KVCache?] = cache.map { $0 }
28+
let (hiddenStates, captured) = model.callCapturing(
29+
inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs)
30+
return (dflashLmHeadLogits(hiddenStates), captured)
31+
}
32+
33+
public var dflashIsHybridGDN: Bool { false }
34+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2026 SwiftLM Contributors
2+
// MIT License — see LICENSE file
3+
// Bridge: Qwen3 MoE models conform to DFlashTargetModel
4+
5+
import DFlash
6+
import MLX
7+
import MLXLLM
8+
import MLXLMCommon
9+
10+
extension Qwen3MoEModel: DFlashTargetModel {
11+
public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray {
12+
model.embedTokens(tokens)
13+
}
14+
15+
public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray {
16+
if let lmHead {
17+
return lmHead(hiddenStates)
18+
}
19+
return model.embedTokens.asLinear(hiddenStates)
20+
}
21+
22+
public func dflashForwardWithCapture(
23+
inputIDs: MLXArray,
24+
cache: [KVCache],
25+
captureLayerIDs: Set<Int>
26+
) -> (MLXArray, [Int: MLXArray]) {
27+
let cacheOpt: [KVCache?] = cache.map { $0 }
28+
let (hiddenStates, captured) = model.callCapturing(
29+
inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs)
30+
return (dflashLmHeadLogits(hiddenStates), captured)
31+
}
32+
33+
public var dflashIsHybridGDN: Bool { false }
34+
}

mlx-swift-lm

0 commit comments

Comments
 (0)