Skip to content

Commit 0b94ef9

Browse files
Blaizzyclaude
andcommitted
Wire quant_predicate for mixed-precision quantization
Add a quant_predicate on the privacy-filter Model that keeps the MoE router at 8 bits while the rest of the weights quantize to the user's chosen bit width. The router is a small but routing-sensitive linear; a uniform 4-bit quantization of the router was measurably degrading accuracy in gpt-oss-style models, and the same applies here. Follow mlx-vlm's pattern in convert.py: delegate quantization to mlx_lm.utils.quantize_model, passing a wrapper that composes mlx-embeddings' skip-vision / group-size checks with the model's quant_predicate. mlx_lm handles recording per-layer overrides into config["quantization"][path], and the existing load path in utils.py already respects those. Verified: bf16 and q4 (uniform) both still extract the same PII spans; mixed-precision q4-experts + q8-router saves to disk with 4.52 bits/ weight, loads correctly, and extracts the same spans. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 44616af commit 0b94ef9

2 files changed

Lines changed: 30 additions & 12 deletions

File tree

mlx_embeddings/convert.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,34 @@ def defaults_for_mode(mode: str, group_size: int, bits: int) -> Tuple[int, int]:
9191
effective_bits = bits if bits else default_bits
9292
return effective_group_size, effective_bits
9393

94-
quantized_config = copy.deepcopy(config)
9594
effective_group_size, effective_bits = defaults_for_mode(mode, q_group_size, q_bits)
9695

97-
nn.quantize(
96+
# Delegate to mlx_lm.utils.quantize_model (same pattern as mlx-vlm): it reads
97+
# `model.quant_predicate` and records per-layer overrides into the config,
98+
# while our wrapper adds the skip-vision / group-size sanity checks.
99+
from mlx_lm.utils import quantize_model as mlx_lm_quantize_model
100+
101+
default_predicate = get_class_predicate(
102+
skip_vision=skip_vision, q_group_size=effective_group_size
103+
)
104+
model_predicate = getattr(model, "quant_predicate", None)
105+
106+
def quant_predicate(path, module):
107+
if not default_predicate(path, module):
108+
return False
109+
if model_predicate is not None:
110+
return model_predicate(path, module)
111+
return True
112+
113+
model, quantized_config = mlx_lm_quantize_model(
98114
model,
115+
copy.deepcopy(config),
99116
group_size=effective_group_size,
100117
bits=effective_bits,
101118
mode=mode,
102-
class_predicate=get_class_predicate(
103-
skip_vision=skip_vision, q_group_size=effective_group_size
104-
),
119+
quant_predicate=quant_predicate,
105120
)
106-
quantized_config["quantization"] = {
107-
"group_size": effective_group_size,
108-
"bits": effective_bits,
109-
"mode": mode,
110-
}
121+
111122
if "vision_config" in quantized_config and isinstance(
112123
quantized_config["vision_config"], dict
113124
):

mlx_embeddings/models/openai_privacy_filter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ def __init__(self, config: ModelArgs):
155155
)
156156

157157
def __call__(self, x: mx.array) -> mx.array:
158-
# Go through the router module so this works with both dense and
159-
# QuantizedLinear weights; upcast the softmax for numerical parity.
160158
router_logits = self.router(x).astype(mx.float32)
161159

162160
k = self.num_experts_per_tok
@@ -315,3 +313,12 @@ def sanitize(self, weights: dict) -> dict:
315313
@property
316314
def layers(self):
317315
return self.model.layers
316+
317+
@property
318+
def quant_predicate(self):
319+
def predicate(path, _):
320+
if path.endswith("router"):
321+
return {"group_size": 64, "bits": 8}
322+
return True
323+
324+
return predicate

0 commit comments

Comments
 (0)