Skip to content

Commit 39b3881

Browse files
committed
merge
2 parents ec376c1 + f0bb56f commit 39b3881

2 files changed

Lines changed: 11 additions & 4 deletions

File tree

model2vec/train/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
pad_id: int = 0,
2626
token_mapping: list[int] | None = None,
2727
weights: torch.Tensor | None = None,
28+
freeze: bool = False,
2829
) -> None:
2930
"""
3031
Initialize a trainable StaticModel from a StaticModel.
@@ -35,6 +36,7 @@ def __init__(
3536
:param pad_id: The padding id. This is set to 0 in almost all model2vec models
3637
:param token_mapping: The token mapping. If None, the token mapping is set to the range of the number of vectors.
3738
:param weights: The weights of the model. If None, the weights are initialized to zeros.
39+
:param freeze: Whether to freeze the embeddings. This should be set to False in most cases.
3840
"""
3941
super().__init__()
4042
self.pad_id = pad_id
@@ -54,7 +56,8 @@ def __init__(
5456
else:
5557
self.token_mapping = torch.arange(len(vectors), dtype=torch.int64)
5658
self.token_mapping = nn.Parameter(self.token_mapping, requires_grad=False)
57-
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
59+
self.freeze = freeze
60+
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=self.freeze, padding_idx=pad_id)
5861
self.head = self.construct_head()
5962
self.w = self.construct_weights() if weights is None else nn.Parameter(weights, requires_grad=True)
6063
self.tokenizer = tokenizer
@@ -63,7 +66,7 @@ def construct_weights(self) -> nn.Parameter:
6366
"""Construct the weights for the model."""
6467
weights = torch.zeros(len(self.token_mapping))
6568
weights[self.pad_id] = -10_000
66-
return nn.Parameter(weights)
69+
return nn.Parameter(weights, requires_grad=not self.freeze)
6770

6871
def construct_head(self) -> nn.Sequential:
6972
"""Method should be overridden for various other classes."""

model2vec/train/classifier.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
pad_id: int = 0,
4141
token_mapping: list[int] | None = None,
4242
weights: torch.Tensor | None = None,
43+
freeze: bool = False,
4344
) -> None:
4445
"""Initialize a standard classifier model."""
4546
self.n_layers = n_layers
@@ -55,6 +56,7 @@ def __init__(
5556
tokenizer=tokenizer,
5657
token_mapping=token_mapping,
5758
weights=weights,
59+
freeze=freeze,
5860
)
5961

6062
@property
@@ -133,7 +135,7 @@ def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_siz
133135
pred.append(torch.softmax(logits, dim=1).cpu().numpy())
134136
return np.concatenate(pred, axis=0)
135137

136-
def fit( # noqa: C901 # Refactor later
138+
def fit( # noqa: C901 # Complexity is bad.
137139
self,
138140
X: list[str],
139141
y: LabelType,
@@ -309,7 +311,9 @@ def _initialize(self, y: LabelType) -> None:
309311
self.classes_ = classes
310312
self.out_dim = len(self.classes_) # Update output dimension
311313
self.head = self.construct_head()
312-
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)
314+
self.embeddings = nn.Embedding.from_pretrained(
315+
self.vectors.clone(), freeze=self.freeze, padding_idx=self.pad_id
316+
)
313317
self.w = self.construct_weights()
314318
self.train()
315319

0 commit comments

Comments
 (0)