11import os
22import torch
33import logging
4- from typing import Tuple
4+ from typing import Tuple , Optional
55from transformers import (
66 AutoModel , AutoConfig ,
77 AutoTokenizer , PreTrainedTokenizer
@@ -44,7 +44,8 @@ def get_model(
4444 model_name_or_path : str ,
4545 trust_remote_code : bool = False ,
4646 colbert_dim : int = - 1 ,
47- cache_dir : str = None
47+ cache_dir : str = None ,
48+ torch_dtype : Optional [torch .dtype ] = None ,
4849 ):
4950 """Get the model.
5051
@@ -54,6 +55,7 @@ def get_model(
5455 trust_remote_code (bool, optional): trust_remote_code to use when loading models from HF. Defaults to ``False``.
5556 colbert_dim (int, optional): Colbert dim to set. Defaults to ``-1``.
5657 cache_dir (str, optional): HF cache dir to store the model. Defaults to ``None``.
58+ torch_dtype (Optional[torch.dtype], optional): Torch dtype used when loading model weights. Defaults to ``None``.
5759
5860 Returns:
5961 dict: A dictionary containing the model, colbert linear and sparse linear.
@@ -69,7 +71,8 @@ def get_model(
6971 model = AutoModel .from_pretrained (
7072 model_name_or_path ,
7173 cache_dir = cache_folder ,
72- trust_remote_code = trust_remote_code
74+ trust_remote_code = trust_remote_code ,
75+ dtype = torch_dtype ,
7376 )
7477 colbert_linear = torch .nn .Linear (
7578 in_features = model .config .hidden_size ,
0 commit comments