Skip to content

Commit ac7a274

Browse files
authored
Merge pull request #1566 from lnxtree/feat-fix-embedder_use_bf16
Feat(embedder) use bf16 and fix the interface of attn_implementation in embedder.decode_only
2 parents dbc6005 + 19edba7 commit ac7a274

File tree

11 files changed

+51
-25
lines changed

11 files changed

+51
-25
lines changed

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
model_name_or_path: str,
5050
normalize_embeddings: bool = True,
5151
use_fp16: bool = True,
52+
use_bf16: bool = False,
5253
query_instruction_for_retrieval: Optional[str] = None,
5354
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
5455
devices: Optional[Union[str, int, List[str], List[int]]] = None,
@@ -62,6 +63,7 @@ def __init__(
6263
self.model_name_or_path = model_name_or_path
6364
self.normalize_embeddings = normalize_embeddings
6465
self.use_fp16 = use_fp16
66+
self.use_bf16 = use_bf16
6567
self.query_instruction_for_retrieval = query_instruction_for_retrieval
6668
self.query_instruction_format = query_instruction_format
6769
self.target_devices = self.get_target_devices(devices)
@@ -81,6 +83,13 @@ def __init__(
8183
self.model = None
8284
self.pool = None
8385

86+
def get_model_torch_dtype(self) -> torch.dtype:
87+
if self.use_bf16:
88+
return torch.bfloat16
89+
if self.use_fp16:
90+
return torch.float16
91+
return torch.float32
92+
8493
def stop_self_pool(self):
8594
if self.pool is not None:
8695
self.stop_multi_process_pool(self.pool)

FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
7171
model = AutoModel.from_pretrained(
7272
model_args.model_name_or_path,
7373
# torch_dtype=torch.bfloat16,
74-
use_flash_attention_2=True if model_args.use_flash_attn else False,
74+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
7575
token=model_args.token,
7676
cache_dir=model_args.cache_dir,
7777
from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -152,7 +152,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
152152
model = AutoModel.from_pretrained(
153153
model_args.model_name_or_path,
154154
# torch_dtype=torch.bfloat16,
155-
use_flash_attention_2=True if model_args.use_flash_attn else False,
155+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
156156
token=model_args.token,
157157
cache_dir=model_args.cache_dir,
158158
from_tf=bool(".ckpt" in model_args.model_name_or_path),

FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str,
7171
model = AutoModel.from_pretrained(
7272
model_args.model_name_or_path,
7373
# torch_dtype=torch.bfloat16,
74-
use_flash_attention_2=True if model_args.use_flash_attn else False,
74+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
7575
token=model_args.token,
7676
cache_dir=model_args.cache_dir,
7777
from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -150,7 +150,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_d
150150
model = AutoModel.from_pretrained(
151151
model_args.model_name_or_path,
152152
# torch_dtype=torch.bfloat16,
153-
use_flash_attention_2=True if model_args.use_flash_attn else False,
153+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
154154
token=model_args.token,
155155
cache_dir=model_args.cache_dir,
156156
from_tf=bool(".ckpt" in model_args.model_name_or_path),

FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import torch
33
import logging
4-
from typing import Tuple
4+
from typing import Tuple, Optional
55
from transformers import (
66
AutoModel, AutoConfig,
77
AutoTokenizer, PreTrainedTokenizer
@@ -44,7 +44,8 @@ def get_model(
4444
model_name_or_path: str,
4545
trust_remote_code: bool = False,
4646
colbert_dim: int = -1,
47-
cache_dir: str = None
47+
cache_dir: str = None,
48+
torch_dtype: Optional[torch.dtype] = None,
4849
):
4950
"""Get the model.
5051
@@ -54,6 +55,7 @@ def get_model(
5455
trust_remote_code (bool, optional): trust_remote_code to use when loading models from HF. Defaults to ``False``.
5556
colbert_dim (int, optional): Colbert dim to set. Defaults to ``-1``.
5657
cache_dir (str, optional): HF cache dir to store the model. Defaults to ``None``.
58+
torch_dtype (Optional[torch.dtype], optional): Torch dtype used when loading model weights. Defaults to ``None``.
5759
5860
Returns:
5961
dict: A dictionary containing the model, colbert linear and sparse linear.
@@ -69,7 +71,8 @@ def get_model(
6971
model = AutoModel.from_pretrained(
7072
model_name_or_path,
7173
cache_dir=cache_folder,
72-
trust_remote_code=trust_remote_code
74+
trust_remote_code=trust_remote_code,
75+
dtype=torch_dtype,
7376
)
7477
colbert_linear = torch.nn.Linear(
7578
in_features=model.config.hidden_size,

FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def get_model(model_args: RerankerModelArguments):
6767
model = AutoModelForCausalLM.from_pretrained(
6868
model_args.model_name_or_path,
6969
# torch_dtype=torch.bfloat16,
70-
use_flash_attention_2=True if model_args.use_flash_attn else False,
70+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
7171
token=model_args.token,
7272
cache_dir=model_args.cache_dir,
7373
from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -135,7 +135,7 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
135135
model = AutoModelForCausalLM.from_pretrained(
136136
model_args.model_name_or_path,
137137
# torch_dtype=torch.bfloat16,
138-
use_flash_attention_2=True if model_args.use_flash_attn else False,
138+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
139139
token=model_args.token,
140140
cache_dir=model_args.cache_dir,
141141
from_tf=bool(".ckpt" in model_args.model_name_or_path),

FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
7777
model_args.model_name_or_path,
7878
trust_remote_code=model_args.trust_remote_code,
7979
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
80-
use_flash_attention_2=True if model_args.use_flash_attn else False,
80+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
8181
token=model_args.token,
8282
cache_dir=model_args.cache_dir,
8383
from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -131,7 +131,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
131131
model = LayerWiseMiniCPMForCausalLM.from_pretrained(
132132
model_args.model_name_or_path,
133133
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
134-
use_flash_attention_2=True if model_args.use_flash_attn else False,
134+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
135135
token=model_args.token,
136136
cache_dir=model_args.cache_dir,
137137
from_tf=bool(".ckpt" in model_args.model_name_or_path),

FlagEmbedding/inference/auto_embedder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def from_finetuned(
2626
model_class: Optional[Union[str, EmbedderModelClass]] = None,
2727
normalize_embeddings: bool = True,
2828
use_fp16: bool = True,
29+
use_bf16: bool = False,
2930
query_instruction_for_retrieval: Optional[str] = None,
3031
devices: Optional[Union[str, List[str]]] = None,
3132
pooling_method: Optional[str] = None,
@@ -102,6 +103,7 @@ def from_finetuned(
102103
model_name_or_path,
103104
normalize_embeddings=normalize_embeddings,
104105
use_fp16=use_fp16,
106+
use_bf16=use_bf16,
105107
query_instruction_for_retrieval=query_instruction_for_retrieval,
106108
query_instruction_format=query_instruction_format,
107109
devices=devices,

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
model_name_or_path: str,
6161
normalize_embeddings: bool = True,
6262
use_fp16: bool = True,
63+
use_bf16: bool = False,
6364
query_instruction_for_retrieval: Optional[str] = None,
6465
query_instruction_format: str = "Instruct: {}\nQuery: {}", # specify the format of query_instruction_for_retrieval
6566
devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
@@ -77,6 +78,7 @@ def __init__(
7778
model_name_or_path,
7879
normalize_embeddings=normalize_embeddings,
7980
use_fp16=use_fp16,
81+
use_bf16=use_bf16,
8082
query_instruction_for_retrieval=query_instruction_for_retrieval,
8183
query_instruction_format=query_instruction_format,
8284
devices=devices,
@@ -95,7 +97,8 @@ def __init__(
9597
self.model = AutoModel.from_pretrained(
9698
model_name_or_path,
9799
trust_remote_code=trust_remote_code,
98-
cache_dir=cache_dir
100+
cache_dir=cache_dir,
101+
dtype=self.get_model_torch_dtype(),
99102
)
100103

101104
if self.kwargs.get("pooling_method", "last_token") != "last_token":
@@ -211,8 +214,8 @@ def encode_single_device(
211214
if device is None:
212215
device = self.target_devices[0]
213216

214-
if device == "cpu": self.use_fp16 = False
215-
if self.use_fp16: self.model.half()
217+
if device == "cpu":
218+
self.model.float()
216219

217220
self.model.to(device)
218221
self.model.eval()

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
model_name_or_path: str,
6969
normalize_embeddings: bool = True,
7070
use_fp16: bool = True,
71+
use_bf16: bool = False,
7172
query_instruction_for_retrieval: Optional[str] = None,
7273
query_instruction_format: str = "<instruct>{}\n<query>{}", # specify the format of query_instruction_for_retrieval
7374
suffix: str = '\n<response>',
@@ -90,6 +91,7 @@ def __init__(
9091
model_name_or_path,
9192
normalize_embeddings=normalize_embeddings,
9293
use_fp16=use_fp16,
94+
use_bf16=use_bf16,
9395
query_instruction_for_retrieval=query_instruction_for_retrieval,
9496
query_instruction_format=query_instruction_format,
9597
devices=devices,
@@ -108,7 +110,8 @@ def __init__(
108110
self.model = AutoModel.from_pretrained(
109111
model_name_or_path,
110112
trust_remote_code=trust_remote_code,
111-
cache_dir=cache_dir
113+
cache_dir=cache_dir,
114+
torch_dtype=self.get_model_torch_dtype(),
112115
)
113116
self.examples_for_task = examples_for_task
114117
self.examples_instruction_format = examples_instruction_format
@@ -340,8 +343,8 @@ def encode_queries_single_device(
340343
if device is None:
341344
device = self.target_devices[0]
342345

343-
if device == "cpu": self.use_fp16 = False
344-
if self.use_fp16: self.model.half()
346+
if device == "cpu":
347+
self.model.float()
345348

346349
self.model.to(device)
347350
self.model.eval()

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
model_name_or_path: str,
4343
normalize_embeddings: bool = True,
4444
use_fp16: bool = True,
45+
use_bf16: bool = False,
4546
query_instruction_for_retrieval: Optional[str] = None,
4647
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
4748
devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
@@ -60,6 +61,7 @@ def __init__(
6061
model_name_or_path,
6162
normalize_embeddings=normalize_embeddings,
6263
use_fp16=use_fp16,
64+
use_bf16=use_bf16,
6365
query_instruction_for_retrieval=query_instruction_for_retrieval,
6466
query_instruction_format=query_instruction_format,
6567
devices=devices,
@@ -79,7 +81,8 @@ def __init__(
7981
self.model = AutoModel.from_pretrained(
8082
model_name_or_path,
8183
trust_remote_code=trust_remote_code,
82-
cache_dir=cache_dir
84+
cache_dir=cache_dir,
85+
dtype=self.get_model_torch_dtype(),
8386
)
8487

8588
def encode_queries(
@@ -192,8 +195,8 @@ def encode_single_device(
192195
if device is None:
193196
device = self.target_devices[0]
194197

195-
if device == "cpu": self.use_fp16 = False
196-
if self.use_fp16: self.model.half()
198+
if device == "cpu":
199+
self.model.float()
197200

198201
self.model.to(device)
199202
self.model.eval()

0 commit comments

Comments
 (0)