Skip to content

Commit 1e8b454

Browse files
committed
add configurable pad token
1 parent f0bb56f commit 1e8b454

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

model2vec/train/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,15 @@ def from_pretrained(
6565
return cls.from_static_model(model=model, out_dim=out_dim, **kwargs)
6666

6767
@classmethod
68-
def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> ModelType:
68+
def from_static_model(
69+
cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, pad_token: str = "[PAD]", **kwargs: Any
70+
) -> ModelType:
6971
"""Load the model from a static model."""
7072
model.embedding = np.nan_to_num(model.embedding)
7173
embeddings_converted = torch.from_numpy(model.embedding)
7274
return cls(
7375
vectors=embeddings_converted,
74-
pad_id=model.tokenizer.token_to_id("[PAD]"),
76+
pad_id=model.tokenizer.token_to_id(pad_token),
7577
out_dim=out_dim,
7678
tokenizer=model.tokenizer,
7779
**kwargs,

0 commit comments

Comments
 (0)