Skip to content

Commit 048d1a6

Browse files
authored
feat: add frozen-weights option to training (#327)
* improve training * update lock file * simplify
1 parent c4cc294 commit 048d1a6

3 files changed

Lines changed: 41 additions & 13 deletions

File tree

model2vec/train/base.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
weights: torch.Tensor | None = None,
4747
freeze: bool = False,
4848
normalize: bool = True,
49+
freeze_weights: bool = False,
4950
) -> None:
5051
"""Initialize a trainable StaticModel from a StaticModel.
5152
@@ -59,6 +60,7 @@ def __init__(
5960
:param weights: The weights of the model. If None, the weights are initialized to zeros.
6061
:param freeze: Whether to freeze the embeddings. This should be set to False in most cases.
6162
:param normalize: Whether to normalize the embeddings.
63+
:param freeze_weights: Whether to freeze the learned token weights.
6264
"""
6365
super().__init__()
6466
self.pad_id = pad_id
@@ -67,6 +69,7 @@ def __init__(
6769
self.hidden_dim = hidden_dim
6870
self.n_layers = n_layers
6971
self.normalize = normalize
72+
self.freeze_weights = freeze_weights
7073

7174
self.vectors = vectors
7275
if self.vectors.dtype != torch.float32:
@@ -92,10 +95,10 @@ def construct_weights(self) -> nn.Parameter:
9295
"""Construct the weights for the model."""
9396
if self._weights is not None:
9497
w = logit(self._weights)
95-
return nn.Parameter(w.float(), requires_grad=True)
96-
weights = torch.zeros(len(self.token_mapping))
97-
weights[self.pad_id] = -10_000
98-
return nn.Parameter(weights, requires_grad=not self.freeze)
98+
else:
99+
w = torch.zeros(len(self.token_mapping)).float()
100+
w[self.pad_id] = -10_000
101+
return nn.Parameter(w, requires_grad=not self.freeze_weights)
99102

100103
def construct_head(self) -> nn.Sequential:
101104
"""Constructs a simple classifier head."""
@@ -136,7 +139,11 @@ def _initialize(self) -> None:
136139

137140
@classmethod
138141
def from_pretrained(
139-
cls: type[ModelType], path: str = "minishlab/potion-base-32m", *, token: str | None = None, **kwargs: Any
142+
cls: type[ModelType],
143+
path: str = "minishlab/potion-base-32m",
144+
*,
145+
token: str | None = None,
146+
**kwargs: Any,
140147
) -> ModelType:
141148
"""Load the model from a pretrained model2vec model."""
142149
if model_name := kwargs.pop("model_name", None):
@@ -147,7 +154,11 @@ def from_pretrained(
147154

148155
@classmethod
149156
def from_static_model(
150-
cls: type[ModelType], *, model: StaticModel, pad_token: str | None = None, **kwargs: Any
157+
cls: type[ModelType],
158+
*,
159+
model: StaticModel,
160+
pad_token: str | None = None,
161+
**kwargs: Any,
151162
) -> ModelType:
152163
"""Load the model from a static model."""
153164
model.embedding = np.nan_to_num(model.embedding)
@@ -179,14 +190,15 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
179190
:param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
180191
:return: The mean over the input ids, weighted by token weights.
181192
"""
182-
w = self.w[input_ids]
183-
w = torch.sigmoid(w)
184193
zeros = (input_ids != self.pad_id).float()
185-
w = w * zeros
186194
# Add a small epsilon to avoid division by zero
187195
length = zeros.sum(1) + 1e-16
188196
input_ids_embeddings = self.token_mapping[input_ids]
189197
embedded = self.embeddings(input_ids_embeddings)
198+
199+
w = self.w[input_ids]
200+
w = torch.sigmoid(w)
201+
w = w * zeros
190202
# Weigh each token
191203
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
192204
# Mean pooling by dividing by the length
@@ -235,13 +247,20 @@ def device(self) -> torch.device:
235247

236248
def to_static_model(self) -> StaticModel:
237249
"""Convert the model to a static model."""
238-
emb = self.embeddings.weight.detach().cpu().numpy()
239-
w = torch.sigmoid(self.w).detach().cpu().numpy()
250+
with torch.no_grad():
251+
emb = self.embeddings.weight
252+
emb = emb.cpu().numpy()
253+
w = torch.sigmoid(self.w).cpu().numpy()
254+
240255
# If the weights and emb are the same length, the model was not quantized before training.
241256
if len(w) == len(emb):
242257
emb = emb * w[:, None]
243258
return StaticModel(
244-
vectors=emb, weights=None, tokenizer=self.tokenizer, normalize=self.normalize, token_mapping=None
259+
vectors=emb,
260+
weights=None,
261+
tokenizer=self.tokenizer,
262+
normalize=self.normalize,
263+
token_mapping=None,
245264
)
246265
return StaticModel(
247266
vectors=emb,
@@ -265,7 +284,12 @@ def _determine_batch_size(self, batch_size: int | None, train_length: int) -> in
265284
return batch_size
266285

267286
def _check_val_split(
268-
self, X: list[str], y: list, X_val: list[str] | None, y_val: list | None, test_size: float
287+
self,
288+
X: list[str],
289+
y: list,
290+
X_val: list[str] | None,
291+
y_val: list | None,
292+
test_size: float,
269293
) -> tuple[list[str], list[str], Sequence, Sequence]:
270294
if (X_val is not None) != (y_val is not None):
271295
raise ValueError("Both X_val and y_val must be provided together, or neither.")

model2vec/train/classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
weights: torch.Tensor | None = None,
3939
freeze: bool = False,
4040
normalize: bool = True,
41+
freeze_weights: bool = False,
4142
) -> None:
4243
"""Initialize a standard classifier model."""
4344
# Alias: Follows scikit-learn. Set to dummy classes
@@ -55,6 +56,7 @@ def __init__(
5556
hidden_dim=hidden_dim,
5657
n_layers=n_layers,
5758
normalize=normalize,
59+
freeze_weights=freeze_weights,
5860
)
5961

6062
@property

model2vec/train/similarity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
weights: torch.Tensor | None = None,
3131
freeze: bool = False,
3232
normalize: bool = True,
33+
freeze_weights: bool = False,
3334
) -> None:
3435
"""Initialize a standard similarity model."""
3536
super().__init__(
@@ -43,6 +44,7 @@ def __init__(
4344
hidden_dim=hidden_dim,
4445
n_layers=n_layers,
4546
normalize=normalize,
47+
freeze_weights=freeze_weights,
4648
)
4749

4850
def fit(

0 commit comments

Comments
 (0)