@@ -25,6 +25,7 @@ def __init__(
2525 pad_id : int = 0 ,
2626 token_mapping : list [int ] | None = None ,
2727 weights : torch .Tensor | None = None ,
28+ freeze : bool = False ,
2829 ) -> None :
2930 """
3031 Initialize a trainable StaticModel from a StaticModel.
@@ -35,6 +36,7 @@ def __init__(
3536 :param pad_id: The padding id. This is set to 0 in almost all model2vec models
3637 :param token_mapping: The token mapping. If None, the token mapping is set to the range of the number of vectors.
3738 :param weights: The weights of the model. If None, the weights are initialized to zeros.
39+ :param freeze: Whether to freeze the embeddings. This should be set to False in most cases.
3840 """
3941 super ().__init__ ()
4042 self .pad_id = pad_id
@@ -54,7 +56,8 @@ def __init__(
5456 else :
5557 self .token_mapping = torch .arange (len (vectors ), dtype = torch .int64 )
5658 self .token_mapping = nn .Parameter (self .token_mapping , requires_grad = False )
57- self .embeddings = nn .Embedding .from_pretrained (vectors .clone (), freeze = False , padding_idx = pad_id )
59+ self .freeze = freeze
60+ self .embeddings = nn .Embedding .from_pretrained (vectors .clone (), freeze = self .freeze , padding_idx = pad_id )
5861 self .head = self .construct_head ()
5962 self .w = self .construct_weights () if weights is None else nn .Parameter (weights , requires_grad = True )
6063 self .tokenizer = tokenizer
@@ -63,7 +66,7 @@ def construct_weights(self) -> nn.Parameter:
6366 """Construct the weights for the model."""
6467 weights = torch .zeros (len (self .token_mapping ))
6568 weights [self .pad_id ] = - 10_000
66- return nn .Parameter (weights )
69+ return nn .Parameter (weights , requires_grad = not self . freeze )
6770
6871 def construct_head (self ) -> nn .Sequential :
6972 """Method should be overridden for various other classes."""
0 commit comments