Skip to content
Merged
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __pycache__/
private/
*.py[cod]
*$py.class
.claude/

# C extensions
*.so
Expand Down
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ MLX-Embeddings supports a variety of model architectures for text embedding task
- Qwen3-VL (multimodal Qwen3-VL embedding and reranking model)
- Llama Bidirectional (Llama-based bidirectional embedding models, e.g. NVIDIA NV-Embed)
- Llama Nemotron VL (multimodal vision-language embedding model with SigLIP vision + bidirectional Llama)
- OpenAI Privacy Filter (bidirectional GPT-OSS variant for PII token classification with sparse MoE, GQA + attention sinks, and YARN RoPE)

We're continuously working to expand our support for additional model architectures. Check our GitHub repository or documentation for the most up-to-date list of supported models and their specific versions.

Expand Down Expand Up @@ -177,6 +178,32 @@ for idx, logit in enumerate(predictions.tolist()):
print(f"{label}: {logit:.3f}")
```

#### Token Classification (PII detection)

`openai/privacy-filter` is a bidirectional 1.5B-parameter / 50M-active sparse-MoE token classifier that tags personally identifiable information (PII) with BIOES spans over 8 categories (person, email, phone, URL, address, date, account number, secret).

```python
from itertools import groupby
import mlx.core as mx
from mlx_embeddings.utils import load

model, tokenizer = load("openai/privacy-filter")
id2label = model.config.id2label

text = "My name is Alice Smith and my email is alice@example.com. Phone: 555-1234."
inputs = tokenizer(text, return_tensors="mlx")

outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"])
preds = mx.argmax(outputs.logits, axis=-1)[0].tolist()

entity = lambda p: id2label[str(p)].split("-", 1)[-1] if id2label[str(p)] != "O" else None

for ent, group in groupby(zip(inputs["input_ids"][0].tolist(), preds), key=lambda x: entity(x[1])):
if ent:
span = tokenizer.decode([tid for tid, _ in group]).strip()
print(f"{ent:18s} -> {span!r}")
```

### Batch Processing

#### Multiple Texts Comparison
Expand Down
27 changes: 24 additions & 3 deletions mlx_embeddings/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,40 @@ def defaults_for_mode(mode: str, group_size: int, bits: int) -> Tuple[int, int]:
quantized_config = copy.deepcopy(config)
effective_group_size, effective_bits = defaults_for_mode(mode, q_group_size, q_bits)

# Predicate-chaining pattern from mlx-vlm: honor the model's `quant_predicate`
# (if any) on top of the default skip-vision / group-size checks, and record
# per-layer overrides so the load path re-quantizes the same way.
default_predicate = get_class_predicate(
skip_vision=skip_vision, q_group_size=effective_group_size
)

model_quant_predicate = (
getattr(model, "quant_predicate", None) if mode == "affine" else None
)
overrides: Dict[str, Dict[str, int]] = {}

def base_quant_predicate(path, module):
if not default_predicate(path, module):
return False
if model_quant_predicate is None:
return True
result = model_quant_predicate(path, module)
if isinstance(result, dict):
overrides[path] = result
return result

nn.quantize(
model,
group_size=effective_group_size,
bits=effective_bits,
mode=mode,
class_predicate=get_class_predicate(
skip_vision=skip_vision, q_group_size=effective_group_size
),
class_predicate=base_quant_predicate,
)
quantized_config["quantization"] = {
"group_size": effective_group_size,
"bits": effective_bits,
"mode": mode,
**overrides,
}
if "vision_config" in quantized_config and isinstance(
quantized_config["vision_config"], dict
Expand Down
1 change: 1 addition & 0 deletions mlx_embeddings/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BaseModelOutput:
pooler_output: Optional[mx.array] = None
text_embeds: Optional[mx.array] = None # mean pooled and normalized embeddings
hidden_states: Optional[List[mx.array]] = None
logits: Optional[mx.array] = None # token-level or sequence classification logits


@dataclass
Expand Down
Loading
Loading