@@ -60,6 +60,7 @@ def __init__(
6060 model_name_or_path : str ,
6161 normalize_embeddings : bool = True ,
6262 use_fp16 : bool = True ,
63+ use_bf16 : bool = False ,
6364 query_instruction_for_retrieval : Optional [str ] = None ,
6465 query_instruction_format : str = "Instruct: {}\n Query: {}" , # specify the format of query_instruction_for_retrieval
6566 devices : Optional [Union [str , List [str ]]] = None , # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
@@ -77,6 +78,7 @@ def __init__(
7778 model_name_or_path ,
7879 normalize_embeddings = normalize_embeddings ,
7980 use_fp16 = use_fp16 ,
81+ use_bf16 = use_bf16 ,
8082 query_instruction_for_retrieval = query_instruction_for_retrieval ,
8183 query_instruction_format = query_instruction_format ,
8284 devices = devices ,
@@ -95,7 +97,8 @@ def __init__(
9597 self .model = AutoModel .from_pretrained (
9698 model_name_or_path ,
9799 trust_remote_code = trust_remote_code ,
98- cache_dir = cache_dir
100+ cache_dir = cache_dir ,
101+ dtype = self .get_model_torch_dtype (),
99102 )
100103
101104 if self .kwargs .get ("pooling_method" , "last_token" ) != "last_token" :
@@ -211,8 +214,8 @@ def encode_single_device(
211214 if device is None :
212215 device = self .target_devices [0 ]
213216
214- if device == "cpu" : self . use_fp16 = False
215- if self . use_fp16 : self .model .half ()
217+ if device == "cpu" :
218+ self .model .float ()
216219
217220 self .model .to (device )
218221 self .model .eval ()
0 commit comments