@@ -46,6 +46,7 @@ def __init__(
4646 weights : torch .Tensor | None = None ,
4747 freeze : bool = False ,
4848 normalize : bool = True ,
49+ freeze_weights : bool = False ,
4950 ) -> None :
5051 """Initialize a trainable StaticModel from a StaticModel.
5152
@@ -59,6 +60,7 @@ def __init__(
5960 :param weights: The weights of the model. If None, the weights are initialized to zeros.
6061 :param freeze: Whether to freeze the embeddings. This should be set to False in most cases.
6162 :param normalize: Whether to normalize the embeddings.
63+ :param freeze_weights: Whether to freeze the learned token weights.
6264 """
6365 super ().__init__ ()
6466 self .pad_id = pad_id
@@ -67,6 +69,7 @@ def __init__(
6769 self .hidden_dim = hidden_dim
6870 self .n_layers = n_layers
6971 self .normalize = normalize
72+ self .freeze_weights = freeze_weights
7073
7174 self .vectors = vectors
7275 if self .vectors .dtype != torch .float32 :
@@ -92,10 +95,10 @@ def construct_weights(self) -> nn.Parameter:
9295 """Construct the weights for the model."""
9396 if self ._weights is not None :
9497 w = logit (self ._weights )
95- return nn . Parameter ( w . float (), requires_grad = True )
96- weights = torch .zeros (len (self .token_mapping ))
97- weights [self .pad_id ] = - 10_000
98- return nn .Parameter (weights , requires_grad = not self .freeze )
98+ else :
99+ w = torch .zeros (len (self .token_mapping )). float ( )
100+ w [self .pad_id ] = - 10_000
101+ return nn .Parameter (w , requires_grad = not self .freeze_weights )
99102
100103 def construct_head (self ) -> nn .Sequential :
101104 """Constructs a simple classifier head."""
@@ -136,7 +139,11 @@ def _initialize(self) -> None:
136139
137140 @classmethod
138141 def from_pretrained (
139- cls : type [ModelType ], path : str = "minishlab/potion-base-32m" , * , token : str | None = None , ** kwargs : Any
142+ cls : type [ModelType ],
143+ path : str = "minishlab/potion-base-32m" ,
144+ * ,
145+ token : str | None = None ,
146+ ** kwargs : Any ,
140147 ) -> ModelType :
141148 """Load the model from a pretrained model2vec model."""
142149 if model_name := kwargs .pop ("model_name" , None ):
@@ -147,7 +154,11 @@ def from_pretrained(
147154
148155 @classmethod
149156 def from_static_model (
150- cls : type [ModelType ], * , model : StaticModel , pad_token : str | None = None , ** kwargs : Any
157+ cls : type [ModelType ],
158+ * ,
159+ model : StaticModel ,
160+ pad_token : str | None = None ,
161+ ** kwargs : Any ,
151162 ) -> ModelType :
152163 """Load the model from a static model."""
153164 model .embedding = np .nan_to_num (model .embedding )
@@ -179,14 +190,15 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
179190 :param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
180191 :return: The mean over the input ids, weighted by token weights.
181192 """
182- w = self .w [input_ids ]
183- w = torch .sigmoid (w )
184193 zeros = (input_ids != self .pad_id ).float ()
185- w = w * zeros
186194 # Add a small epsilon to avoid division by zero
187195 length = zeros .sum (1 ) + 1e-16
188196 input_ids_embeddings = self .token_mapping [input_ids ]
189197 embedded = self .embeddings (input_ids_embeddings )
198+
199+ w = self .w [input_ids ]
200+ w = torch .sigmoid (w )
201+ w = w * zeros
190202 # Weigh each token
191203 embedded = torch .bmm (w [:, None , :], embedded ).squeeze (1 )
192204 # Mean pooling by dividing by the length
@@ -235,13 +247,20 @@ def device(self) -> torch.device:
235247
236248 def to_static_model (self ) -> StaticModel :
237249 """Convert the model to a static model."""
238- emb = self .embeddings .weight .detach ().cpu ().numpy ()
239- w = torch .sigmoid (self .w ).detach ().cpu ().numpy ()
250+ with torch .no_grad ():
251+ emb = self .embeddings .weight
252+ emb = emb .cpu ().numpy ()
253+ w = torch .sigmoid (self .w ).cpu ().numpy ()
254+
240255 # If the weights and emb are the same length, the model was not quantized before training.
241256 if len (w ) == len (emb ):
242257 emb = emb * w [:, None ]
243258 return StaticModel (
244- vectors = emb , weights = None , tokenizer = self .tokenizer , normalize = self .normalize , token_mapping = None
259+ vectors = emb ,
260+ weights = None ,
261+ tokenizer = self .tokenizer ,
262+ normalize = self .normalize ,
263+ token_mapping = None ,
245264 )
246265 return StaticModel (
247266 vectors = emb ,
@@ -265,7 +284,12 @@ def _determine_batch_size(self, batch_size: int | None, train_length: int) -> in
265284 return batch_size
266285
267286 def _check_val_split (
268- self , X : list [str ], y : list , X_val : list [str ] | None , y_val : list | None , test_size : float
287+ self ,
288+ X : list [str ],
289+ y : list ,
290+ X_val : list [str ] | None ,
291+ y_val : list | None ,
292+ test_size : float ,
269293 ) -> tuple [list [str ], list [str ], Sequence , Sequence ]:
270294 if (X_val is not None ) != (y_val is not None ):
271295 raise ValueError ("Both X_val and y_val must be provided together, or neither." )
0 commit comments