Skip to content

Commit e200a3b

Browse files
committed
fix usage of acclerate
1 parent 641b452 commit e200a3b

16 files changed

Lines changed: 542 additions & 175 deletions

examples/llama_acc.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
set -ex
33

44
# FSDP
5-
./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4
5+
./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 --use_flash_attn
66

77
# TP
88
# ./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 24 --tp 4

flashmodels/accelerators/acc_baichuan_accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader):
1717
raise NotImplementedError("resume_from_checkpoint.")
1818

1919
config = self.get_config(model)
20-
model = ta.accelerate(model, config)
20+
model = ta.accelerate(model, config=config)
2121
return model, loader
2222

2323
def get_config(self, model):

flashmodels/accelerators/acc_gemma_accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def accelerate(self, model, loader):
1212

1313
def accelerate_internal(self, model, loader):
1414
config = self.get_config()
15-
model = ta.accelerate(model, config)
15+
model = ta.accelerate(model, config=config)
1616
return model, loader
1717

1818
def get_config(self):

flashmodels/accelerators/acc_glm_accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader):
1717
raise NotImplementedError("resume_from_checkpoint.")
1818

1919
config = self.get_config(model)
20-
model = ta.accelerate(model, config)
20+
model = ta.accelerate(model, config=config)
2121
return model, loader
2222

2323
def get_config(self, model):

flashmodels/accelerators/acc_gpt_accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def accelerate_internal(self, model, loader):
2020
raise NotImplementedError("resume_from_checkpoint.")
2121

2222
config = self.get_config(model)
23-
model = ta.accelerate(model, config)
23+
model = ta.accelerate(model, config=config)
2424
return model, loader
2525

2626
device = lazy_device()

flashmodels/accelerators/acc_llama_accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def accelerate_internal(self, model, loader):
9999
self.args.sp)
100100

101101
config = self.get_config(model)
102-
model = ta.accelerate(model, config)
102+
model = ta.accelerate(model, config=config)
103103

104104
if self.args.tp_num > 1 and self.args.pp_num > 1:
105105
self.parallel_3d(model._get_underlay_model())

flashmodels/accelerators/acc_olmo_accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader):
1717
raise NotImplementedError("resume_from_checkpoint.")
1818

1919
config = self.get_config(model)
20-
model = ta.accelerate(model, config)
20+
model = ta.accelerate(model, config=config)
2121
return model, loader
2222
else:
2323
raise NotImplementedError("Currently, only FSDP is supported.")

flashmodels/accelerators/acc_qwen_accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def accelerate_internal(self, model, loader):
3737
raise NotImplementedError("resume_from_checkpoint.")
3838

3939
config = self.get_config(model)
40-
model = ta.accelerate(model, config)
40+
model = ta.accelerate(model, config=config)
4141
return model, loader
4242

4343
def get_config(self, model):

flashmodels/accelerators/cuda_llama_accelerator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
FullyShardedDataParallel as FSDP
1414
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
1515
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
16+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
1617

1718
from flashmodels.accelerators.accelerator import (Accelerator,
1819
AcceleratorFactory)
@@ -70,7 +71,7 @@ def apply_checkpointing(self, model):
7071
checkpoint_wrapper,
7172
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
7273
)
73-
check_fn = lambda submodule: isinstance(submodule, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
74+
check_fn = lambda submodule: isinstance(LlamaDecoderLayer)
7475
apply_activation_checkpointing(
7576
model,
7677
checkpoint_wrapper_fn=non_reentrant_wrapper,
@@ -96,9 +97,7 @@ def fsdp(self, model):
9697
convert_outputs_to_fp32(model.forward.__func__), model)
9798

9899
# Use auto_wrap_poliy for nested wrapping instead of only a top-level FSDP.
99-
auto_wrap_policy = ModuleWrapPolicy({
100-
transformers.models.llama.modeling_llama.LlamaDecoderLayer,
101-
})
100+
auto_wrap_policy = ModuleWrapPolicy({LlamaDecoderLayer, })
102101

103102
mixed_precision_policy = None
104103
if self.args.fp16 or self.args.bf16:

0 commit comments

Comments
 (0)