Skip to content

Commit ba57392

Browse files
authored
coreml : fix --quantize crash for mlprogram format; fix --optimize-ane label (#3868)
commit 8b92060 switched ct.convert() to mlprogram, but did not update the --quantize path. quantize_weights() from neural_network.quantization_utils only works with the legacy neuralnetwork format. Running with --quantize crashed with: Exception: MLModel of type mlProgram cannot be loaded just from the model spec object. It also needs the path to the weights file. Fix: pass compute_precision=ct.precision.FLOAT16 into ct.convert() when --quantize is set. This matches the original intent of nbits=16 (F16 storage) without changing the quantization scheme or model accuracy. Also fix the three boolean CLI flags (--encoder-only, --quantize, --optimize-ane) to use a _str_to_bool helper so that both --flag True and --flag False parse correctly. The type=bool form accepted "False" as True because bool("False") == True. Remove the "currently broken" label from --optimize-ane: the ANE path (WhisperANE with Conv2d attention and LayerNormANE) converts and loads correctly with both PyTorch 2.x and coremltools 9.x.
1 parent 84bd03a commit ba57392

1 file changed

Lines changed: 15 additions & 10 deletions

File tree

models/convert-whisper-to-coreml.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,19 @@
88
from typing import Dict
99
from typing import Optional
1010
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
11-
from coremltools.models.neural_network.quantization_utils import quantize_weights
1211
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
1312
from whisper import load_model
1413

14+
15+
def _str_to_bool(v):
16+
if isinstance(v, bool):
17+
return v
18+
if v.lower() in ("true", "1", "yes"):
19+
return True
20+
if v.lower() in ("false", "0", "no"):
21+
return False
22+
raise argparse.ArgumentTypeError(f"boolean value expected, got '{v}'")
23+
1524
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
1625
# The Whisper implementation expects a specific behavior from
1726
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
@@ -258,11 +267,9 @@ def convert_encoder(hparams, model, quantize=False):
258267
inputs=[ct.TensorType(name="logmel_data", shape=input_shape)],
259268
outputs=[ct.TensorType(name="output")],
260269
compute_units=ct.ComputeUnit.ALL,
270+
compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32,
261271
)
262272

263-
if quantize:
264-
model = quantize_weights(model, nbits=16)
265-
266273
return model
267274

268275
def convert_decoder(hparams, model, quantize=False):
@@ -283,20 +290,18 @@ def convert_decoder(hparams, model, quantize=False):
283290
ct.TensorType(name="token_data", shape=tokens_shape, dtype=int),
284291
ct.TensorType(name="audio_data", shape=audio_shape)
285292
],
293+
compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32,
286294
)
287295

288-
if quantize:
289-
model = quantize_weights(model, nbits=16)
290-
291296
return model
292297

293298

294299
if __name__ == "__main__":
295300
parser = argparse.ArgumentParser()
296301
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True)
297-
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
298-
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
299-
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
302+
parser.add_argument("--encoder-only", type=_str_to_bool, help="only convert encoder", default=False)
303+
parser.add_argument("--quantize", type=_str_to_bool, help="quantize weights to F16", default=False)
304+
parser.add_argument("--optimize-ane", type=_str_to_bool, help="optimize for ANE execution", default=False)
300305
args = parser.parse_args()
301306

302307
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]:

0 commit comments

Comments
 (0)