fix: move callCapturing to *ModelInner (callers use model.callCapturing)#32
Conversation
The DFlashTargetModel conformances in SwiftLM call model.callCapturing() where model is *ModelInner, not the outer *Model wrapper.
There was a problem hiding this comment.
Pull request overview
Moves the callCapturing API from the outer *Model wrappers onto the corresponding *ModelInner types so downstream DFlash conformances (which hold *ModelInner) can call model.callCapturing(...) directly.
Changes:
- Add
public callCapturing(_:cache:captureLayerIDs:)toLlamaModelInner,Qwen3ModelInner, andQwen3MoEModelInner - Remove
public callCapturingfromLlamaModel,Qwen3Model, andQwen3MoEModel
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| Libraries/MLXLLM/Models/Llama.swift | Adds callCapturing to LlamaModelInner and removes it from LlamaModel. |
| Libraries/MLXLLM/Models/Qwen3.swift | Adds callCapturing to Qwen3ModelInner and removes it from Qwen3Model. |
| Libraries/MLXLLM/Models/Qwen3MoE.swift | Adds callCapturing to Qwen3MoEModelInner and removes it from Qwen3MoEModel. |
Comments suppressed due to low confidence (1)
Libraries/MLXLLM/Models/Qwen3MoE.swift:289
callCapturingwas removed from the publicQwen3MoEModelwrapper. If this API was already exposed in #31, consider reintroducing it here as a thin forwarding method tomodel.callCapturing(...)(optionally deprecated) to avoid breaking existing call sites that use the outer model type.
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
var out = model(inputs, cache: cache)
if let lmHead {
out = lmHead(out)
} else {
out = model.embedTokens.asLinear(out)
}
return out
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var sanitizedWeights = weights
if configuration.tieWordEmbeddings {
sanitizedWeights["lm_head.weight"] = nil
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -182,29 +203,6 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { | |||
| } | |||
| } | |||
|
|
|||
There was a problem hiding this comment.
callCapturing was removed from the public LlamaModel wrapper. Since it was introduced as part of the public API in the prior PR, consider keeping a forwarding LlamaModel.callCapturing(...) that delegates to model.callCapturing(...) (optionally marking it deprecated) to avoid a breaking change for any clients already calling the wrapper type.
| @available(*, deprecated, message: "Use the underlying model.callCapturing(...) instead.") | |
| public func callCapturing(_ inputs: MLXArray, cache: [KVCache]?) -> Any { | |
| model.callCapturing(inputs, cache: cache) | |
| } |
| @@ -208,29 +229,6 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider { | |||
| return out | |||
| } | |||
|
|
|||
There was a problem hiding this comment.
callCapturing was removed from the public Qwen3Model wrapper. To preserve source compatibility for users that may have started calling Qwen3Model.callCapturing(...) after #31, consider keeping a wrapper method that forwards to model.callCapturing(...) (and optionally deprecate it if you want callers to move to the inner type).
| public func callCapturing(_ inputs: MLXArray, cache: [KVCache]?) -> (MLXArray, [MLXArray]) { | |
| var (out, captured) = model.callCapturing(inputs, cache: cache) | |
| if let lmHead { | |
| out = lmHead(out) | |
| } else { | |
| out = model.embedTokens.asLinear(out) | |
| } | |
| return (out, captured) | |
| } |
|
Noted on the source compatibility concern. Since PR #31 was merged only minutes before this hotfix and |
Hotfix on top of #31. The DFlash conformances in SwiftLM call
model.callCapturing()wheremodelis*ModelInner, socallCapturingmust live on the inner class, not the outer wrapper. Verified builds clean locally.