@@ -196,9 +196,9 @@ public class Qwen3MoEModelInner: Module, LayerPartitionable, StreamableMoE {
196196 // StreamableMoE
197197 public var streamExperts : Bool = false
198198
199- @ModuleInfo ( key: " embed_tokens " ) var embedTokens : Embedding
199+ @ModuleInfo ( key: " embed_tokens " ) public var embedTokens : Embedding
200200
201- fileprivate let layers : [ Qwen3MoeDecoderLayer ]
201+ let layers : [ Qwen3MoeDecoderLayer ]
202202 let norm : RMSNorm
203203 let args : Qwen3MoEConfiguration
204204
@@ -238,7 +238,7 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
238238 public let model : Qwen3MoEModelInner
239239 let configuration : Qwen3MoEConfiguration
240240
241- @ModuleInfo ( key: " lm_head " ) var lmHead : Linear ?
241+ @ModuleInfo ( key: " lm_head " ) public var lmHead : Linear ?
242242
243243 public init ( _ args: Qwen3MoEConfiguration ) {
244244 self . configuration = args
@@ -261,6 +261,31 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider {
261261 return out
262262 }
263263
264+ public func callCapturing(
265+ _ inputs: MLXArray , cache: [ KVCache ? ] ? = nil , captureLayerIDs: Set < Int >
266+ ) -> ( MLXArray , [ Int : MLXArray ] ) {
267+ var h = model. embedTokens ( inputs)
268+ let layerCount = model. layers. count
269+ let kvCache : [ KVCache ? ] = {
270+ guard let c = cache else { return Array ( repeating: nil , count: layerCount) }
271+ var normalized : [ KVCache ? ] = Array ( repeating: nil , count: layerCount)
272+ for (i, v) in c. prefix ( layerCount) . enumerated ( ) { normalized [ i] = v }
273+ return normalized
274+ } ( )
275+ let mask = createAttentionMask ( h: h, cache: kvCache. first ?? nil )
276+ var captured : [ Int : MLXArray ] = [ : ]
277+ for (i, layer) in model. layers. enumerated ( ) {
278+ h = partitionedLayerCall (
279+ index: i, gpuLayerCount: model. gpuLayerCount, stream: model. streamExperts
280+ ) {
281+ layer ( h, mask: mask, cache: kvCache [ i] )
282+ }
283+ if captureLayerIDs. contains ( i) { captured [ i] = h }
284+ }
285+ h = model. norm ( h)
286+ return ( h, captured)
287+ }
288+
264289 public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
265290 var sanitizedWeights = weights
266291
0 commit comments