From 0a93687938d1f3422cc3cf800e90d4df6ae368ea Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 22:52:03 +0200 Subject: [PATCH 01/11] update gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 5f9f73cc83..9753940b65 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ private/ *.py[cod] *$py.class +.claude/ # C extensions *.so From ef04c683f9bc5436d90d63a396cada66dd937ad7 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 22:52:36 +0200 Subject: [PATCH 02/11] Add OpenAI privacy-filter token classification model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Port openai/privacy-filter to mlx-embeddings: a bidirectional 1.5B/50M-active MoE token classifier for PII detection (8 BIOES span labels, 33 classes). The architecture is a bidirectional GPT-OSS variant with GQA + attention sinks, YARN RoPE (interleaved layout), 128-expert top-4 MoE, and ±128 sliding-window attention. Sanitize splits the fused concat-layout gate_up_proj into separate gate/up projections and transposes expert weights for mlx SwitchLinear. Numerical parity with the HF reference in fp32: max logit diff < 0.004, 100% prediction agreement across PII test strings. Co-Authored-By: Claude Opus 4.7 (1M context) --- mlx_embeddings/models/base.py | 1 + .../models/openai_privacy_filter.py | 326 ++++++++++++++++++ mlx_embeddings/tests/test_models.py | 28 ++ 3 files changed, 355 insertions(+) create mode 100644 mlx_embeddings/models/openai_privacy_filter.py diff --git a/mlx_embeddings/models/base.py b/mlx_embeddings/models/base.py index 3fd544399c..f77f46750a 100644 --- a/mlx_embeddings/models/base.py +++ b/mlx_embeddings/models/base.py @@ -24,6 +24,7 @@ class BaseModelOutput: pooler_output: Optional[mx.array] = None text_embeds: Optional[mx.array] = None # mean pooled and normalized embeddings hidden_states: Optional[List[mx.array]] = None + logits: Optional[mx.array] = None # token-level or sequence classification logits @dataclass diff --git a/mlx_embeddings/models/openai_privacy_filter.py b/mlx_embeddings/models/openai_privacy_filter.py new file mode 100644 index 0000000000..49ca03ad36 --- /dev/null +++ b/mlx_embeddings/models/openai_privacy_filter.py @@ -0,0 +1,326 @@ +import math +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.rope_utils import initialize_rope +from mlx_lm.models.switch_layers import SwitchGLU + +from .base import BaseModelArgs, BaseModelOutput + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "openai_privacy_filter" + vocab_size: int = 200064 + hidden_size: int = 640 + intermediate_size: int = 640 + num_hidden_layers: int = 8 + num_attention_heads: int = 14 + num_key_value_heads: int = 2 + head_dim: int = 64 + sliding_window: int = 128 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-5 + attention_bias: bool = True + attention_dropout: float = 0.0 + classifier_dropout: float = 0.0 + num_local_experts: int = 128 + num_experts_per_tok: int = 4 + tie_word_embeddings: bool = False + pad_token_id: Optional[int] = 199999 + eos_token_id: Optional[int] = 199999 + rope_parameters: Optional[Dict[str, Any]] = None + id2label: Optional[Dict[int, str]] = None + label2id: Optional[Dict[str, int]] = None + architectures: List[str] = field( + default_factory=lambda: ["OpenAIPrivacyFilterForTokenClassification"] + ) + + def __post_init__(self): + if self.rope_parameters is None: + self.rope_parameters = { + "rope_type": "yarn", + "rope_theta": 150000.0, + "factor": 32.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "original_max_position_embeddings": 4096, + } + + @property + def num_labels(self) -> int: + if self.id2label is not None: + return len(self.id2label) + return 33 + + +def _swiglu_concat(gate_up: mx.array, alpha: float = 1.702, limit: float = 7.0) -> mx.array: + gate, up = mx.split(gate_up, 2, axis=-1) + gate = mx.clip(gate, a_min=None, a_max=limit) + up = mx.clip(up, a_min=-limit, a_max=limit) + glu = gate * mx.sigmoid(gate * alpha) + return (up + 1) * glu + + +class PrivacyFilterSwiGLU(nn.Module): + """SwiGLU variant used by the privacy filter: gate clamped above, up clamped both sides, (up+1)*gate*sigmoid(alpha*gate).""" + + def __init__(self, alpha: float = 1.702, limit: float = 7.0): + super().__init__() + self.alpha = alpha + self.limit = limit + + def __call__(self, x: mx.array, gate: mx.array) -> mx.array: + gate = mx.clip(gate, a_min=None, a_max=self.limit) + x = mx.clip(x, a_min=-self.limit, a_max=self.limit) + glu = gate * mx.sigmoid(gate * self.alpha) + return (x + 1) * glu + + +class OpenAIPrivacyFilterAttention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.head_dim = config.head_dim + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + + # Attention sinks; checkpoint stores them as float32. + self.sinks = mx.zeros((config.num_attention_heads,)) + + bias = config.attention_bias + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=bias + ) + + self.sm_scale = 1.0 / math.sqrt(self.head_dim) + + scaling_config = dict(config.rope_parameters) + rope_theta = scaling_config.pop("rope_theta", 150000.0) + self.rope = initialize_rope( + self.head_dim, + rope_theta, + traditional=True, + scaling_config=scaling_config, + max_position_embeddings=config.max_position_embeddings, + ) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + B, L, _ = x.shape + D = self.head_dim + + q = self.q_proj(x).reshape(B, L, -1, D).swapaxes(1, 2) + k = self.k_proj(x).reshape(B, L, -1, D).swapaxes(1, 2) + v = self.v_proj(x).reshape(B, L, -1, D).swapaxes(1, 2) + + q = self.rope(q) + k = self.rope(k) + + out = scaled_dot_product_attention( + q, + k, + v, + cache=None, + scale=self.sm_scale, + mask=mask, + sinks=self.sinks.astype(q.dtype), + ) + + out = out.swapaxes(1, 2).reshape(B, L, -1) + return self.o_proj(out) + + +class OpenAIPrivacyFilterMLP(nn.Module): + """Top-k routed sparse MoE matching the HF reference (softmax over top-k, no extra scaling).""" + + def __init__(self, config: ModelArgs): + super().__init__() + self.num_local_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.experts = SwitchGLU( + input_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + num_experts=config.num_local_experts, + activation=PrivacyFilterSwiGLU(), + bias=True, + ) + self.router = nn.Linear( + config.hidden_size, config.num_local_experts, bias=True + ) + + def __call__(self, x: mx.array) -> mx.array: + # Router runs in fp32 for numerical parity with the reference. + x_f32 = x.astype(mx.float32) + w_f32 = self.router.weight.astype(mx.float32) + b_f32 = self.router.bias.astype(mx.float32) + router_logits = x_f32 @ w_f32.swapaxes(-1, -2) + b_f32 + + k = self.num_experts_per_tok + top_idx = mx.argpartition(router_logits, kth=-k, axis=-1)[..., -k:] + top_val = mx.take_along_axis(router_logits, top_idx, axis=-1) + weights = mx.softmax(top_val, axis=-1).astype(x.dtype) + + y = self.experts(x, top_idx) + y = y * mx.expand_dims(weights, axis=-1) + return y.sum(axis=-2) + + +class OpenAIPrivacyFilterEncoderLayer(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.self_attn = OpenAIPrivacyFilterAttention(config) + self.mlp = OpenAIPrivacyFilterMLP(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + h = self.input_layernorm(x) + h = self.self_attn(h, mask) + x = x + h + + h = self.post_attention_layernorm(x) + h = self.mlp(h) + return x + h + + +def _bidirectional_sliding_window_mask( + seq_len: int, + window: int, + attention_mask: Optional[mx.array], + dtype: mx.Dtype, +) -> mx.array: + idx = mx.arange(seq_len) + diff = idx[:, None] - idx[None, :] + local = mx.abs(diff) <= window # (L, L) bool + local = mx.where(local, mx.array(0.0, dtype=dtype), mx.array(-mx.inf, dtype=dtype)) + + if attention_mask is None: + return local[None, None, :, :] + + # attention_mask: (B, L), 1 for valid, 0 for pad. + pad = attention_mask.astype(mx.bool_) + pad_mask = mx.where( + pad[:, None, :], + mx.array(0.0, dtype=dtype), + mx.array(-mx.inf, dtype=dtype), + ) # (B, 1, L) over keys + return local[None, None, :, :] + pad_mask[:, None, :, :] + + +class OpenAIPrivacyFilterModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [ + OpenAIPrivacyFilterEncoderLayer(config) + for _ in range(config.num_hidden_layers) + ] + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window + + def __call__( + self, + input_ids: mx.array, + attention_mask: Optional[mx.array] = None, + input_embeddings: Optional[mx.array] = None, + ) -> mx.array: + if input_embeddings is not None: + h = input_embeddings + else: + h = self.embed_tokens(input_ids) + + seq_len = h.shape[1] + mask = _bidirectional_sliding_window_mask( + seq_len, self.sliding_window, attention_mask, h.dtype + ) + + for layer in self.layers: + h = layer(h, mask) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.model_type = config.model_type + self.num_labels = config.num_labels + + self.model = OpenAIPrivacyFilterModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=True) + + def __call__( + self, + input_ids: mx.array, + attention_mask: Optional[mx.array] = None, + ) -> BaseModelOutput: + if input_ids.ndim != 2: + raise ValueError(f"input_ids must be 2D, got shape {input_ids.shape}") + + last_hidden_state = self.model(input_ids, attention_mask=attention_mask) + logits = self.score(last_hidden_state) + return BaseModelOutput( + last_hidden_state=last_hidden_state, + logits=logits, + ) + + def sanitize(self, weights: dict) -> dict: + # Split the fused gate_up_proj (concatenated layout) into separate gate and up + # projections, and transpose expert weights from (E, in, out) to (E, out, in) + # to match mlx's SwitchLinear expectations. + sanitized = {} + for key, value in weights.items(): + # Skip the alternate `original/` OpenAI-format checkpoint that ships alongside + # the transformers weights in this repo. + if key.startswith("original."): + continue + if "mlp.experts.gate_up_proj_bias" in key: + gate_bias, up_bias = mx.split(value, 2, axis=-1) + sanitized[key.replace("gate_up_proj_bias", "gate_proj.bias")] = ( + mx.contiguous(gate_bias) + ) + sanitized[key.replace("gate_up_proj_bias", "up_proj.bias")] = ( + mx.contiguous(up_bias) + ) + elif "mlp.experts.gate_up_proj" in key: + # (E, in, 2*out) -> split -> (E, in, out) -> transpose -> (E, out, in) + gate, up = mx.split(value, 2, axis=-1) + sanitized[key.replace("gate_up_proj", "gate_proj.weight")] = ( + mx.contiguous(gate.swapaxes(-1, -2)) + ) + sanitized[key.replace("gate_up_proj", "up_proj.weight")] = ( + mx.contiguous(up.swapaxes(-1, -2)) + ) + elif key.endswith("mlp.experts.down_proj"): + # (E, in, out) -> (E, out, in) + sanitized[key + ".weight"] = mx.contiguous(value.swapaxes(-1, -2)) + elif key.endswith("mlp.experts.down_proj_bias"): + sanitized[key.replace("down_proj_bias", "down_proj.bias")] = value + elif key.endswith("self_attn.sinks"): + # Keep sinks in the attention module dtype (float32 is fine). + sanitized[key] = value + else: + sanitized[key] = value + return sanitized + + @property + def layers(self): + return self.model.layers diff --git a/mlx_embeddings/tests/test_models.py b/mlx_embeddings/tests/test_models.py index d3ab4a7abc..efe2cf9604 100644 --- a/mlx_embeddings/tests/test_models.py +++ b/mlx_embeddings/tests/test_models.py @@ -489,6 +489,34 @@ def test_llama_bidirec_model(self): config.num_hidden_layers, ) + def test_openai_privacy_filter_model(self): + from mlx_embeddings.models import openai_privacy_filter + + config = openai_privacy_filter.ModelArgs( + model_type="openai_privacy_filter", + vocab_size=64, + hidden_size=32, + intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + sliding_window=16, + max_position_embeddings=128, + num_local_experts=4, + num_experts_per_tok=2, + rms_norm_eps=1e-5, + ) + model = openai_privacy_filter.Model(config) + model.update(tree_map(lambda p: p.astype(mx.float32), model.parameters())) + + inputs = mx.array([[0, 1, 2, 3, 4]]) + outputs = model(inputs) + + self.assertEqual(outputs.last_hidden_state.shape, (1, 5, config.hidden_size)) + self.assertEqual(outputs.logits.shape, (1, 5, config.num_labels)) + self.assertEqual(outputs.last_hidden_state.dtype, mx.float32) + def test_qwen3_model(self): from mlx_embeddings.models import qwen3 From 5f5971e579c3a501b16e655b3fa21c4cd5570433 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 22:56:48 +0200 Subject: [PATCH 03/11] Document openai/privacy-filter in README Add the model to the supported architectures list and a Token Classification (PII detection) usage section with a working example. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/README.md b/README.md index 3bfd935e24..ebc07bdc5f 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ MLX-Embeddings supports a variety of model architectures for text embedding task - Qwen3-VL (multimodal Qwen3-VL embedding and reranking model) - Llama Bidirectional (Llama-based bidirectional embedding models, e.g. NVIDIA NV-Embed) - Llama Nemotron VL (multimodal vision-language embedding model with SigLIP vision + bidirectional Llama) +- OpenAI Privacy Filter (bidirectional GPT-OSS variant for PII token classification with sparse MoE, GQA + attention sinks, and YARN RoPE) We're continuously working to expand our support for additional model architectures. Check our GitHub repository or documentation for the most up-to-date list of supported models and their specific versions. @@ -177,6 +178,30 @@ for idx, logit in enumerate(predictions.tolist()): print(f"{label}: {logit:.3f}") ``` +#### Token Classification (PII detection) + +`openai/privacy-filter` is a bidirectional 1.5B-parameter / 50M-active sparse-MoE token classifier that tags personally identifiable information (PII) with BIOES spans over 8 categories (person, email, phone, URL, address, date, account number, secret). + +```python +import mlx.core as mx +from mlx_embeddings.utils import load + +model, tokenizer = load("openai/privacy-filter") +id2label = model.config.id2label + +text = "My name is Alice Smith and my email is alice@example.com. Phone: 555-1234." +inputs = tokenizer(text, return_tensors="mlx") + +outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"]) +preds = mx.argmax(outputs.logits, axis=-1)[0].tolist() + +tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist()) +for token, pred in zip(tokens, preds): + label = id2label[str(pred)] + if label != "O": + print(f"{token!r:20s} -> {label}") +``` + ### Batch Processing #### Multiple Texts Comparison From 9f48116b8812276b920993e41d18d0013210f040 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 23:05:42 +0200 Subject: [PATCH 04/11] Aggregate BIOES spans in README example Use itertools.groupby to collapse consecutive BIOES tokens into clean decoded spans rather than dumping per-token BPE fragments. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ebc07bdc5f..b5dc7f4e0d 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,7 @@ for idx, logit in enumerate(predictions.tolist()): `openai/privacy-filter` is a bidirectional 1.5B-parameter / 50M-active sparse-MoE token classifier that tags personally identifiable information (PII) with BIOES spans over 8 categories (person, email, phone, URL, address, date, account number, secret). ```python +from itertools import groupby import mlx.core as mx from mlx_embeddings.utils import load @@ -195,11 +196,12 @@ inputs = tokenizer(text, return_tensors="mlx") outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"]) preds = mx.argmax(outputs.logits, axis=-1)[0].tolist() -tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist()) -for token, pred in zip(tokens, preds): - label = id2label[str(pred)] - if label != "O": - print(f"{token!r:20s} -> {label}") +entity = lambda p: id2label[str(p)].split("-", 1)[-1] if id2label[str(p)] != "O" else None + +for ent, group in groupby(zip(inputs["input_ids"][0].tolist(), preds), key=lambda x: entity(x[1])): + if ent: + span = tokenizer.decode([tid for tid, _ in group]).strip() + print(f"{ent:18s} -> {span!r}") ``` ### Batch Processing From 44616af6b4005c1055bf528aba1dd5ff45a701cf Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 23:15:48 +0200 Subject: [PATCH 05/11] Make router compatible with quantization Call the router module directly (rather than a manual matmul on .weight/.bias) so the MoE forward works with both dense nn.Linear and QuantizedLinear weights. The softmax still runs in fp32 for numerical parity with the reference. Verified against /tmp/privacy-filter-{q4,mxfp4}: both quantizations extract the same PII spans as the bf16 checkpoint. Co-Authored-By: Claude Opus 4.7 (1M context) --- mlx_embeddings/models/openai_privacy_filter.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/mlx_embeddings/models/openai_privacy_filter.py b/mlx_embeddings/models/openai_privacy_filter.py index 49ca03ad36..6aa74c594b 100644 --- a/mlx_embeddings/models/openai_privacy_filter.py +++ b/mlx_embeddings/models/openai_privacy_filter.py @@ -57,14 +57,6 @@ def num_labels(self) -> int: return 33 -def _swiglu_concat(gate_up: mx.array, alpha: float = 1.702, limit: float = 7.0) -> mx.array: - gate, up = mx.split(gate_up, 2, axis=-1) - gate = mx.clip(gate, a_min=None, a_max=limit) - up = mx.clip(up, a_min=-limit, a_max=limit) - glu = gate * mx.sigmoid(gate * alpha) - return (up + 1) * glu - - class PrivacyFilterSwiGLU(nn.Module): """SwiGLU variant used by the privacy filter: gate clamped above, up clamped both sides, (up+1)*gate*sigmoid(alpha*gate).""" @@ -163,11 +155,9 @@ def __init__(self, config: ModelArgs): ) def __call__(self, x: mx.array) -> mx.array: - # Router runs in fp32 for numerical parity with the reference. - x_f32 = x.astype(mx.float32) - w_f32 = self.router.weight.astype(mx.float32) - b_f32 = self.router.bias.astype(mx.float32) - router_logits = x_f32 @ w_f32.swapaxes(-1, -2) + b_f32 + # Go through the router module so this works with both dense and + # QuantizedLinear weights; upcast the softmax for numerical parity. + router_logits = self.router(x).astype(mx.float32) k = self.num_experts_per_tok top_idx = mx.argpartition(router_logits, kth=-k, axis=-1)[..., -k:] @@ -282,6 +272,7 @@ def __call__( logits=logits, ) + def sanitize(self, weights: dict) -> dict: # Split the fused gate_up_proj (concatenated layout) into separate gate and up # projections, and transpose expert weights from (E, in, out) to (E, out, in) From 0b94ef969367fa87759963ab16ad2e9da480c502 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 23:27:13 +0200 Subject: [PATCH 06/11] Wire quant_predicate for mixed-precision quantization Add a quant_predicate on the privacy-filter Model that keeps the MoE router at 8 bits while the rest of the weights quantize to the user's chosen bit width. The router is a small but routing-sensitive linear; a uniform 4-bit quantization of the router was measurably degrading accuracy in gpt-oss-style models, and the same applies here. Follow mlx-vlm's pattern in convert.py: delegate quantization to mlx_lm.utils.quantize_model, passing a wrapper that composes mlx-embeddings' skip-vision / group-size checks with the model's quant_predicate. mlx_lm handles recording per-layer overrides into config["quantization"][path], and the existing load path in utils.py already respects those. Verified: bf16 and q4 (uniform) both still extract the same PII spans; mixed-precision q4-experts + q8-router saves to disk with 4.52 bits/ weight, loads correctly, and extracts the same spans. Co-Authored-By: Claude Opus 4.7 (1M context) --- mlx_embeddings/convert.py | 31 +++++++++++++------ .../models/openai_privacy_filter.py | 11 +++++-- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/mlx_embeddings/convert.py b/mlx_embeddings/convert.py index f4317c9771..3e7e3ff781 100644 --- a/mlx_embeddings/convert.py +++ b/mlx_embeddings/convert.py @@ -91,23 +91,34 @@ def defaults_for_mode(mode: str, group_size: int, bits: int) -> Tuple[int, int]: effective_bits = bits if bits else default_bits return effective_group_size, effective_bits - quantized_config = copy.deepcopy(config) effective_group_size, effective_bits = defaults_for_mode(mode, q_group_size, q_bits) - nn.quantize( + # Delegate to mlx_lm.utils.quantize_model (same pattern as mlx-vlm): it reads + # `model.quant_predicate` and records per-layer overrides into the config, + # while our wrapper adds the skip-vision / group-size sanity checks. + from mlx_lm.utils import quantize_model as mlx_lm_quantize_model + + default_predicate = get_class_predicate( + skip_vision=skip_vision, q_group_size=effective_group_size + ) + model_predicate = getattr(model, "quant_predicate", None) + + def quant_predicate(path, module): + if not default_predicate(path, module): + return False + if model_predicate is not None: + return model_predicate(path, module) + return True + + model, quantized_config = mlx_lm_quantize_model( model, + copy.deepcopy(config), group_size=effective_group_size, bits=effective_bits, mode=mode, - class_predicate=get_class_predicate( - skip_vision=skip_vision, q_group_size=effective_group_size - ), + quant_predicate=quant_predicate, ) - quantized_config["quantization"] = { - "group_size": effective_group_size, - "bits": effective_bits, - "mode": mode, - } + if "vision_config" in quantized_config and isinstance( quantized_config["vision_config"], dict ): diff --git a/mlx_embeddings/models/openai_privacy_filter.py b/mlx_embeddings/models/openai_privacy_filter.py index 6aa74c594b..e510cadc1c 100644 --- a/mlx_embeddings/models/openai_privacy_filter.py +++ b/mlx_embeddings/models/openai_privacy_filter.py @@ -155,8 +155,6 @@ def __init__(self, config: ModelArgs): ) def __call__(self, x: mx.array) -> mx.array: - # Go through the router module so this works with both dense and - # QuantizedLinear weights; upcast the softmax for numerical parity. router_logits = self.router(x).astype(mx.float32) k = self.num_experts_per_tok @@ -315,3 +313,12 @@ def sanitize(self, weights: dict) -> dict: @property def layers(self): return self.model.layers + + @property + def quant_predicate(self): + def predicate(path, _): + if path.endswith("router"): + return {"group_size": 64, "bits": 8} + return True + + return predicate From 18d109c1dea907ca5a431d05bd304a9a96588fda Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 23:33:26 +0200 Subject: [PATCH 07/11] Use mlx-vlm-style predicate chaining for quantization Keep the local nn.quantize call but switch the class_predicate to the compose-with-model.quant_predicate pattern from mlx-vlm: chain the default skip-vision / group-size predicate with the model's own predicate, and record any per-layer dict results so the load path re-quantizes the same way. Co-Authored-By: Claude Opus 4.7 (1M context) --- mlx_embeddings/convert.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/mlx_embeddings/convert.py b/mlx_embeddings/convert.py index 3e7e3ff781..cb17ee432d 100644 --- a/mlx_embeddings/convert.py +++ b/mlx_embeddings/convert.py @@ -91,34 +91,41 @@ def defaults_for_mode(mode: str, group_size: int, bits: int) -> Tuple[int, int]: effective_bits = bits if bits else default_bits return effective_group_size, effective_bits + quantized_config = copy.deepcopy(config) effective_group_size, effective_bits = defaults_for_mode(mode, q_group_size, q_bits) - # Delegate to mlx_lm.utils.quantize_model (same pattern as mlx-vlm): it reads - # `model.quant_predicate` and records per-layer overrides into the config, - # while our wrapper adds the skip-vision / group-size sanity checks. - from mlx_lm.utils import quantize_model as mlx_lm_quantize_model - + # Predicate-chaining pattern from mlx-vlm: honor the model's `quant_predicate` + # (if any) on top of the default skip-vision / group-size checks, and record + # per-layer overrides so the load path re-quantizes the same way. default_predicate = get_class_predicate( skip_vision=skip_vision, q_group_size=effective_group_size ) - model_predicate = getattr(model, "quant_predicate", None) + model_quant_predicate = getattr(model, "quant_predicate", None) + overrides: Dict[str, Dict[str, int]] = {} - def quant_predicate(path, module): + def base_quant_predicate(path, module): if not default_predicate(path, module): return False - if model_predicate is not None: - return model_predicate(path, module) - return True - - model, quantized_config = mlx_lm_quantize_model( + if model_quant_predicate is None: + return True + result = model_quant_predicate(path, module) + if isinstance(result, dict): + overrides[path] = result + return result + + nn.quantize( model, - copy.deepcopy(config), group_size=effective_group_size, bits=effective_bits, mode=mode, - quant_predicate=quant_predicate, + class_predicate=base_quant_predicate, ) - + quantized_config["quantization"] = { + "group_size": effective_group_size, + "bits": effective_bits, + "mode": mode, + **overrides, + } if "vision_config" in quantized_config and isinstance( quantized_config["vision_config"], dict ): From b469f0261289f06fd5fd247a528b460e1901730c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 23:39:17 +0200 Subject: [PATCH 08/11] fomrat --- mlx_embeddings/models/openai_privacy_filter.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mlx_embeddings/models/openai_privacy_filter.py b/mlx_embeddings/models/openai_privacy_filter.py index e510cadc1c..ac358b59ad 100644 --- a/mlx_embeddings/models/openai_privacy_filter.py +++ b/mlx_embeddings/models/openai_privacy_filter.py @@ -150,9 +150,7 @@ def __init__(self, config: ModelArgs): activation=PrivacyFilterSwiGLU(), bias=True, ) - self.router = nn.Linear( - config.hidden_size, config.num_local_experts, bias=True - ) + self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True) def __call__(self, x: mx.array) -> mx.array: router_logits = self.router(x).astype(mx.float32) @@ -270,7 +268,6 @@ def __call__( logits=logits, ) - def sanitize(self, weights: dict) -> dict: # Split the fused gate_up_proj (concatenated layout) into separate gate and up # projections, and transpose expert weights from (E, in, out) to (E, out, in) From a1169804e830d295eca9ee821f87c0933f6833b1 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 23:42:00 +0200 Subject: [PATCH 09/11] Update quantization logic to conditionally set model_quant_predicate based on mode Refactor the quantize_model function to ensure model_quant_predicate is only set when the mode is "affine", improving clarity and functionality in the quantization process. --- mlx_embeddings/convert.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlx_embeddings/convert.py b/mlx_embeddings/convert.py index cb17ee432d..66339d0211 100644 --- a/mlx_embeddings/convert.py +++ b/mlx_embeddings/convert.py @@ -100,7 +100,10 @@ def defaults_for_mode(mode: str, group_size: int, bits: int) -> Tuple[int, int]: default_predicate = get_class_predicate( skip_vision=skip_vision, q_group_size=effective_group_size ) - model_quant_predicate = getattr(model, "quant_predicate", None) + + model_quant_predicate = ( + getattr(model, "quant_predicate", None) if mode == "affine" else None + ) overrides: Dict[str, Dict[str, int]] = {} def base_quant_predicate(path, module): From 4abda99068f01cc097fbe331fa64d8c5f9a80d88 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 23:43:28 +0200 Subject: [PATCH 10/11] Update console scripts in pyproject.toml to improve entry point definitions Refactor the entry points for console scripts by changing the format to specify the module paths directly. This includes adding a new entry point for 'mlx_embeddings.convert', enhancing the CLI functionality. --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8e9c5795cc..aadb533400 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,9 @@ classifiers = [ ] [project.entry-points."console_scripts"] -mlx_embeddings = "mlx_embeddings.cli:main" +"mlx_embeddings.cli" = "mlx_embeddings.cli:main" +"mlx_embeddings.convert" = "mlx_embeddings.convert:main" + [project.optional-dependencies] all = [ From 170489556187b5ad0e35dae1bffcf0b2d9d54c07 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 Apr 2026 23:43:34 +0200 Subject: [PATCH 11/11] Update version to 0.1.1 --- mlx_embeddings/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_embeddings/version.py b/mlx_embeddings/version.py index 3dc1f76bc6..485f44ac21 100644 --- a/mlx_embeddings/version.py +++ b/mlx_embeddings/version.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.1.1"