Skip to content

Commit 19edba7

Browse files
committed
feat: fix the interface of attn_implementation in embedder.decode_only.*.load_model and reranker.decode_only.*.load_model
1 parent 6679caa commit 19edba7

4 files changed

Lines changed: 8 additions & 8 deletions

File tree

FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
7171
model = AutoModel.from_pretrained(
7272
model_args.model_name_or_path,
7373
# torch_dtype=torch.bfloat16,
74-
use_flash_attention_2=True if model_args.use_flash_attn else False,
74+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
7575
token=model_args.token,
7676
cache_dir=model_args.cache_dir,
7777
from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -152,7 +152,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
152152
model = AutoModel.from_pretrained(
153153
model_args.model_name_or_path,
154154
# torch_dtype=torch.bfloat16,
155-
use_flash_attention_2=True if model_args.use_flash_attn else False,
155+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
156156
token=model_args.token,
157157
cache_dir=model_args.cache_dir,
158158
from_tf=bool(".ckpt" in model_args.model_name_or_path),

FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str,
7171
model = AutoModel.from_pretrained(
7272
model_args.model_name_or_path,
7373
# torch_dtype=torch.bfloat16,
74-
use_flash_attention_2=True if model_args.use_flash_attn else False,
74+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
7575
token=model_args.token,
7676
cache_dir=model_args.cache_dir,
7777
from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -150,7 +150,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_d
150150
model = AutoModel.from_pretrained(
151151
model_args.model_name_or_path,
152152
# torch_dtype=torch.bfloat16,
153-
use_flash_attention_2=True if model_args.use_flash_attn else False,
153+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
154154
token=model_args.token,
155155
cache_dir=model_args.cache_dir,
156156
from_tf=bool(".ckpt" in model_args.model_name_or_path),

FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def get_model(model_args: RerankerModelArguments):
6767
model = AutoModelForCausalLM.from_pretrained(
6868
model_args.model_name_or_path,
6969
# torch_dtype=torch.bfloat16,
70-
use_flash_attention_2=True if model_args.use_flash_attn else False,
70+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
7171
token=model_args.token,
7272
cache_dir=model_args.cache_dir,
7373
from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -135,7 +135,7 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
135135
model = AutoModelForCausalLM.from_pretrained(
136136
model_args.model_name_or_path,
137137
# torch_dtype=torch.bfloat16,
138-
use_flash_attention_2=True if model_args.use_flash_attn else False,
138+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
139139
token=model_args.token,
140140
cache_dir=model_args.cache_dir,
141141
from_tf=bool(".ckpt" in model_args.model_name_or_path),

FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
7777
model_args.model_name_or_path,
7878
trust_remote_code=model_args.trust_remote_code,
7979
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
80-
use_flash_attention_2=True if model_args.use_flash_attn else False,
80+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
8181
token=model_args.token,
8282
cache_dir=model_args.cache_dir,
8383
from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -131,7 +131,7 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
131131
model = LayerWiseMiniCPMForCausalLM.from_pretrained(
132132
model_args.model_name_or_path,
133133
# torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
134-
use_flash_attention_2=True if model_args.use_flash_attn else False,
134+
attn_implementation = "flash_attention_2" if model_args.use_flash_attn else None,
135135
token=model_args.token,
136136
cache_dir=model_args.cache_dir,
137137
from_tf=bool(".ckpt" in model_args.model_name_or_path),

0 commit comments

Comments
 (0)