Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions coremltools/optimize/coreml/_quantization_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,12 +1139,28 @@ def transform_op(self, op: Operation):
"Palettization with per-channel-scale is only supported since "
"iOS18. Please set minimum_deployment_target accordingly."
)
new_var = mb.constexpr_blockwise_shift_scale(
data=new_var,
scale=per_channel_scale,
offset=None,
before_op=op,
# Bake per_channel_scale into the LUT entries instead of
# wrapping the dense weight in a runtime
# constexpr_blockwise_shift_scale: that wrapper fails MPSGraph
# verification on Apple Neural Engine (macOS 26+) because the
# mps.dequantize lowering expects an integer data operand.
# Folding is mathematically identical (output = data * scale).
lut = lut_params.lut.copy()
# LUT has trailing dims [group, num_palette, vector_size] that
# are not present in per_channel_scale; broadcast across those.
pcs_bcast = per_channel_scale.reshape(
per_channel_scale.shape
+ (1,) * (lut.ndim - per_channel_scale.ndim)
)
lut = (
lut.astype(np.float32) * pcs_bcast.astype(np.float32)
).astype(lut.dtype)
new_var = frontend_utils._construct_constexpr_lut_op(
lut_params.indices,
lut,
lut_params.vector_axis,
name=op.name + "_palettized_pcs",
before_op=op,
)
else:
decompressed_val = self.decompress(lut_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1683,13 +1683,13 @@ def test_palettization_pcs(self, compute_unit, backend):
op_type="constexpr_lut_to_dense"
)[0]
assert types.builtin_to_string(palettize_op.indices.dtype) == "uint4"
# The per-channel-scale is represented by a quant op to do scaling.
# per_channel_scale is folded into the LUT entries at compile time, so
# no runtime constexpr_blockwise_shift_scale wrapper is emitted (see
# palettize_weights in _quantization_passes.py for the rationale).
quantize_ops = mlmodel_palettized._mil_program.functions["main"].find_ops(
op_type="constexpr_blockwise_shift_scale"
)
assert len(quantize_ops) > 0
# Order of quant and lut op is determined by canonicalize_quantized_lut_pattern graph pass.
assert quantize_ops[0].outputs[0].child_ops[0].op_type == "constexpr_lut_to_dense"
assert len(quantize_ops) == 0

if _macos_version() >= (15, 0):
verify_model_outputs(mlmodel, mlmodel_palettized, coreml_input_values)
Expand Down