Skip to content

Commit 694806d

Browse files
authored
Merge pull request #32 from SharpAI/feat/dflash-public-api-v2
fix: move callCapturing to *ModelInner (callers use model.callCapturing)
2 parents adbfbb2 + 8eb7c0e commit 694806d

3 files changed

Lines changed: 63 additions & 71 deletions

File tree

Libraries/MLXLLM/Models/Llama.swift

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,27 @@ public class LlamaModelInner: Module, LayerPartitionable {
152152

153153
return norm(h)
154154
}
155+
156+
public func callCapturing(
157+
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
158+
) -> (MLXArray, [Int: MLXArray]) {
159+
var h = embedTokens(inputs)
160+
let kvCache: [KVCache?] = {
161+
guard let c = cache else { return Array(repeating: nil, count: layers.count) }
162+
var normalized: [KVCache?] = Array(repeating: nil, count: layers.count)
163+
for (i, v) in c.prefix(layers.count).enumerated() { normalized[i] = v }
164+
return normalized
165+
}()
166+
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
167+
var captured: [Int: MLXArray] = [:]
168+
for (i, layer) in layers.enumerated() {
169+
h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) {
170+
layer(h, mask: mask, cache: kvCache[i])
171+
}
172+
if captureLayerIDs.contains(i) { captured[i] = h }
173+
}
174+
return (norm(h), captured)
175+
}
155176
}
156177

157178
/// Model for Llama and Mistral model types.
@@ -182,29 +203,6 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {
182203
}
183204
}
184205

185-
public func callCapturing(
186-
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
187-
) -> (MLXArray, [Int: MLXArray]) {
188-
var h = model.embedTokens(inputs)
189-
let layerCount = model.layers.count
190-
let kvCache: [KVCache?] = {
191-
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
192-
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
193-
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
194-
return normalized
195-
}()
196-
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
197-
var captured: [Int: MLXArray] = [:]
198-
for (i, layer) in model.layers.enumerated() {
199-
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
200-
layer(h, mask: mask, cache: kvCache[i])
201-
}
202-
if captureLayerIDs.contains(i) { captured[i] = h }
203-
}
204-
h = model.norm(h)
205-
return (h, captured)
206-
}
207-
208206
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
209207
// Remove unused precomputed rotary frequencies
210208
weights.filter {

Libraries/MLXLLM/Models/Qwen3.swift

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,27 @@ public class Qwen3ModelInner: Module, LayerPartitionable {
176176

177177
return norm(h)
178178
}
179+
180+
public func callCapturing(
181+
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
182+
) -> (MLXArray, [Int: MLXArray]) {
183+
var h = embedTokens(inputs)
184+
let kvCache: [KVCache?] = {
185+
guard let c = cache else { return Array(repeating: nil, count: layers.count) }
186+
var normalized: [KVCache?] = Array(repeating: nil, count: layers.count)
187+
for (i, v) in c.prefix(layers.count).enumerated() { normalized[i] = v }
188+
return normalized
189+
}()
190+
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
191+
var captured: [Int: MLXArray] = [:]
192+
for (i, layer) in layers.enumerated() {
193+
h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) {
194+
layer(h, mask: mask, cache: kvCache[i])
195+
}
196+
if captureLayerIDs.contains(i) { captured[i] = h }
197+
}
198+
return (norm(h), captured)
199+
}
179200
}
180201

181202
public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
@@ -208,29 +229,6 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider {
208229
return out
209230
}
210231

211-
public func callCapturing(
212-
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
213-
) -> (MLXArray, [Int: MLXArray]) {
214-
var h = model.embedTokens(inputs)
215-
let layerCount = model.layers.count
216-
let kvCache: [KVCache?] = {
217-
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
218-
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
219-
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
220-
return normalized
221-
}()
222-
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
223-
var captured: [Int: MLXArray] = [:]
224-
for (i, layer) in model.layers.enumerated() {
225-
h = partitionedLayerCall(index: i, gpuLayerCount: model.gpuLayerCount) {
226-
layer(h, mask: mask, cache: kvCache[i])
227-
}
228-
if captureLayerIDs.contains(i) { captured[i] = h }
229-
}
230-
h = model.norm(h)
231-
return (h, captured)
232-
}
233-
234232
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
235233
var weights = weights
236234

Libraries/MLXLLM/Models/Qwen3MoE.swift

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,27 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE {
229229

230230
return norm(h)
231231
}
232+
233+
public func callCapturing(
234+
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
235+
) -> (MLXArray, [Int: MLXArray]) {
236+
var h = embedTokens(inputs)
237+
let kvCache: [KVCache?] = {
238+
guard let c = cache else { return Array(repeating: nil, count: layers.count) }
239+
var normalized: [KVCache?] = Array(repeating: nil, count: layers.count)
240+
for (i, v) in c.prefix(layers.count).enumerated() { normalized[i] = v }
241+
return normalized
242+
}()
243+
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
244+
var captured: [Int: MLXArray] = [:]
245+
for (i, layer) in layers.enumerated() {
246+
h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount, stream: streamExperts) {
247+
layer(h, mask: mask, cache: kvCache[i])
248+
}
249+
if captureLayerIDs.contains(i) { captured[i] = h }
250+
}
251+
return (norm(h), captured)
252+
}
232253
}
233254

234255
public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
@@ -261,31 +282,6 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
261282
return out
262283
}
263284

264-
public func callCapturing(
265-
_ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set<Int>
266-
) -> (MLXArray, [Int: MLXArray]) {
267-
var h = model.embedTokens(inputs)
268-
let layerCount = model.layers.count
269-
let kvCache: [KVCache?] = {
270-
guard let c = cache else { return Array(repeating: nil, count: layerCount) }
271-
var normalized: [KVCache?] = Array(repeating: nil, count: layerCount)
272-
for (i, v) in c.prefix(layerCount).enumerated() { normalized[i] = v }
273-
return normalized
274-
}()
275-
let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil)
276-
var captured: [Int: MLXArray] = [:]
277-
for (i, layer) in model.layers.enumerated() {
278-
h = partitionedLayerCall(
279-
index: i, gpuLayerCount: model.gpuLayerCount, stream: model.streamExperts
280-
) {
281-
layer(h, mask: mask, cache: kvCache[i])
282-
}
283-
if captureLayerIDs.contains(i) { captured[i] = h }
284-
}
285-
h = model.norm(h)
286-
return (h, captured)
287-
}
288-
289285
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
290286
var sanitizedWeights = weights
291287

0 commit comments

Comments
 (0)