Skip to content

Commit 9b28270

Browse files
leonnoirclercclaudeBlaizzy
authored
Dispatch sentence-transformers pooling for bert + xlm_roberta (fix CLS-pooled checkpoints) (#63)
* sentence-transformers pooling dispatch for bert + xlm_roberta bert.py and xlm_roberta.py both hard-coded `mean_pooling` for `text_embeds`, ignoring sentence-transformers' `1_Pooling/config.json`. CLS-pooled checkpoints (bge-base-en-v1.5, snowflake-arctic-embed-l-v2.0, the rest of the bge family, mxbai, …) silently returned mean-pooled vectors instead of CLS, with measurable cosine drift against the SentenceTransformer reference. Changes: - base.py: add cls_pooling, max_pooling, lasttoken_pooling translated from sentence_transformers/sentence_transformer/modules/pooling.py @ 8151750. Add `_normalize_pooling_config` (port of `_convert_legacy_pooling_kwargs`) and `pool_by_config` dispatcher. Modes outside {cls, mean, max, lasttoken} raise NotImplementedError; tuple multi-mode and include_prompt=False raise too. Empirical mode coverage: ~100% of top-60 ST checkpoints on the Hub. - utils.py: `_read_pooling_config` loads `1_Pooling/config.json` when present and injects it into the config dict, so the existing `model_config` override mechanism keeps "caller wins" precedence intact. - bert.py / xlm_roberta.py: add `pooling_config` field on ModelArgs with default `{"pooling_mode": "mean"}` (visible in the dataclass signature), swap the hard-coded `mean_pooling` for `pool_by_config(...)`. - tests/test_base.py: ports of the five HF unit tests that map to our supported surface — cls right-pad, cls left-pad, max-respects-mask, lasttoken finds last, lasttoken all-padding-zeros — plus the gold-standard `test_pooling_exact_values` (HF's shared fixture, all four supported modes) and the two `_convert_legacy_pooling_kwargs` conversion tests. Each carries a line-pinned reference back to the HF source at commit 8151750. Verified end-to-end with the mlx-embeddings-tests harness: bge-base-en-v1.5 (bert, CLS): cos_sim 0.9587 -> 1.000000 snowflake-arctic-embed-l-v2.0 (xlm-r, CLS): cos_sim 0.8156 -> 1.000000 all-MiniLM-L6-v2 (bert, MEAN, control): cos_sim 1.000000 (unchanged) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * extract pooling module, tighten dispatcher, preserve 1_Pooling in convert - Move pool helpers + dispatcher from `models/base.py` to a new `models/pooling.py`. `base.py` goes back to its pre-PR shape (data classes + `normalize_embeddings`). - Update imports across all model files that use `mean_pooling`: `bert`, `xlm_roberta`, `modernbert`, `gemma3_text`, `llama_bidirec`, `lfm2`, `llama_nemotron_vl`. - Tighten `pool_by_config` / `_normalize_pooling_config` signatures from `Optional[Dict[str, Any]]` to `Dict[str, Any]`; remove the now-unreachable `if cfg is None` fallback (the dataclass default factory guarantees a dict reaches the dispatcher). - Tighten `bert.ModelArgs.pooling_config` and `xlm_roberta.ModelArgs.pooling_config` from `Optional[dict]` to `dict`. - `convert.py`: preserve `1_Pooling/` subdirectory when converting an HF checkpoint. The top-level `*.json` glob doesn't recurse; without this, converted `mlx-community/*` variants lose the pooling sidecar and the loader silently falls back to mean. - Split `tests/test_base.py` -> `tests/test_pooling.py` for the new pool-helper / dispatcher / config-normalization tests; replace per-test "Port of test_xxx" annotations with the upstream comments verbatim and add the two HF-port dispatcher tests (`test_forward_all_modes`, `test_invalid_mode_raises`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * style: apply pre-commit (black + isort) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: avoid optional backend guard in qwen3 vl loader test --------- Co-authored-by: Leon Noirclerc <leonnoirclerc@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
1 parent ea6d739 commit 9b28270

14 files changed

Lines changed: 380 additions & 64 deletions

File tree

mlx_embeddings/convert.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ def convert(
233233
for file in files:
234234
shutil.copy(file, mlx_path)
235235

236+
src_pooling = model_path / "1_Pooling"
237+
if src_pooling.is_dir():
238+
shutil.copytree(src_pooling, mlx_path / "1_Pooling", dirs_exist_ok=True)
239+
236240
tokenizer.save_pretrained(mlx_path)
237241

238242
save_config(config, config_path=mlx_path / "config.json")

mlx_embeddings/models/base.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,6 @@ class ViTModelOutput:
3838
vision_model_output: Optional[mx.array] = None
3939

4040

41-
def mean_pooling(token_embeddings: mx.array, attention_mask: mx.array):
42-
input_mask_expanded = mx.expand_dims(attention_mask, -1)
43-
input_mask_expanded = mx.broadcast_to(
44-
input_mask_expanded, token_embeddings.shape
45-
).astype(mx.float32)
46-
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, axis=1)
47-
sum_mask = mx.maximum(mx.sum(input_mask_expanded, axis=1), 1e-9)
48-
return sum_embeddings / sum_mask
49-
50-
5141
def normalize_embeddings(embeddings, p=2, axis=-1, keepdims=True, eps=1e-9):
5242
return embeddings / mx.maximum(
5343
mx.linalg.norm(embeddings, ord=p, axis=axis, keepdims=keepdims), eps

mlx_embeddings/models/bert.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import math
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
from typing import Optional, Tuple
44

55
import mlx.core as mx
66
import mlx.nn as nn
77

8-
from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
8+
from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings
9+
from .pooling import pool_by_config
910

1011

1112
@dataclass
@@ -22,6 +23,7 @@ class ModelArgs(BaseModelArgs):
2223
initializer_range: float = 0.02
2324
layer_norm_eps: float = 1e-12
2425
vocab_size: int = 30522
26+
pooling_config: dict = field(default_factory=lambda: {"pooling_mode": "mean"})
2527

2628

2729
class BertEmbeddings(nn.Module):
@@ -224,8 +226,9 @@ def __call__(self, input_ids, token_type_ids=None, attention_mask=None):
224226
sequence_output = encoder_outputs
225227
pooled_output = self.pooler(sequence_output)
226228

227-
# normalized features
228-
text_embeds = mean_pooling(sequence_output, attention_mask)
229+
text_embeds = pool_by_config(
230+
sequence_output, attention_mask, self.config.pooling_config
231+
)
229232
text_embeds = normalize_embeddings(text_embeds)
230233

231234
return BaseModelOutput(

mlx_embeddings/models/gemma3_text.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from mlx_lm.models.base import create_attention_mask
77
from mlx_lm.models.gemma3_text import ModelArgs, RMSNorm, TransformerBlock
88

9-
from .base import BaseModelOutput, mean_pooling, normalize_embeddings
9+
from .base import BaseModelOutput, normalize_embeddings
10+
from .pooling import mean_pooling
1011

1112

1213
class Gemma3Model(nn.Module):

mlx_embeddings/models/lfm2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from mlx_lm.models.lfm2 import Lfm2DecoderLayer
1010
from mlx_lm.models.lfm2 import ModelArgs as Lfm2ModelArgs
1111

12-
from .base import BaseModelOutput, mean_pooling, normalize_embeddings
12+
from .base import BaseModelOutput, normalize_embeddings
13+
from .pooling import mean_pooling
1314

1415

1516
@dataclass

mlx_embeddings/models/llama_bidirec.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import mlx.nn as nn
66
from mlx_lm.models.llama import TransformerBlock
77

8-
from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
8+
from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings
9+
from .pooling import mean_pooling
910

1011

1112
@dataclass

mlx_embeddings/models/llama_nemotron_vl/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import mlx.nn as nn
77
import numpy as np
88

9-
from ..base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
9+
from ..base import BaseModelArgs, BaseModelOutput, normalize_embeddings
1010
from ..llama_bidirec import LlamaBidirectionalModel
1111
from ..llama_bidirec import ModelArgs as LlamaBidirectModelArgs
12+
from ..pooling import mean_pooling
1213
from ..siglip import SiglipVisionTransformer
1314
from ..siglip import VisionConfig as SiglipVisionConfig
1415

mlx_embeddings/models/modernbert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import mlx.core as mx
66
import mlx.nn as nn
77

8-
from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
8+
from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings
9+
from .pooling import mean_pooling
910

1011

1112
@dataclass

mlx_embeddings/models/pooling.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from typing import Any, Dict
2+
3+
import mlx.core as mx
4+
5+
6+
def mean_pooling(token_embeddings: mx.array, attention_mask: mx.array):
7+
input_mask_expanded = mx.expand_dims(attention_mask, -1)
8+
input_mask_expanded = mx.broadcast_to(
9+
input_mask_expanded, token_embeddings.shape
10+
).astype(mx.float32)
11+
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, axis=1)
12+
sum_mask = mx.maximum(mx.sum(input_mask_expanded, axis=1), 1e-9)
13+
return sum_embeddings / sum_mask
14+
15+
16+
def cls_pooling(token_embeddings: mx.array, attention_mask: mx.array) -> mx.array:
17+
first_indices = mx.argmax(attention_mask, axis=1)
18+
batch_size = token_embeddings.shape[0]
19+
hidden_dim = token_embeddings.shape[-1]
20+
gather_idx = mx.broadcast_to(
21+
first_indices[:, None, None], (batch_size, 1, hidden_dim)
22+
)
23+
return mx.squeeze(mx.take_along_axis(token_embeddings, gather_idx, axis=1), axis=1)
24+
25+
26+
def max_pooling(token_embeddings: mx.array, attention_mask: mx.array) -> mx.array:
27+
mask = mx.expand_dims(attention_mask, -1)
28+
mask = mx.broadcast_to(mask, token_embeddings.shape).astype(token_embeddings.dtype)
29+
masked = mx.where(mask == 0, -float("inf"), token_embeddings)
30+
return mx.max(masked, axis=1)
31+
32+
33+
def lasttoken_pooling(token_embeddings: mx.array, attention_mask: mx.array) -> mx.array:
34+
batch_size, seq_len, hidden_dim = token_embeddings.shape
35+
flipped = attention_mask[:, ::-1]
36+
flip_indices = mx.argmax(flipped, axis=1)
37+
has_any_real = mx.max(flipped, axis=1)
38+
flip_indices = mx.where(has_any_real == 0, seq_len - 1, flip_indices)
39+
last_indices = seq_len - flip_indices - 1
40+
gather_idx = mx.broadcast_to(
41+
last_indices[:, None, None], (batch_size, 1, hidden_dim)
42+
)
43+
mask = mx.broadcast_to(attention_mask[:, :, None], token_embeddings.shape).astype(
44+
token_embeddings.dtype
45+
)
46+
return mx.squeeze(
47+
mx.take_along_axis(token_embeddings * mask, gather_idx, axis=1), axis=1
48+
)
49+
50+
51+
_LEGACY_POOLING_MODE_KWARGS = {
52+
"pooling_mode_cls_token": "cls",
53+
"pooling_mode_max_tokens": "max",
54+
"pooling_mode_mean_tokens": "mean",
55+
"pooling_mode_mean_sqrt_len_tokens": "mean_sqrt_len_tokens",
56+
"pooling_mode_weightedmean_tokens": "weightedmean",
57+
"pooling_mode_lasttoken": "lasttoken",
58+
}
59+
60+
_SUPPORTED_POOL_MODES = {"cls", "mean", "max", "lasttoken"}
61+
_KNOWN_UNSUPPORTED_POOL_MODES = {"weightedmean", "mean_sqrt_len_tokens"}
62+
63+
64+
def _normalize_pooling_config(
65+
pooling_config: Dict[str, Any],
66+
) -> Dict[str, Any]:
67+
cfg = dict(pooling_config)
68+
found = [k for k in _LEGACY_POOLING_MODE_KWARGS if k in cfg]
69+
if not found:
70+
return cfg
71+
if "pooling_mode" not in cfg:
72+
active = tuple(
73+
name
74+
for key, name in _LEGACY_POOLING_MODE_KWARGS.items()
75+
if cfg.get(key, False)
76+
)
77+
if not active:
78+
active = ("mean",)
79+
cfg["pooling_mode"] = active[0] if len(active) == 1 else active
80+
for k in found:
81+
del cfg[k]
82+
return cfg
83+
84+
85+
def pool_by_config(
86+
token_embeddings: mx.array,
87+
attention_mask: mx.array,
88+
pooling_config: Dict[str, Any],
89+
) -> mx.array:
90+
cfg = _normalize_pooling_config(pooling_config)
91+
mode = cfg["pooling_mode"]
92+
if not cfg.get("include_prompt", True):
93+
raise NotImplementedError(
94+
"Prompt-aware pooling (include_prompt=False) is not supported. "
95+
"This affects INSTRUCTOR-style models."
96+
)
97+
if isinstance(mode, (tuple, list)):
98+
raise NotImplementedError(
99+
f"Concatenated pooling mode {mode!r} is not supported; "
100+
"only a single pooling mode is allowed."
101+
)
102+
if mode in _KNOWN_UNSUPPORTED_POOL_MODES:
103+
raise NotImplementedError(
104+
f"Pooling mode {mode!r} is not supported. "
105+
f"Supported modes: {sorted(_SUPPORTED_POOL_MODES)}."
106+
)
107+
108+
if mode == "cls":
109+
return cls_pooling(token_embeddings, attention_mask)
110+
if mode == "max":
111+
return max_pooling(token_embeddings, attention_mask)
112+
if mode == "lasttoken":
113+
return lasttoken_pooling(token_embeddings, attention_mask)
114+
if mode == "mean":
115+
return mean_pooling(token_embeddings, attention_mask)
116+
raise ValueError(
117+
f"Unknown pooling mode {mode!r}. "
118+
f"Supported modes: {sorted(_SUPPORTED_POOL_MODES)}."
119+
)

mlx_embeddings/models/xlm_roberta.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import math
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
from typing import Optional, Tuple
44

55
import mlx.core as mx
66
import mlx.nn as nn
77

8-
from .base import BaseModelArgs, BaseModelOutput, mean_pooling, normalize_embeddings
8+
from .base import BaseModelArgs, BaseModelOutput, normalize_embeddings
9+
from .pooling import pool_by_config
910

1011

1112
@dataclass
@@ -25,7 +26,7 @@ class ModelArgs(BaseModelArgs):
2526
output_past: bool = True
2627
pad_token_id: int = 1
2728
position_embedding_type: str = "absolute"
28-
pooling_config: dict = None
29+
pooling_config: dict = field(default_factory=lambda: {"pooling_mode": "mean"})
2930

3031

3132
class XLMRobertaEmbeddings(nn.Module):
@@ -352,8 +353,9 @@ def __call__(
352353
self.pooler(sequence_output) if self.pooler is not None else None
353354
)
354355

355-
# normalized features
356-
text_embeds = mean_pooling(sequence_output, attention_mask)
356+
text_embeds = pool_by_config(
357+
sequence_output, attention_mask, self.config.pooling_config
358+
)
357359
text_embeds = normalize_embeddings(text_embeds)
358360

359361
return BaseModelOutput(

0 commit comments

Comments
 (0)