Skip to content

Commit c5fb6c7

Browse files
committed
Fix palettize_weights with enable_per_channel_scale=True crashing on ANE (macOS 26)
When OpPalettizerConfig is configured with enable_per_channel_scale=True, palettize_weights wraps the constexpr_lut_to_dense output in a constexpr_blockwise_shift_scale op (data=<dense fp16 weight>, scale=<per-channel fp16>). On macOS 26, the MPSGraph backend lowering for that constexpr op fails verification when targeting the Apple Neural Engine: 'mps.dequantize' op operand #2 must be tensor of quantized values, but got 'tensor<1xf16>' ... failed assertion `original module failed verification' The MPSGraph lowering of constexpr_blockwise_shift_scale assumes the data operand is a quantized integer tensor (it lowers to mps.dequantize); with enable_per_channel_scale=True, the data is the dense fp16 weight, which fails that assumption. CPU and GPU compute units accept the wrapper and predict correctly; only the ANE-targeted MIL -> MPSGraph dispatch is broken. Fix: bake per_channel_scale into the LUT entries at compile time and re-emit constexpr_lut_to_dense, instead of leaving the scale as a runtime constexpr. Both data and scale are fp16 and the wrapper's only effect is data * scale, so the fold is mathematically identical. The failing MPSGraph dispatch is eliminated entirely, and CPU / GPU numerics stay bit-identical with the prior behavior. Resulting graph also has one fewer runtime constexpr per palettized const. Test updated: TestPalettizeWeights::test_palettization_pcs previously asserted that the constexpr_blockwise_shift_scale wrapper was emitted; it now asserts the wrapper is absent (the LUT is pre-scaled). Numerical equivalence vs the unpalettized model is verified by the existing verify_model_outputs call on macOS 15+. Tested: - test_palettization_pcs: PASS - All 155 TestPalettizeWeights / TestJointCompressWeights: PASS - Manual: Qwen3-VL 2B stateful chunk on macOS 26 + M4 ANE: MPSGraph verification crash gone (was reproducible at every load).
1 parent 5256644 commit c5fb6c7

2 files changed

Lines changed: 25 additions & 9 deletions

File tree

coremltools/optimize/coreml/_quantization_passes.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,12 +1139,28 @@ def transform_op(self, op: Operation):
11391139
"Palettization with per-channel-scale is only supported since "
11401140
"iOS18. Please set minimum_deployment_target accordingly."
11411141
)
1142-
new_var = mb.constexpr_blockwise_shift_scale(
1143-
data=new_var,
1144-
scale=per_channel_scale,
1145-
offset=None,
1146-
before_op=op,
1142+
# Bake per_channel_scale into the LUT entries instead of
1143+
# wrapping the dense weight in a runtime
1144+
# constexpr_blockwise_shift_scale: that wrapper fails MPSGraph
1145+
# verification on Apple Neural Engine (macOS 26+) because the
1146+
# mps.dequantize lowering expects an integer data operand.
1147+
# Folding is mathematically identical (output = data * scale).
1148+
lut = lut_params.lut.copy()
1149+
# LUT has trailing dims [group, num_palette, vector_size] that
1150+
# are not present in per_channel_scale; broadcast across those.
1151+
pcs_bcast = per_channel_scale.reshape(
1152+
per_channel_scale.shape
1153+
+ (1,) * (lut.ndim - per_channel_scale.ndim)
1154+
)
1155+
lut = (
1156+
lut.astype(np.float32) * pcs_bcast.astype(np.float32)
1157+
).astype(lut.dtype)
1158+
new_var = frontend_utils._construct_constexpr_lut_op(
1159+
lut_params.indices,
1160+
lut,
1161+
lut_params.vector_axis,
11471162
name=op.name + "_palettized_pcs",
1163+
before_op=op,
11481164
)
11491165
else:
11501166
decompressed_val = self.decompress(lut_params)

coremltools/test/optimize/coreml/test_post_training_quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,13 +1683,13 @@ def test_palettization_pcs(self, compute_unit, backend):
16831683
op_type="constexpr_lut_to_dense"
16841684
)[0]
16851685
assert types.builtin_to_string(palettize_op.indices.dtype) == "uint4"
1686-
# The per-channel-scale is represented by a quant op to do scaling.
1686+
# per_channel_scale is folded into the LUT entries at compile time, so
1687+
# no runtime constexpr_blockwise_shift_scale wrapper is emitted (see
1688+
# palettize_weights in _quantization_passes.py for the rationale).
16871689
quantize_ops = mlmodel_palettized._mil_program.functions["main"].find_ops(
16881690
op_type="constexpr_blockwise_shift_scale"
16891691
)
1690-
assert len(quantize_ops) > 0
1691-
# Order of quant and lut op is determined by canonicalize_quantized_lut_pattern graph pass.
1692-
assert quantize_ops[0].outputs[0].child_ops[0].op_type == "constexpr_lut_to_dense"
1692+
assert len(quantize_ops) == 0
16931693

16941694
if _macos_version() >= (15, 0):
16951695
verify_model_outputs(mlmodel, mlmodel_palettized, coreml_input_values)

0 commit comments

Comments
 (0)