Skip to content

Commit adbfbb2

Browse files
authored
Merge pull request #31 from SharpAI/feat/dflash-public-api
feat: expose embedTokens, lmHead, callCapturing for DFlash conformance on Llama/Qwen3/Qwen3MoE
2 parents ef3318e + 63da149 commit adbfbb2

3 files changed

Lines changed: 79 additions & 8 deletions

File tree

Libraries/MLXLLM/Models/Llama.swift

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class LlamaTransformerBlock: Module {
120120

121121
public class LlamaModelInner: Module, LayerPartitionable {
122122

123-
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
123+
@ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding
124124

125125
let layers: [LlamaTransformerBlock]
126126
let norm: RMSNorm
@@ -162,7 +162,7 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {
162162

163163
public let model: LlamaModelInner
164164

165-
@ModuleInfo(key: "lm_head") var lmHead: Linear?
165+
@ModuleInfo(key: "lm_head") public var lmHead: Linear?
166166

167167
public init(_ args: LlamaConfiguration) {
168168
self.vocabularySize = args.vocabularySize
@@ -182,6 +182,29 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {
182182
}
183183
}
184184

185+
public func callCapturing(
186+
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
187+
) -> (MLXArray, [Int: MLXArray]) {
188+
var h = model.embedTokens(inputs)
189+
let layerCount = model.layers.count
190+
let kvCache: [KVCache?] = {
191+
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
192+
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
193+
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
194+
return normalized
195+
}()
196+
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
197+
var captured: [Int: MLXArray] = [:]
198+
for (i, layer) in model.layers.enumerated() {
199+
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
200+
layer(h, mask: mask, cache: kvCache[i])
201+
}
202+
if captureLayerIDs.contains(i) { captured[i] = h }
203+
}
204+
h = model.norm(h)
205+
return (h, captured)
206+
}
207+
185208
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
186209
// Remove unused precomputed rotary frequencies
187210
weights.filter {

Libraries/MLXLLM/Models/Qwen3.swift

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ public class Qwen3ModelInner: Module, LayerPartitionable {
145145
// LayerPartitionable
146146
public var gpuLayerCount: Int?
147147
public var totalLayerCount: Int { layers.count }
148-
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
148+
@ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding
149149

150-
fileprivate let layers: [Qwen3TransformerBlock]
150+
let layers: [Qwen3TransformerBlock]
151151
let norm: RMSNorm
152152

153153
public init(_ args: Qwen3Configuration) {
@@ -185,7 +185,7 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
185185
public let model: Qwen3ModelInner
186186
let configuration: Qwen3Configuration
187187

188-
@ModuleInfo(key: "lm_head") var lmHead: Linear?
188+
@ModuleInfo(key: "lm_head") public var lmHead: Linear?
189189

190190
public init(_ args: Qwen3Configuration) {
191191
self.configuration = args
@@ -208,6 +208,29 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
208208
return out
209209
}
210210

211+
public func callCapturing(
212+
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
213+
) -> (MLXArray, [Int: MLXArray]) {
214+
var h = model.embedTokens(inputs)
215+
let layerCount = model.layers.count
216+
let kvCache: [KVCache?] = {
217+
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
218+
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
219+
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
220+
return normalized
221+
}()
222+
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
223+
var captured: [Int: MLXArray] = [:]
224+
for (i, layer) in model.layers.enumerated() {
225+
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
226+
layer(h, mask: mask, cache: kvCache[i])
227+
}
228+
if captureLayerIDs.contains(i) { captured[i] = h }
229+
}
230+
h = model.norm(h)
231+
return (h, captured)
232+
}
233+
211234
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
212235
var weights = weights
213236

Libraries/MLXLLM/Models/Qwen3MoE.swift

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE {
196196
// StreamableMoE
197197
public var streamExperts: Bool = false
198198

199-
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
199+
@ModuleInfo(key: "embed_tokens") public var embedTokens: Embedding
200200

201-
fileprivate let layers: [Qwen3MoeDecoderLayer]
201+
let layers: [Qwen3MoeDecoderLayer]
202202
let norm: RMSNorm
203203
let args: Qwen3MoEConfiguration
204204

@@ -238,7 +238,7 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
238238
public let model: Qwen3MoEModelInner
239239
let configuration: Qwen3MoEConfiguration
240240

241-
@ModuleInfo(key: "lm_head") var lmHead: Linear?
241+
@ModuleInfo(key: "lm_head") public var lmHead: Linear?
242242

243243
public init(_ args: Qwen3MoEConfiguration) {
244244
self.configuration = args
@@ -261,6 +261,31 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
261261
return out
262262
}
263263

264+
public func callCapturing(
265+
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
266+
) -> (MLXArray, [Int: MLXArray]) {
267+
var h = model.embedTokens(inputs)
268+
let layerCount = model.layers.count
269+
let kvCache: [KVCache?] = {
270+
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
271+
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
272+
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
273+
return normalized
274+
}()
275+
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
276+
var captured: [Int: MLXArray] = [:]
277+
for (i, layer) in model.layers.enumerated() {
278+
h = partitionedLayerCall(
279+
index: i, gpuLayerCount: model.gpuLayerCount, stream: model.streamExperts
280+
) {
281+
layer(h, mask: mask, cache: kvCache[i])
282+
}
283+
if captureLayerIDs.contains(i) { captured[i] = h }
284+
}
285+
h = model.norm(h)
286+
return (h, captured)
287+
}
288+
264289
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
265290
var sanitizedWeights = weights
266291

0 commit comments

Comments
 (0)