Skip to content

Add OpenAI privacy-filter token classification model#60

Merged
Blaizzy merged 11 commits into
mainfrom
pc/add-openai-privacy-filter
Apr 22, 2026
Merged

Add OpenAI privacy-filter token classification model#60
Blaizzy merged 11 commits into
mainfrom
pc/add-openai-privacy-filter

Conversation

@Blaizzy
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy commented Apr 22, 2026

Summary

  • Port openai/privacy-filter (bidirectional GPT-OSS variant, 1.5B total / 50M active MoE token classifier for PII detection).
  • 8 transformer blocks, GQA (14 q-heads / 2 kv-heads, head_dim=64) with per-head attention sinks via mx.fast.scaled_dot_product_attention(sinks=…), YARN RoPE with interleaved layout, 128-expert top-4 MoE using SwitchGLU + a custom PrivacyFilterSwiGLU activation (up+1)·gate·σ(α·gate), bidirectional ±128 sliding-window mask, and a 33-class BIOES token-classification head (score).
  • sanitize() drops the parallel original/ OpenAI-format checkpoint, splits the fused concat-layout gate_up_proj into separate gate_proj/up_proj, and transposes expert weight matrices from (E, in, out) to (E, out, in) to match SwitchLinear.
  • Adds a logits field to BaseModelOutput so token-classification outputs have a natural home.

Numerical parity

fp32 vs HF reference on four PII prompts:

prompt max logit diff pred agreement
My name is Alice Smith and I live at 123 Main Street. 0.00182 100%
Email alice@example.com or phone 555-123-4567. 0.00207 100%
Visit https://example.com. My SSN is 123-45-6789. 0.00169 100%
Hello, my account number is 9876543210 on 2024-03-15. 0.00312 100%

Test plan

  • pytest mlx_embeddings/tests/test_models.py — all 16 tests pass, including the new test_openai_privacy_filter_model
  • End-to-end load via mlx_embeddings.utils.load("openai/privacy-filter") + forward pass
  • fp32 numerical closeness vs transformers reference (attn_implementation="eager")
  • Optional: convert + upload to HF hub via mlx_embeddings.convert

🤖 Generated with Claude Code

Blaizzy and others added 11 commits April 22, 2026 22:52
Port openai/privacy-filter to mlx-embeddings: a bidirectional 1.5B/50M-active
MoE token classifier for PII detection (8 BIOES span labels, 33 classes). The
architecture is a bidirectional GPT-OSS variant with GQA + attention sinks,
YARN RoPE (interleaved layout), 128-expert top-4 MoE, and ±128 sliding-window
attention. Sanitize splits the fused concat-layout gate_up_proj into separate
gate/up projections and transposes expert weights for mlx SwitchLinear.

Numerical parity with the HF reference in fp32: max logit diff < 0.004,
100% prediction agreement across PII test strings.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add the model to the supported architectures list and a Token
Classification (PII detection) usage section with a working example.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Use itertools.groupby to collapse consecutive BIOES tokens into clean
decoded spans rather than dumping per-token BPE fragments.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Call the router module directly (rather than a manual matmul on
.weight/.bias) so the MoE forward works with both dense nn.Linear
and QuantizedLinear weights. The softmax still runs in fp32 for
numerical parity with the reference.

Verified against /tmp/privacy-filter-{q4,mxfp4}: both quantizations
extract the same PII spans as the bf16 checkpoint.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a quant_predicate on the privacy-filter Model that keeps the MoE
router at 8 bits while the rest of the weights quantize to the user's
chosen bit width. The router is a small but routing-sensitive linear;
a uniform 4-bit quantization of the router was measurably degrading
accuracy in gpt-oss-style models, and the same applies here.

Follow mlx-vlm's pattern in convert.py: delegate quantization to
mlx_lm.utils.quantize_model, passing a wrapper that composes
mlx-embeddings' skip-vision / group-size checks with the model's
quant_predicate. mlx_lm handles recording per-layer overrides into
config["quantization"][path], and the existing load path in utils.py
already respects those.

Verified: bf16 and q4 (uniform) both still extract the same PII spans;
mixed-precision q4-experts + q8-router saves to disk with 4.52 bits/
weight, loads correctly, and extracts the same spans.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Keep the local nn.quantize call but switch the class_predicate to the
compose-with-model.quant_predicate pattern from mlx-vlm: chain the
default skip-vision / group-size predicate with the model's own
predicate, and record any per-layer dict results so the load path
re-quantizes the same way.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…based on mode

Refactor the quantize_model function to ensure model_quant_predicate is only set when the mode is "affine", improving clarity and functionality in the quantization process.
…itions

Refactor the entry points for console scripts by changing the format to specify the module paths directly. This includes adding a new entry point for 'mlx_embeddings.convert', enhancing the CLI functionality.
@Blaizzy Blaizzy merged commit ea6d739 into main Apr 22, 2026
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant