-
Notifications
You must be signed in to change notification settings - Fork 121
Expand file tree
/
Copy pathinference.py
More file actions
252 lines (205 loc) · 10.1 KB
/
inference.py
File metadata and controls
252 lines (205 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
from __future__ import annotations
import inspect
import logging
from enum import Enum
from pathlib import Path
from typing import Literal
import numpy as np
import torch
from sklearn.decomposition import PCA
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
PathLike = Path | str
PCADimType = int | None | float | Literal["auto"]
_DEFAULT_BATCH_SIZE = 256
class PoolingMode(str, Enum):
"""
Pooling modes for embedding creation.
- MEAN: masked mean over all tokens.
- LAST: last non-padding token (often EOS, common in decoder-style models).
- FIRST: first token hidden state (position 0). In BERT-style encoders,
this corresponds to the [CLS] token representation.
- POOLER: use the model's `pooler_output`. In BERT-like models this is
computed as the hidden state at [CLS], passed through a learned
dense layer + activation. Not all models provide this.
"""
MEAN = "mean"
LAST = "last"
FIRST = "first"
POOLER = "pooler"
def create_embeddings(
model: PreTrainedModel,
tokenized: list[list[int]],
device: str,
pad_token_id: int,
pooling: PoolingMode | str = PoolingMode.MEAN,
) -> np.ndarray:
"""
Create output embeddings for a bunch of tokens using a pretrained model.
It does a forward pass for all tokens passed in `tokens`.
:param model: The model to use.
This should be a transformers model.
:param tokenized: All tokenized tokens.
:param device: The torch device to use.
:param pad_token_id: The pad token id. Used to pad sequences.
:param pooling: The pooling mode to use.
:return: The output embeddings.
:raises ValueError: If the pooling mode is unknown.
"""
model = model.to(device).eval() # type: ignore # Transformers error
out_weights: np.ndarray
intermediate_weights: list[np.ndarray] = []
# Add token_type_ids only if the model supports it
add_token_type_ids = "token_type_ids" in inspect.getfullargspec(model.forward).args
lengths = np.asarray([len(sequence) for sequence in tokenized])
sort_order = np.argsort(lengths)
sorted_tokenized = [tokenized[i] for i in sort_order]
pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")
for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
batch_list = sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
batch = [torch.tensor(x, dtype=torch.long) for x in batch_list]
encoded = {}
encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
# Create attention mask by using the lengths of each sequence
seq_len = encoded["input_ids"].size(1)
batch_lengths = torch.tensor([len(x) for x in batch_list], device=encoded["input_ids"].device)
token_positions = torch.arange(seq_len, device=encoded["input_ids"].device)
# Mark padding tokens with 0, and non-padding tokens with 1
attention_mask = token_positions.unsqueeze(0) < batch_lengths.unsqueeze(1)
encoded["attention_mask"] = attention_mask.to(dtype=torch.long)
if add_token_type_ids:
# Add token_type_ids for models that support it
encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
if pooling == PoolingMode.MEAN:
out = _encode_mean_with_model(model, encoded)
elif pooling == PoolingMode.LAST:
out = _encode_last_with_model(model, encoded)
elif pooling == PoolingMode.FIRST:
out = _encode_first_with_model(model, encoded)
elif pooling == PoolingMode.POOLER:
out = _encode_pooler_with_model(model, encoded)
else:
raise ValueError(f"Unknown pooling: {pooling}")
intermediate_weights.extend(out.numpy())
pbar.update(len(batch))
# Sort the output back to the original order
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
out_weights = np.stack(intermediate_weights)
out_weights = np.nan_to_num(out_weights)
return out_weights
def _encode_with_model(
model: PreTrainedModel, encodings: dict[str, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor | None, dict[str, torch.Tensor]]:
"""
Move inputs to the model device, run a forward pass, and standardize dtypes.
:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: a tuple consisting of:
- hidden: last_hidden_state
- pooler: pooler_output if present, else None
- encodings_on_device: the device-moved encodings (for masks)
"""
encodings_on_device = {k: v.to(model.device) for k, v in encodings.items()}
outputs: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings_on_device)
hidden: torch.Tensor = outputs.last_hidden_state # type: ignore # False positive
# NOTE: If the dtype is bfloat 16, we convert to float32,
# because numpy does not suport bfloat16
# See here: https://github.com/numpy/numpy/issues/19808
hidden = hidden.float()
pooler = getattr(outputs, "pooler_output", None)
if pooler is not None:
pooler = pooler.float()
return hidden, pooler, encodings_on_device
@torch.inference_mode()
def _encode_mean_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using mean pooling.
:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The mean of the output for each token.
"""
hidden, _, encodings_on_device = _encode_with_model(model, encodings)
# Take the mean by averaging over the attention mask.
mask = encodings_on_device["attention_mask"].cpu().float()
lengths = mask.sum(1, keepdim=True).clamp_min_(1.0)
mask = mask / lengths
return torch.bmm(mask.to(hidden.device)[:, None, :], hidden).squeeze(1).cpu()
@torch.inference_mode()
def _encode_last_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using last token pooling.
:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The last hidden state for each token.
"""
hidden, _, encodings_on_device = _encode_with_model(model, encodings)
mask = encodings_on_device["attention_mask"].bool()
last_idx = (mask.sum(dim=1) - 1).clamp_min(0).long()
batch_indices = torch.arange(hidden.size(0), device=hidden.device)
return hidden[batch_indices, last_idx, :].cpu()
@torch.inference_mode()
def _encode_first_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using first token (CLS) pooling.
:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The first token representation for each token.
"""
hidden, _, _ = _encode_with_model(model, encodings)
return hidden[:, 0, :].cpu()
@torch.inference_mode()
def _encode_pooler_with_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Encode a batch of tokens using pooler output.
:param model: The model to use.
:param encodings: The encoded tokens to turn into features.
:return: The pooler output for each token.
:raises ValueError: If the model does not return pooler_output.
"""
_, pooler, _ = _encode_with_model(model, encodings)
if pooler is None:
raise ValueError("POOLER pooling requested, but model did not return pooler_output.")
return pooler.cpu()
def post_process_embeddings(
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
) -> tuple[np.ndarray, np.ndarray]:
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
if pca_dims is not None:
if pca_dims == "auto":
pca_dims = embeddings.shape[1]
if pca_dims > embeddings.shape[1]:
logger.warning(
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). "
"Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
"Applying PCA will probably improve performance, so consider just leaving it."
)
pca_dims = embeddings.shape[1]
if pca_dims >= embeddings.shape[0]:
logger.warning(
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
)
elif pca_dims <= embeddings.shape[1]:
if isinstance(pca_dims, float):
logger.info(f"Applying PCA with {pca_dims} explained variance.")
else:
logger.info(f"Applying PCA with n_components {pca_dims}")
orig_dims = embeddings.shape[1]
p = PCA(n_components=pca_dims, svd_solver="full")
embeddings = p.fit_transform(embeddings)
if embeddings.shape[1] < orig_dims:
explained_variance_ratio = np.sum(p.explained_variance_ratio_)
explained_variance = np.sum(p.explained_variance_)
logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
logger.info(f"Explained variance: {explained_variance:.3f}.")
if sif_coefficient is not None:
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
proba = inv_rank / np.sum(inv_rank)
weight = sif_coefficient / (sif_coefficient + proba)
else:
weight = np.ones(embeddings.shape[0])
return embeddings, weight