Skip to content

Commit 0581046

Browse files
authored
fix: only allows named args in pretrain (#200)
* fix: named arguments for from_pretrained * docs: update README
1 parent 62dcef9 commit 0581046

3 files changed

Lines changed: 6 additions & 6 deletions

File tree

model2vec/train/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ from model2vec.train import StaticModelForClassification
2222

2323
# From a distilled model
2424
distilled_model = distill("baai/bge-base-en-v1.5")
25-
classifier = StaticModelForClassification.from_static_model(distilled_model)
25+
classifier = StaticModelForClassification.from_static_model(model=distilled_model)
2626

2727
# From a pre-trained model: potion is the default
2828
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m")

model2vec/train/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ def construct_head(self) -> nn.Sequential:
4545

4646
@classmethod
4747
def from_pretrained(
48-
cls: type[ModelType], out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any
48+
cls: type[ModelType], *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any
4949
) -> ModelType:
5050
"""Load the model from a pretrained model2vec model."""
5151
model = StaticModel.from_pretrained(model_name)
52-
return cls.from_static_model(model, out_dim, **kwargs)
52+
return cls.from_static_model(model=model, out_dim=out_dim, **kwargs)
5353

5454
@classmethod
55-
def from_static_model(cls: type[ModelType], model: StaticModel, out_dim: int = 2, **kwargs: Any) -> ModelType:
55+
def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> ModelType:
5656
"""Load the model from a static model."""
5757
model.embedding = np.nan_to_num(model.embedding)
5858
embeddings_converted = torch.from_numpy(model.embedding)

tests/test_trainable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_init_base_class(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) ->
4141
def test_init_base_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
4242
"""Test initializion from a static model."""
4343
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer)
44-
s = FinetunableStaticModel.from_static_model(model)
44+
s = FinetunableStaticModel.from_static_model(model=model)
4545
assert s.vectors.shape == mock_vectors.shape
4646
assert s.w.shape[0] == mock_vectors.shape[0]
4747

@@ -55,7 +55,7 @@ def test_init_base_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenize
5555
def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
5656
"""Test initializion from a static model."""
5757
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer)
58-
s = StaticModelForClassification.from_static_model(model)
58+
s = StaticModelForClassification.from_static_model(model=model)
5959
assert s.vectors.shape == mock_vectors.shape
6060
assert s.w.shape[0] == mock_vectors.shape[0]
6161

0 commit comments

Comments
 (0)