Skip to content

Commit 5dfdcd1

Browse files
committed
Simplified pooling code, generalized pooling options
1 parent 1e3b58e commit 5dfdcd1

2 files changed

Lines changed: 63 additions & 28 deletions

File tree

model2vec/distill/inference.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,22 @@
2424

2525

2626
class PoolingType(str, Enum):
27-
"""Pooling strategies for embedding creation."""
27+
"""
28+
Pooling strategies for embedding creation.
29+
30+
- MEAN: masked mean over all tokens (ignores padding).
31+
- LAST: last non-padding token (often EOS, common in decoder-style models).
32+
- FIRST: first token hidden state (position 0). In BERT-style encoders,
33+
this corresponds to the [CLS] token representation.
34+
- POOLER: use the model's `pooler_output`. In BERT-like models this is
35+
computed as the hidden state at [CLS], passed through a learned
36+
dense layer + activation. Not all models provide this.
37+
"""
2838

2939
MEAN = "mean"
3040
LAST = "last"
31-
CLS = "cls"
41+
FIRST = "first"
42+
POOLER = "pooler"
3243

3344

3445
def create_embeddings(
@@ -74,16 +85,13 @@ def create_embeddings(
7485
encoded = {}
7586
encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
7687

77-
if pooling == PoolingType.MEAN:
78-
# For mean pooling, mask out padding tokens
79-
encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
80-
else:
81-
# For "last"/"cls": build mask directly from true lengths to ensure
82-
# the last non-pad token and CLS positions are chosen correctly
83-
seq_len = encoded["input_ids"].size(1)
84-
batch_lengths = torch.tensor([len(x) for x in batch_list], device=encoded["input_ids"].device)
85-
token_positions = torch.arange(seq_len, device=encoded["input_ids"].device)
86-
encoded["attention_mask"] = token_positions.unsqueeze(0) < batch_lengths.unsqueeze(1)
88+
# Create attention mask by using the lengths of each sequence
89+
seq_len = encoded["input_ids"].size(1)
90+
batch_lengths = torch.tensor([len(x) for x in batch_list], device=encoded["input_ids"].device)
91+
token_positions = torch.arange(seq_len, device=encoded["input_ids"].device)
92+
# Mark padding tokens with 0, and non-padding tokens with 1
93+
attention_mask = token_positions.unsqueeze(0) < batch_lengths.unsqueeze(1)
94+
encoded["attention_mask"] = attention_mask.to(dtype=torch.long)
8795

8896
if add_token_type_ids:
8997
# Add token_type_ids for models that support it
@@ -93,8 +101,10 @@ def create_embeddings(
93101
out = _encode_mean_with_model(model, encoded)
94102
elif pooling == PoolingType.LAST:
95103
out = _encode_last_with_model(model, encoded)
96-
elif pooling == PoolingType.CLS:
97-
out = _encode_cls_with_model(model, encoded)
104+
elif pooling == PoolingType.FIRST:
105+
out = _encode_first_with_model(model, encoded)
106+
elif pooling == PoolingType.POOLER:
107+
out = _encode_pooler_with_model(model, encoded)
98108
else:
99109
raise ValueError(f"Unknown pooling: {pooling}")
100110

@@ -163,30 +173,41 @@ def _encode_last_with_model(model: PreTrainedModel, encodings: dict[str, torch.T
163173
:return: The last hidden state for each token.
164174
"""
165175
hidden, _, encodings_on_device = _encode_with_model(model, encodings)
166-
# Get the last hidden state for each token
167176
mask = encodings_on_device["attention_mask"].bool()
168177
last_idx = (mask.sum(dim=1) - 1).clamp_min(0).long()
169-
b = torch.arange(hidden.size(0), device=hidden.device)
170-
return hidden[b, last_idx, :].cpu()
178+
batch_indices = torch.arange(hidden.size(0), device=hidden.device)
179+
return hidden[batch_indices, last_idx, :].cpu()
171180

172181

173182
@torch.inference_mode()
174-
def _encode_cls_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
183+
def _encode_first_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
175184
"""
176-
Encode a batch of tokens using CLS pooling.
177-
178-
If the model has a pooler_output, use that, otherwise, use the first token's hidden state.
185+
Encode a batch of tokens using first token (CLS) pooling.
179186
180187
:param model: The model to use.
181188
:param encodings: The encoded tokens to turn into features.
182-
:return: The [CLS] token representation for each token.
189+
:return: The first token representation for each token.
183190
"""
184-
hidden, pooler, _ = _encode_with_model(model, encodings)
185-
if pooler is not None:
186-
return pooler.cpu()
191+
hidden, _, _ = _encode_with_model(model, encodings)
187192
return hidden[:, 0, :].cpu()
188193

189194

195+
@torch.inference_mode()
196+
def _encode_pooler_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
197+
"""
198+
Encode a batch of tokens using pooler output.
199+
200+
:param model: The model to use.
201+
:param encodings: The encoded tokens to turn into features.
202+
:return: The pooler output for each token.
203+
:raises ValueError: If the model does not return pooler_output.
204+
"""
205+
_, pooler, _ = _encode_with_model(model, encodings)
206+
if pooler is None:
207+
raise ValueError("POOLER pooling requested, but model did not return pooler_output.")
208+
return pooler.cpu()
209+
210+
190211
def post_process_embeddings(
191212
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
192213
) -> tuple[np.ndarray, np.ndarray]:

tests/test_distillation.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,10 @@ def test_clean_and_create_vocabulary(
294294
@pytest.mark.parametrize(
295295
"pooling,with_pooler,expected_rows",
296296
[
297-
(PoolingType.MEAN, False, [1.0, 0.0]), # len=3: mean(0,1,2)=1; len=1: mean(0) = 0
297+
(PoolingType.MEAN, False, [1.0, 0.0]), # len=3: mean(0,1,2)=1; len=1: mean(0)=0
298298
(PoolingType.LAST, False, [2.0, 0.0]), # last of 3: 2; last of 1: 0
299-
(PoolingType.CLS, False, [0.0, 0.0]), # first position: 0
300-
(PoolingType.CLS, True, [7.0, 7.0]), # pooler_output is used
299+
(PoolingType.FIRST, False, [0.0, 0.0]), # first position: 0
300+
(PoolingType.POOLER, True, [7.0, 7.0]), # pooler_output used
301301
],
302302
)
303303
def test_pooling_strategies(mock_transformer, pooling, with_pooler, expected_rows) -> None:
@@ -314,3 +314,17 @@ def test_pooling_strategies(mock_transformer, pooling, with_pooler, expected_row
314314
dim = out.shape[1]
315315
expected = np.stack([np.full((dim,), v, dtype=np.float32) for v in expected_rows])
316316
assert np.allclose(out, expected, rtol=1e-6, atol=0.0)
317+
318+
319+
def test_pooler_raises_without_pooler_output(mock_transformer) -> None:
320+
"""POOLER should raise when the model doesn't expose pooler_output."""
321+
mock_transformer.with_pooler = False
322+
tokenized = [[10, 11, 12], [20]]
323+
with pytest.raises(ValueError, match="pooler_output"):
324+
_ = create_embeddings(
325+
model=mock_transformer,
326+
tokenized=tokenized,
327+
device="cpu",
328+
pad_token_id=0,
329+
pooling=PoolingType.POOLER,
330+
)

0 commit comments

Comments
 (0)