@@ -229,6 +229,27 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE {
229229
230230 return norm ( h)
231231 }
232+
233+ public func callCapturing(
234+ _ inputs: MLXArray , cache: [ KVCache ? ] ? = nil , captureLayerIDs: Set < Int >
235+ ) -> ( MLXArray , [ Int : MLXArray ] ) {
236+ var h = embedTokens ( inputs)
237+ let kvCache : [ KVCache ? ] = {
238+ guard let c = cache else { return Array ( repeating: nil , count: layers. count) }
239+ var normalized : [ KVCache ? ] = Array ( repeating: nil , count: layers. count)
240+ for (i, v) in c. prefix ( layers. count) . enumerated ( ) { normalized [ i] = v }
241+ return normalized
242+ } ( )
243+ let mask = createAttentionMask ( h: h, cache: kvCache. first ?? nil )
244+ var captured : [ Int : MLXArray ] = [ : ]
245+ for (i, layer) in layers. enumerated ( ) {
246+ h = partitionedLayerCall ( index: i, gpuLayerCount: gpuLayerCount, stream: streamExperts) {
247+ layer ( h, mask: mask, cache: kvCache [ i] )
248+ }
249+ if captureLayerIDs. contains ( i) { captured [ i] = h }
250+ }
251+ return ( norm ( h) , captured)
252+ }
232253}
233254
234255public class Qwen3MoEModel : Module , LLMModel , KVCacheDimensionProvider {
@@ -261,31 +282,6 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
261282 return out
262283 }
263284
264- public func callCapturing(
265- _ inputs: MLXArray , cache: [ KVCache ? ] ? = nil , captureLayerIDs: Set < Int >
266- ) -> ( MLXArray , [ Int : MLXArray ] ) {
267- var h = model. embedTokens ( inputs)
268- let layerCount = model. layers. count
269- let kvCache : [ KVCache ? ] = {
270- guard let c = cache else { return Array ( repeating: nil , count: layerCount) }
271- var normalized : [ KVCache ? ] = Array ( repeating: nil , count: layerCount)
272- for (i, v) in c. prefix ( layerCount) . enumerated ( ) { normalized [ i] = v }
273- return normalized
274- } ( )
275- let mask = createAttentionMask ( h: h, cache: kvCache. first ?? nil )
276- var captured : [ Int : MLXArray ] = [ : ]
277- for (i, layer) in model. layers. enumerated ( ) {
278- h = partitionedLayerCall (
279- index: i, gpuLayerCount: model. gpuLayerCount, stream: model. streamExperts
280- ) {
281- layer ( h, mask: mask, cache: kvCache [ i] )
282- }
283- if captureLayerIDs. contains ( i) { captured [ i] = h }
284- }
285- h = model. norm ( h)
286- return ( h, captured)
287- }
288-
289285 public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
290286 var sanitizedWeights = weights
291287
0 commit comments