Skip to content

Commit 75e117c

Browse files
committed
fix: precision during training
1 parent 2d0417d commit 75e117c

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

model2vec/train/base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from typing import Any, TypeVar
45

56
import numpy as np
@@ -11,6 +12,8 @@
1112

1213
from model2vec import StaticModel
1314

15+
logger = logging.getLogger(__name__)
16+
1417

1518
class FinetunableStaticModel(nn.Module):
1619
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
@@ -26,9 +29,16 @@ def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int
2629
self.pad_id = pad_id
2730
self.out_dim = out_dim
2831
self.embed_dim = vectors.shape[1]
29-
self.vectors = vectors
3032

31-
self.embeddings = nn.Embedding.from_pretrained(vectors.clone().float(), freeze=False, padding_idx=pad_id)
33+
self.vectors = vectors
34+
if self.vectors.dtype != torch.float32:
35+
dtype = str(self.vectors.dtype)
36+
logger.warning(
37+
f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues."
38+
)
39+
self.vectors = vectors.float()
40+
41+
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
3242
self.head = self.construct_head()
3343
self.w = self.construct_weights()
3444
self.tokenizer = tokenizer

0 commit comments

Comments
 (0)