Skip to content

Commit ef04c68

Browse files
Blaizzyclaude
andcommitted
Add OpenAI privacy-filter token classification model
Port openai/privacy-filter to mlx-embeddings: a bidirectional 1.5B/50M-active MoE token classifier for PII detection (8 BIOES span labels, 33 classes). The architecture is a bidirectional GPT-OSS variant with GQA + attention sinks, YARN RoPE (interleaved layout), 128-expert top-4 MoE, and ±128 sliding-window attention. Sanitize splits the fused concat-layout gate_up_proj into separate gate/up projections and transposes expert weights for mlx SwitchLinear. Numerical parity with the HF reference in fp32: max logit diff < 0.004, 100% prediction agreement across PII test strings. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0a93687 commit ef04c68

3 files changed

Lines changed: 355 additions & 0 deletions

File tree

mlx_embeddings/models/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class BaseModelOutput:
2424
pooler_output: Optional[mx.array] = None
2525
text_embeds: Optional[mx.array] = None # mean pooled and normalized embeddings
2626
hidden_states: Optional[List[mx.array]] = None
27+
logits: Optional[mx.array] = None # token-level or sequence classification logits
2728

2829

2930
@dataclass
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
import math
2+
from dataclasses import dataclass, field
3+
from typing import Any, Dict, List, Optional
4+
5+
import mlx.core as mx
6+
import mlx.nn as nn
7+
from mlx_lm.models.base import scaled_dot_product_attention
8+
from mlx_lm.models.rope_utils import initialize_rope
9+
from mlx_lm.models.switch_layers import SwitchGLU
10+
11+
from .base import BaseModelArgs, BaseModelOutput
12+
13+
14+
@dataclass
15+
class ModelArgs(BaseModelArgs):
16+
model_type: str = "openai_privacy_filter"
17+
vocab_size: int = 200064
18+
hidden_size: int = 640
19+
intermediate_size: int = 640
20+
num_hidden_layers: int = 8
21+
num_attention_heads: int = 14
22+
num_key_value_heads: int = 2
23+
head_dim: int = 64
24+
sliding_window: int = 128
25+
max_position_embeddings: int = 131072
26+
rms_norm_eps: float = 1e-5
27+
attention_bias: bool = True
28+
attention_dropout: float = 0.0
29+
classifier_dropout: float = 0.0
30+
num_local_experts: int = 128
31+
num_experts_per_tok: int = 4
32+
tie_word_embeddings: bool = False
33+
pad_token_id: Optional[int] = 199999
34+
eos_token_id: Optional[int] = 199999
35+
rope_parameters: Optional[Dict[str, Any]] = None
36+
id2label: Optional[Dict[int, str]] = None
37+
label2id: Optional[Dict[str, int]] = None
38+
architectures: List[str] = field(
39+
default_factory=lambda: ["OpenAIPrivacyFilterForTokenClassification"]
40+
)
41+
42+
def __post_init__(self):
43+
if self.rope_parameters is None:
44+
self.rope_parameters = {
45+
"rope_type": "yarn",
46+
"rope_theta": 150000.0,
47+
"factor": 32.0,
48+
"beta_fast": 32.0,
49+
"beta_slow": 1.0,
50+
"original_max_position_embeddings": 4096,
51+
}
52+
53+
@property
54+
def num_labels(self) -> int:
55+
if self.id2label is not None:
56+
return len(self.id2label)
57+
return 33
58+
59+
60+
def _swiglu_concat(gate_up: mx.array, alpha: float = 1.702, limit: float = 7.0) -> mx.array:
61+
gate, up = mx.split(gate_up, 2, axis=-1)
62+
gate = mx.clip(gate, a_min=None, a_max=limit)
63+
up = mx.clip(up, a_min=-limit, a_max=limit)
64+
glu = gate * mx.sigmoid(gate * alpha)
65+
return (up + 1) * glu
66+
67+
68+
class PrivacyFilterSwiGLU(nn.Module):
69+
"""SwiGLU variant used by the privacy filter: gate clamped above, up clamped both sides, (up+1)*gate*sigmoid(alpha*gate)."""
70+
71+
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
72+
super().__init__()
73+
self.alpha = alpha
74+
self.limit = limit
75+
76+
def __call__(self, x: mx.array, gate: mx.array) -> mx.array:
77+
gate = mx.clip(gate, a_min=None, a_max=self.limit)
78+
x = mx.clip(x, a_min=-self.limit, a_max=self.limit)
79+
glu = gate * mx.sigmoid(gate * self.alpha)
80+
return (x + 1) * glu
81+
82+
83+
class OpenAIPrivacyFilterAttention(nn.Module):
84+
def __init__(self, config: ModelArgs):
85+
super().__init__()
86+
self.head_dim = config.head_dim
87+
self.num_attention_heads = config.num_attention_heads
88+
self.num_key_value_heads = config.num_key_value_heads
89+
self.num_key_value_groups = (
90+
config.num_attention_heads // config.num_key_value_heads
91+
)
92+
93+
# Attention sinks; checkpoint stores them as float32.
94+
self.sinks = mx.zeros((config.num_attention_heads,))
95+
96+
bias = config.attention_bias
97+
self.q_proj = nn.Linear(
98+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=bias
99+
)
100+
self.k_proj = nn.Linear(
101+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=bias
102+
)
103+
self.v_proj = nn.Linear(
104+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=bias
105+
)
106+
self.o_proj = nn.Linear(
107+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=bias
108+
)
109+
110+
self.sm_scale = 1.0 / math.sqrt(self.head_dim)
111+
112+
scaling_config = dict(config.rope_parameters)
113+
rope_theta = scaling_config.pop("rope_theta", 150000.0)
114+
self.rope = initialize_rope(
115+
self.head_dim,
116+
rope_theta,
117+
traditional=True,
118+
scaling_config=scaling_config,
119+
max_position_embeddings=config.max_position_embeddings,
120+
)
121+
122+
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
123+
B, L, _ = x.shape
124+
D = self.head_dim
125+
126+
q = self.q_proj(x).reshape(B, L, -1, D).swapaxes(1, 2)
127+
k = self.k_proj(x).reshape(B, L, -1, D).swapaxes(1, 2)
128+
v = self.v_proj(x).reshape(B, L, -1, D).swapaxes(1, 2)
129+
130+
q = self.rope(q)
131+
k = self.rope(k)
132+
133+
out = scaled_dot_product_attention(
134+
q,
135+
k,
136+
v,
137+
cache=None,
138+
scale=self.sm_scale,
139+
mask=mask,
140+
sinks=self.sinks.astype(q.dtype),
141+
)
142+
143+
out = out.swapaxes(1, 2).reshape(B, L, -1)
144+
return self.o_proj(out)
145+
146+
147+
class OpenAIPrivacyFilterMLP(nn.Module):
148+
"""Top-k routed sparse MoE matching the HF reference (softmax over top-k, no extra scaling)."""
149+
150+
def __init__(self, config: ModelArgs):
151+
super().__init__()
152+
self.num_local_experts = config.num_local_experts
153+
self.num_experts_per_tok = config.num_experts_per_tok
154+
self.experts = SwitchGLU(
155+
input_dims=config.hidden_size,
156+
hidden_dims=config.intermediate_size,
157+
num_experts=config.num_local_experts,
158+
activation=PrivacyFilterSwiGLU(),
159+
bias=True,
160+
)
161+
self.router = nn.Linear(
162+
config.hidden_size, config.num_local_experts, bias=True
163+
)
164+
165+
def __call__(self, x: mx.array) -> mx.array:
166+
# Router runs in fp32 for numerical parity with the reference.
167+
x_f32 = x.astype(mx.float32)
168+
w_f32 = self.router.weight.astype(mx.float32)
169+
b_f32 = self.router.bias.astype(mx.float32)
170+
router_logits = x_f32 @ w_f32.swapaxes(-1, -2) + b_f32
171+
172+
k = self.num_experts_per_tok
173+
top_idx = mx.argpartition(router_logits, kth=-k, axis=-1)[..., -k:]
174+
top_val = mx.take_along_axis(router_logits, top_idx, axis=-1)
175+
weights = mx.softmax(top_val, axis=-1).astype(x.dtype)
176+
177+
y = self.experts(x, top_idx)
178+
y = y * mx.expand_dims(weights, axis=-1)
179+
return y.sum(axis=-2)
180+
181+
182+
class OpenAIPrivacyFilterEncoderLayer(nn.Module):
183+
def __init__(self, config: ModelArgs):
184+
super().__init__()
185+
self.self_attn = OpenAIPrivacyFilterAttention(config)
186+
self.mlp = OpenAIPrivacyFilterMLP(config)
187+
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
188+
self.post_attention_layernorm = nn.RMSNorm(
189+
config.hidden_size, eps=config.rms_norm_eps
190+
)
191+
192+
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
193+
h = self.input_layernorm(x)
194+
h = self.self_attn(h, mask)
195+
x = x + h
196+
197+
h = self.post_attention_layernorm(x)
198+
h = self.mlp(h)
199+
return x + h
200+
201+
202+
def _bidirectional_sliding_window_mask(
203+
seq_len: int,
204+
window: int,
205+
attention_mask: Optional[mx.array],
206+
dtype: mx.Dtype,
207+
) -> mx.array:
208+
idx = mx.arange(seq_len)
209+
diff = idx[:, None] - idx[None, :]
210+
local = mx.abs(diff) <= window # (L, L) bool
211+
local = mx.where(local, mx.array(0.0, dtype=dtype), mx.array(-mx.inf, dtype=dtype))
212+
213+
if attention_mask is None:
214+
return local[None, None, :, :]
215+
216+
# attention_mask: (B, L), 1 for valid, 0 for pad.
217+
pad = attention_mask.astype(mx.bool_)
218+
pad_mask = mx.where(
219+
pad[:, None, :],
220+
mx.array(0.0, dtype=dtype),
221+
mx.array(-mx.inf, dtype=dtype),
222+
) # (B, 1, L) over keys
223+
return local[None, None, :, :] + pad_mask[:, None, :, :]
224+
225+
226+
class OpenAIPrivacyFilterModel(nn.Module):
227+
def __init__(self, config: ModelArgs):
228+
super().__init__()
229+
self.config = config
230+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
231+
self.layers = [
232+
OpenAIPrivacyFilterEncoderLayer(config)
233+
for _ in range(config.num_hidden_layers)
234+
]
235+
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
236+
self.sliding_window = config.sliding_window
237+
238+
def __call__(
239+
self,
240+
input_ids: mx.array,
241+
attention_mask: Optional[mx.array] = None,
242+
input_embeddings: Optional[mx.array] = None,
243+
) -> mx.array:
244+
if input_embeddings is not None:
245+
h = input_embeddings
246+
else:
247+
h = self.embed_tokens(input_ids)
248+
249+
seq_len = h.shape[1]
250+
mask = _bidirectional_sliding_window_mask(
251+
seq_len, self.sliding_window, attention_mask, h.dtype
252+
)
253+
254+
for layer in self.layers:
255+
h = layer(h, mask)
256+
257+
return self.norm(h)
258+
259+
260+
class Model(nn.Module):
261+
def __init__(self, config: ModelArgs):
262+
super().__init__()
263+
self.config = config
264+
self.model_type = config.model_type
265+
self.num_labels = config.num_labels
266+
267+
self.model = OpenAIPrivacyFilterModel(config)
268+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=True)
269+
270+
def __call__(
271+
self,
272+
input_ids: mx.array,
273+
attention_mask: Optional[mx.array] = None,
274+
) -> BaseModelOutput:
275+
if input_ids.ndim != 2:
276+
raise ValueError(f"input_ids must be 2D, got shape {input_ids.shape}")
277+
278+
last_hidden_state = self.model(input_ids, attention_mask=attention_mask)
279+
logits = self.score(last_hidden_state)
280+
return BaseModelOutput(
281+
last_hidden_state=last_hidden_state,
282+
logits=logits,
283+
)
284+
285+
def sanitize(self, weights: dict) -> dict:
286+
# Split the fused gate_up_proj (concatenated layout) into separate gate and up
287+
# projections, and transpose expert weights from (E, in, out) to (E, out, in)
288+
# to match mlx's SwitchLinear expectations.
289+
sanitized = {}
290+
for key, value in weights.items():
291+
# Skip the alternate `original/` OpenAI-format checkpoint that ships alongside
292+
# the transformers weights in this repo.
293+
if key.startswith("original."):
294+
continue
295+
if "mlp.experts.gate_up_proj_bias" in key:
296+
gate_bias, up_bias = mx.split(value, 2, axis=-1)
297+
sanitized[key.replace("gate_up_proj_bias", "gate_proj.bias")] = (
298+
mx.contiguous(gate_bias)
299+
)
300+
sanitized[key.replace("gate_up_proj_bias", "up_proj.bias")] = (
301+
mx.contiguous(up_bias)
302+
)
303+
elif "mlp.experts.gate_up_proj" in key:
304+
# (E, in, 2*out) -> split -> (E, in, out) -> transpose -> (E, out, in)
305+
gate, up = mx.split(value, 2, axis=-1)
306+
sanitized[key.replace("gate_up_proj", "gate_proj.weight")] = (
307+
mx.contiguous(gate.swapaxes(-1, -2))
308+
)
309+
sanitized[key.replace("gate_up_proj", "up_proj.weight")] = (
310+
mx.contiguous(up.swapaxes(-1, -2))
311+
)
312+
elif key.endswith("mlp.experts.down_proj"):
313+
# (E, in, out) -> (E, out, in)
314+
sanitized[key + ".weight"] = mx.contiguous(value.swapaxes(-1, -2))
315+
elif key.endswith("mlp.experts.down_proj_bias"):
316+
sanitized[key.replace("down_proj_bias", "down_proj.bias")] = value
317+
elif key.endswith("self_attn.sinks"):
318+
# Keep sinks in the attention module dtype (float32 is fine).
319+
sanitized[key] = value
320+
else:
321+
sanitized[key] = value
322+
return sanitized
323+
324+
@property
325+
def layers(self):
326+
return self.model.layers

mlx_embeddings/tests/test_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,34 @@ def test_llama_bidirec_model(self):
489489
config.num_hidden_layers,
490490
)
491491

492+
def test_openai_privacy_filter_model(self):
493+
from mlx_embeddings.models import openai_privacy_filter
494+
495+
config = openai_privacy_filter.ModelArgs(
496+
model_type="openai_privacy_filter",
497+
vocab_size=64,
498+
hidden_size=32,
499+
intermediate_size=32,
500+
num_hidden_layers=2,
501+
num_attention_heads=4,
502+
num_key_value_heads=2,
503+
head_dim=8,
504+
sliding_window=16,
505+
max_position_embeddings=128,
506+
num_local_experts=4,
507+
num_experts_per_tok=2,
508+
rms_norm_eps=1e-5,
509+
)
510+
model = openai_privacy_filter.Model(config)
511+
model.update(tree_map(lambda p: p.astype(mx.float32), model.parameters()))
512+
513+
inputs = mx.array([[0, 1, 2, 3, 4]])
514+
outputs = model(inputs)
515+
516+
self.assertEqual(outputs.last_hidden_state.shape, (1, 5, config.hidden_size))
517+
self.assertEqual(outputs.logits.shape, (1, 5, config.num_labels))
518+
self.assertEqual(outputs.last_hidden_state.dtype, mx.float32)
519+
492520
def test_qwen3_model(self):
493521
from mlx_embeddings.models import qwen3
494522

0 commit comments

Comments
 (0)