Skip to content

Commit 5d6fa9a

Browse files
committed
fix: add bf16 interface to EncoderOnlyEmbedderM3Runner.get_model
1 parent 9095783 commit 5d6fa9a

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

  • FlagEmbedding/finetune/embedder/encoder_only/m3

FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import torch
33
import logging
4-
from typing import Tuple
4+
from typing import Tuple, Optional
55
from transformers import (
66
AutoModel, AutoConfig,
77
AutoTokenizer, PreTrainedTokenizer
@@ -44,7 +44,8 @@ def get_model(
4444
model_name_or_path: str,
4545
trust_remote_code: bool = False,
4646
colbert_dim: int = -1,
47-
cache_dir: str = None
47+
cache_dir: str = None,
48+
torch_dtype: Optional[torch.dtype] = None,
4849
):
4950
"""Get the model.
5051
@@ -54,6 +55,7 @@ def get_model(
5455
trust_remote_code (bool, optional): trust_remote_code to use when loading models from HF. Defaults to ``False``.
5556
colbert_dim (int, optional): Colbert dim to set. Defaults to ``-1``.
5657
cache_dir (str, optional): HF cache dir to store the model. Defaults to ``None``.
58+
torch_dtype (Optional[torch.dtype], optional): Torch dtype used when loading model weights. Defaults to ``None``.
5759
5860
Returns:
5961
dict: A dictionary containing the model, colbert linear and sparse linear.
@@ -69,7 +71,8 @@ def get_model(
6971
model = AutoModel.from_pretrained(
7072
model_name_or_path,
7173
cache_dir=cache_folder,
72-
trust_remote_code=trust_remote_code
74+
trust_remote_code=trust_remote_code,
75+
dtype=torch_dtype,
7376
)
7477
colbert_linear = torch.nn.Linear(
7578
in_features=model.config.hidden_size,

0 commit comments

Comments
 (0)