Skip to content

Commit 18d109c

Browse files
Blaizzyclaude
andcommitted
Use mlx-vlm-style predicate chaining for quantization
Keep the local nn.quantize call but switch the class_predicate to the compose-with-model.quant_predicate pattern from mlx-vlm: chain the default skip-vision / group-size predicate with the model's own predicate, and record any per-layer dict results so the load path re-quantizes the same way. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0b94ef9 commit 18d109c

1 file changed

Lines changed: 22 additions & 15 deletions

File tree

mlx_embeddings/convert.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,34 +91,41 @@ def defaults_for_mode(mode: str, group_size: int, bits: int) -> Tuple[int, int]:
9191
effective_bits = bits if bits else default_bits
9292
return effective_group_size, effective_bits
9393

94+
quantized_config = copy.deepcopy(config)
9495
effective_group_size, effective_bits = defaults_for_mode(mode, q_group_size, q_bits)
9596

96-
# Delegate to mlx_lm.utils.quantize_model (same pattern as mlx-vlm): it reads
97-
# `model.quant_predicate` and records per-layer overrides into the config,
98-
# while our wrapper adds the skip-vision / group-size sanity checks.
99-
from mlx_lm.utils import quantize_model as mlx_lm_quantize_model
100-
97+
# Predicate-chaining pattern from mlx-vlm: honor the model's `quant_predicate`
98+
# (if any) on top of the default skip-vision / group-size checks, and record
99+
# per-layer overrides so the load path re-quantizes the same way.
101100
default_predicate = get_class_predicate(
102101
skip_vision=skip_vision, q_group_size=effective_group_size
103102
)
104-
model_predicate = getattr(model, "quant_predicate", None)
103+
model_quant_predicate = getattr(model, "quant_predicate", None)
104+
overrides: Dict[str, Dict[str, int]] = {}
105105

106-
def quant_predicate(path, module):
106+
def base_quant_predicate(path, module):
107107
if not default_predicate(path, module):
108108
return False
109-
if model_predicate is not None:
110-
return model_predicate(path, module)
111-
return True
112-
113-
model, quantized_config = mlx_lm_quantize_model(
109+
if model_quant_predicate is None:
110+
return True
111+
result = model_quant_predicate(path, module)
112+
if isinstance(result, dict):
113+
overrides[path] = result
114+
return result
115+
116+
nn.quantize(
114117
model,
115-
copy.deepcopy(config),
116118
group_size=effective_group_size,
117119
bits=effective_bits,
118120
mode=mode,
119-
quant_predicate=quant_predicate,
121+
class_predicate=base_quant_predicate,
120122
)
121-
123+
quantized_config["quantization"] = {
124+
"group_size": effective_group_size,
125+
"bits": effective_bits,
126+
"mode": mode,
127+
**overrides,
128+
}
122129
if "vision_config" in quantized_config and isinstance(
123130
quantized_config["vision_config"], dict
124131
):

0 commit comments

Comments
 (0)