2424
2525
2626class PoolingType (str , Enum ):
27- """Pooling strategies for embedding creation."""
27+ """
28+ Pooling strategies for embedding creation.
29+
30+ - MEAN: masked mean over all tokens (ignores padding).
31+ - LAST: last non-padding token (often EOS, common in decoder-style models).
32+ - FIRST: first token hidden state (position 0). In BERT-style encoders,
33+ this corresponds to the [CLS] token representation.
34+ - POOLER: use the model's `pooler_output`. In BERT-like models this is
35+ computed as the hidden state at [CLS], passed through a learned
36+ dense layer + activation. Not all models provide this.
37+ """
2838
2939 MEAN = "mean"
3040 LAST = "last"
31- CLS = "cls"
41+ FIRST = "first"
42+ POOLER = "pooler"
3243
3344
3445def create_embeddings (
@@ -74,16 +85,13 @@ def create_embeddings(
7485 encoded = {}
7586 encoded ["input_ids" ] = pad_sequence (batch , batch_first = True , padding_value = pad_token_id )
7687
77- if pooling == PoolingType .MEAN :
78- # For mean pooling, mask out padding tokens
79- encoded ["attention_mask" ] = encoded ["input_ids" ] != pad_token_id
80- else :
81- # For "last"/"cls": build mask directly from true lengths to ensure
82- # the last non-pad token and CLS positions are chosen correctly
83- seq_len = encoded ["input_ids" ].size (1 )
84- batch_lengths = torch .tensor ([len (x ) for x in batch_list ], device = encoded ["input_ids" ].device )
85- token_positions = torch .arange (seq_len , device = encoded ["input_ids" ].device )
86- encoded ["attention_mask" ] = token_positions .unsqueeze (0 ) < batch_lengths .unsqueeze (1 )
88+ # Create attention mask by using the lengths of each sequence
89+ seq_len = encoded ["input_ids" ].size (1 )
90+ batch_lengths = torch .tensor ([len (x ) for x in batch_list ], device = encoded ["input_ids" ].device )
91+ token_positions = torch .arange (seq_len , device = encoded ["input_ids" ].device )
92+ # Mark padding tokens with 0, and non-padding tokens with 1
93+ attention_mask = token_positions .unsqueeze (0 ) < batch_lengths .unsqueeze (1 )
94+ encoded ["attention_mask" ] = attention_mask .to (dtype = torch .long )
8795
8896 if add_token_type_ids :
8997 # Add token_type_ids for models that support it
@@ -93,8 +101,10 @@ def create_embeddings(
93101 out = _encode_mean_with_model (model , encoded )
94102 elif pooling == PoolingType .LAST :
95103 out = _encode_last_with_model (model , encoded )
96- elif pooling == PoolingType .CLS :
97- out = _encode_cls_with_model (model , encoded )
104+ elif pooling == PoolingType .FIRST :
105+ out = _encode_first_with_model (model , encoded )
106+ elif pooling == PoolingType .POOLER :
107+ out = _encode_pooler_with_model (model , encoded )
98108 else :
99109 raise ValueError (f"Unknown pooling: { pooling } " )
100110
@@ -163,30 +173,41 @@ def _encode_last_with_model(model: PreTrainedModel, encodings: dict[str, torch.T
163173 :return: The last hidden state for each token.
164174 """
165175 hidden , _ , encodings_on_device = _encode_with_model (model , encodings )
166- # Get the last hidden state for each token
167176 mask = encodings_on_device ["attention_mask" ].bool ()
168177 last_idx = (mask .sum (dim = 1 ) - 1 ).clamp_min (0 ).long ()
169- b = torch .arange (hidden .size (0 ), device = hidden .device )
170- return hidden [b , last_idx , :].cpu ()
178+ batch_indices = torch .arange (hidden .size (0 ), device = hidden .device )
179+ return hidden [batch_indices , last_idx , :].cpu ()
171180
172181
173182@torch .inference_mode ()
174- def _encode_cls_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
183+ def _encode_first_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
175184 """
176- Encode a batch of tokens using CLS pooling.
177-
178- If the model has a pooler_output, use that, otherwise, use the first token's hidden state.
185+ Encode a batch of tokens using first token (CLS) pooling.
179186
180187 :param model: The model to use.
181188 :param encodings: The encoded tokens to turn into features.
182- :return: The [CLS] token representation for each token.
189+ :return: The first token representation for each token.
183190 """
184- hidden , pooler , _ = _encode_with_model (model , encodings )
185- if pooler is not None :
186- return pooler .cpu ()
191+ hidden , _ , _ = _encode_with_model (model , encodings )
187192 return hidden [:, 0 , :].cpu ()
188193
189194
195+ @torch .inference_mode ()
196+ def _encode_pooler_with_model (model : PreTrainedModel , encodings : dict [str , torch .Tensor ]) -> torch .Tensor :
197+ """
198+ Encode a batch of tokens using pooler output.
199+
200+ :param model: The model to use.
201+ :param encodings: The encoded tokens to turn into features.
202+ :return: The pooler output for each token.
203+ :raises ValueError: If the model does not return pooler_output.
204+ """
205+ _ , pooler , _ = _encode_with_model (model , encodings )
206+ if pooler is None :
207+ raise ValueError ("POOLER pooling requested, but model did not return pooler_output." )
208+ return pooler .cpu ()
209+
210+
190211def post_process_embeddings (
191212 embeddings : np .ndarray , pca_dims : PCADimType , sif_coefficient : float | None = 1e-4
192213) -> tuple [np .ndarray , np .ndarray ]:
0 commit comments