Skip to content

Commit 1c78d9e

Browse files
committed
Clean and fin
1 parent 4d1e285 commit 1c78d9e

5 files changed

Lines changed: 95 additions & 76 deletions

File tree

benchmarks/fp8/ms_amp/Dockerfile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
FROM ghcr.io/azure/msamp
22

33
RUN pip install transformers evaluate datasets
4-
# RUN git clone https://github.com/huggingface/accelerate
4+
RUN git clone https://github.com/huggingface/accelerate
55

6-
# RUN cd accelerate && \
7-
# pip install -e . && \
8-
# cd benchmarks/fp8
6+
RUN cd accelerate && \
7+
pip install -e . && \
8+
cd benchmarks/fp8
99

1010
CMD ["bash"]
1111

benchmarks/fp8/ms_amp/distrib_deepspeed.py

Lines changed: 28 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,27 @@
1616
This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`.
1717
1818
This particular script verifies this for DeepSpeed training.
19+
20+
NOTE: MS-AMP does *not* support ZeRO-3.
1921
"""
20-
from unittest.mock import patch
2122

22-
from msamp import deepspeed
23+
# import msamp.deepspeed as msamp_deepspeed
2324
import evaluate
2425
import torch
25-
# import transformer_engine.common.recipe as te_recipe
26-
# import transformer_engine.pytorch as te
2726
from fp8_utils import evaluate_model, get_training_utilities
28-
# from transformer_engine.common.recipe import DelayedScaling
27+
from msamp import deepspeed as msamp_deepspeed
2928

3029
from accelerate import Accelerator, DeepSpeedPlugin
3130
from accelerate.state import AcceleratorState
32-
from accelerate.utils import FP8RecipeKwargs, set_seed
31+
from accelerate.utils import set_seed
3332

3433

3534
MODEL_NAME = "bert-base-cased"
3635
METRIC = evaluate.load("glue", "mrpc")
3736

3837

3938
def train_baseline(zero_stage: int = 1, opt_level: str = "O1"):
40-
# This forces transformers to think Zero-3 Init should be used
41-
with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock:
42-
mock.return_value = zero_stage == 3
4339
set_seed(42)
44-
4540
accelerator = Accelerator()
4641
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
4742
MODEL_NAME, accelerator=accelerator
@@ -57,7 +52,6 @@ def train_baseline(zero_stage: int = 1, opt_level: str = "O1"):
5752
"stage": zero_stage,
5853
"offload_optimizer": {"device": "none", "nvme_path": None},
5954
"offload_param": {"device": "none", "nvme_path": None},
60-
"stage3_gather_16bit_weights_on_model_save": False,
6155
},
6256
"gradient_clipping": 1.0,
6357
"steps_per_print": np.inf,
@@ -67,15 +61,14 @@ def train_baseline(zero_stage: int = 1, opt_level: str = "O1"):
6761
"msamp": {
6862
"enabled": True,
6963
"opt_level": opt_level,
70-
}
64+
},
7165
}
72-
7366
(
7467
model,
7568
optimizer,
7669
_,
7770
_,
78-
) = deepspeed.initialize(
71+
) = msamp_deepspeed.initialize(
7972
model=model,
8073
optimizer=optimizer,
8174
config_params=config,
@@ -107,18 +100,14 @@ def train_baseline(zero_stage: int = 1, opt_level: str = "O1"):
107100
return base_model_results, trained_model_results
108101

109102

110-
def train_integration(zero_stage: int = 1):
103+
def train_integration(zero_stage: int = 1, opt_level: str = "O1"):
111104
set_seed(42)
112-
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
113-
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
114-
AcceleratorState()._reset_state(True)
115105
deepspeed_plugin = DeepSpeedPlugin(
116106
zero_stage=zero_stage,
117-
zero3_init_flag=zero_stage == 3,
118-
)
119-
accelerator = Accelerator(
120-
mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin
107+
enable_msamp=True,
108+
msamp_opt_level=opt_level,
121109
)
110+
accelerator = Accelerator(mixed_precision="fp8", deepspeed_plugin=deepspeed_plugin)
122111
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16
123112

124113
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
@@ -128,13 +117,9 @@ def train_integration(zero_stage: int = 1):
128117
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
129118
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
130119
model.train()
131-
model_outputs = []
132-
data = []
133120
for _ in range(2):
134121
for batch in train_dataloader:
135122
outputs = model(**batch)
136-
data.append(batch.to("cpu"))
137-
model_outputs.append(outputs.logits.to("cpu"))
138123
loss = outputs.loss
139124
accelerator.backward(loss)
140125
optimizer.step()
@@ -151,32 +136,26 @@ def train_integration(zero_stage: int = 1):
151136
trained_model_results["f1"] > base_model_results["f1"]
152137
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'
153138

139+
AcceleratorState()._reset_state(True)
154140
return base_model_results, trained_model_results
155141

156142

157143
if __name__ == "__main__":
158-
# results = {"1": [], "2": [], "3": []}
159-
# for zero_stage in [1, 2, 3]:
160-
# for opt_level in ["O1", "O2", "O3"]:
161-
baseline_not_trained, baseline_trained = train_baseline(3, "O3")
162-
print(baseline_not_trained, baseline_trained)
163-
# results[str(zero_stage)].append({"opt_level": opt_level, "not_trained": baseline_not_trained, "trained": baseline_trained})
164-
# for stage, stage_results in results.items():
165-
# print(f'zero_stage={stage}:\n')
166-
# for result in stage_results:
167-
# print(f'opt_level={result["opt_level"]}:\nBaseline not trained: {result["not_trained"]}\nBaseline trained: {result["trained"]}\n')
168-
# accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage)
169-
# assert (
170-
# baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
171-
# ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
172-
# assert (
173-
# baseline_not_trained["f1"] == accelerator_not_trained["f1"]
174-
# ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
175-
# assert (
176-
# baseline_trained["accuracy"] == accelerator_trained["accuracy"]
177-
# ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
178-
# assert (
179-
# baseline_trained["f1"] == accelerator_trained["f1"]
180-
# ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
144+
for zero_stage in [1, 2]:
145+
for opt_level in ["O1", "O2", "O3"]:
146+
baseline_not_trained, baseline_trained = train_baseline(zero_stage, opt_level)
147+
accelerator_not_trained, accelerator_trained = train_integration(zero_stage, opt_level)
148+
assert (
149+
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
150+
), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nAccuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
151+
assert (
152+
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
153+
), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nF1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
154+
assert (
155+
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
156+
), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nAccuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
157+
assert (
158+
baseline_trained["f1"] == accelerator_trained["f1"]
159+
), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nF1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
181160

182161
torch.distributed.destroy_process_group()

benchmarks/fp8/ms_amp/fsdp.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,21 @@
1717
1818
This particular script verifies this for FSDP training.
1919
"""
20+
from functools import partial
21+
2022
import evaluate
2123
import msamp
2224
import torch
23-
from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin
25+
from fp8_utils import evaluate_model, get_training_utilities
26+
from msamp.common.dtype import Dtypes
2427
from msamp.fsdp import FP8FullyShardedDataParallel
2528
from msamp.optim import FSDPAdamW
26-
from msamp.common.dtype import Dtypes
27-
from fp8_utils import evaluate_model, get_training_utilities
29+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
30+
from transformers.models.bert import BertLayer
2831

2932
from accelerate import Accelerator
30-
from accelerate.state import AcceleratorState
33+
from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin
3134
from accelerate.utils import FP8RecipeKwargs, set_seed
32-
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
33-
from transformers.models.bert import BertLayer
34-
from functools import partial
3535

3636

3737
MODEL_NAME = "bert-base-cased"
@@ -45,10 +45,7 @@ def train_baseline(opt_level="O2"):
4545
accelerator = Accelerator()
4646
device = accelerator.device
4747
model, optimizer = msamp.initialize(
48-
model, optimizer,
49-
opt_level=opt_level,
50-
weight_qtype=Dtypes.kfloat8_e4m3,
51-
use_fsdp=True
48+
model, optimizer, opt_level=opt_level, weight_qtype=Dtypes.kfloat8_e4m3, use_fsdp=True
5249
)
5350

5451
model = FP8FullyShardedDataParallel(
@@ -60,7 +57,7 @@ def train_baseline(opt_level="O2"):
6057
backward_prefetch=None,
6158
forward_prefetch=False,
6259
limit_all_gathers=True,
63-
device_id=device
60+
device_id=device,
6461
)
6562
optimizer = FSDPAdamW(optimizer)
6663

@@ -78,8 +75,8 @@ def train_baseline(opt_level="O2"):
7875

7976
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
8077

81-
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = (
82-
accelerator.free_memory(model, optimizer, train_dataloader, eval_dataloader, lr_scheduler)
78+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.free_memory(
79+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
8380
)
8481
assert (
8582
trained_model_results["accuracy"] > base_model_results["accuracy"]
@@ -118,8 +115,8 @@ def train_integration(opt_level="O2"):
118115

119116
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
120117

121-
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = (
122-
accelerator.free_memory(model, optimizer, train_dataloader, eval_dataloader, lr_scheduler)
118+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.free_memory(
119+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
123120
)
124121
assert (
125122
trained_model_results["accuracy"] > base_model_results["accuracy"]

src/accelerate/accelerator.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
save_fsdp_optimizer,
105105
wait_for_everyone,
106106
)
107-
from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME
107+
from .utils.constants import PROFILE_PATTERN_NAME
108108
from .utils.modeling import get_state_dict_offloaded_model
109109
from .utils.other import is_compiled_module
110110

@@ -310,8 +310,8 @@ def __init__(
310310
# if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance(
311311
# fsdp_plugin, FullyShardedDataParallelPlugin
312312
# ):
313-
# if is_torch_version("<", FSDP_PYTORCH_VERSION):
314-
# raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
313+
# if is_torch_version("<", FSDP_PYTORCH_VERSION):
314+
# raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
315315

316316
if fsdp_plugin is None: # init from env variables
317317
fsdp_plugin = (
@@ -507,8 +507,11 @@ def __init__(
507507
elif self.state.mixed_precision == "fp8":
508508
# We always enable `native_amp` for FP8
509509
self.native_amp = True
510-
# MS-AMP requires grad scaler however
511-
if self.fp8_backend == "MSAMP" and self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED):
510+
# MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP:
511+
if self.fp8_backend == "MSAMP" and self.distributed_type not in (
512+
DistributedType.FSDP,
513+
DistributedType.DEEPSPEED,
514+
):
512515
self.scaler = torch.cuda.amp.GradScaler()
513516

514517
# Start of internal step tracking
@@ -1336,6 +1339,7 @@ def prepare(self, *args, device_placement=None):
13361339
# We need to convert the underlying optimizer to FSDPAdamW *after* FSDP wrapping
13371340
result = list(result)
13381341
from msamp.optim import FSDPAdamW
1342+
13391343
for i, obj in enumerate(result):
13401344
if isinstance(obj, AcceleratedOptimizer):
13411345
result[i].optimizer = FSDPAdamW(optimizer=obj.optimizer)
@@ -1636,6 +1640,13 @@ def _prepare_te(self, *args, device=None):
16361640
def _prepare_deepspeed(self, *args):
16371641
import deepspeed
16381642

1643+
ds_initialize = deepspeed.initialize
1644+
if self.fp8_backend == "MSAMP":
1645+
# MS-AMP requires DeepSpeed patches
1646+
from msamp import deepspeed as msamp_deepspeed
1647+
1648+
ds_initialize = msamp_deepspeed.initialize
1649+
16391650
deepspeed_plugin = self.state.deepspeed_plugin
16401651

16411652
is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)
@@ -1824,7 +1835,7 @@ def _prepare_deepspeed(self, *args):
18241835
if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES:
18251836
kwargs["lr_scheduler"] = scheduler
18261837

1827-
engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
1838+
engine, optimizer, _, lr_scheduler = ds_initialize(**kwargs)
18281839
if optimizer is not None:
18291840
optimizer = DeepSpeedOptimizerWrapper(optimizer)
18301841
if scheduler is not None:
@@ -2037,8 +2048,10 @@ def _prepare_msamp(self, *args, device_placement):
20372048
# NOTE: MS-AMP fsdp relies on it's own MP policy, we must drop the users
20382049
self.state.fsdp_plugin.mixed_precision_policy = None
20392050
from msamp.common.dtype import Dtypes
2051+
20402052
model, optimizer = msamp.initialize(
2041-
model, optimizer,
2053+
model,
2054+
optimizer,
20422055
opt_level=self.fp8_recipe_handler.opt_level,
20432056
use_fsdp=self.distributed_type == DistributedType.FSDP,
20442057
weight_qtype=Dtypes.kfloat8_e4m3,
@@ -3595,4 +3608,6 @@ def fp8_backend(self):
35953608
"Returns the configured backend for training in FP8"
35963609
if self.mixed_precision == "fp8" and self.fp8_recipe_handler is not None:
35973610
return self.fp8_recipe_handler.backend
3611+
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
3612+
return "MSAMP"
35983613
return None

src/accelerate/utils/dataclasses.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,16 @@ class DeepSpeedPlugin:
972972
" `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..."
973973
},
974974
)
975+
enable_msamp: bool = field(
976+
default=None,
977+
metadata={"help": "Flag to indicate whether to enable MS-AMP backend for FP8 training."},
978+
)
979+
msamp_opt_level: str = field(
980+
default=None,
981+
metadata={
982+
"help": "Optimization level for MS-AMP. Only applicable if `enable_msamp` is True. Should be one of ['O1', 'O2', 'O3']."
983+
},
984+
)
975985

976986
def __post_init__(self):
977987
from .deepspeed import HfDeepSpeedConfig
@@ -1006,6 +1016,12 @@ def __post_init__(self):
10061016
os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false") == "true"
10071017
)
10081018

1019+
if self.enable_msamp is None:
1020+
self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP"
1021+
1022+
if self.msamp_opt_level is None:
1023+
self.msamp_opt_level = os.environ.get("ACCELERATE_FP8_OPT_LEVEL", "O1")
1024+
10091025
if self.hf_ds_config is None:
10101026
self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none")
10111027
if (
@@ -1075,6 +1091,14 @@ def __post_init__(self):
10751091
if self.zero3_init_flag and not self.hf_ds_config.is_zero3():
10761092
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
10771093
self.zero3_init_flag = False
1094+
if self.enable_msamp:
1095+
if self.zero_stage == 3:
1096+
raise NotImplementedError(
1097+
"MS-AMP is not supported for ZeRO Stage 3. Please use ZeRO Stage 0, 1, or 2 instead."
1098+
)
1099+
if self.msamp_opt_level not in ["O1", "O2", "O3"]:
1100+
raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1', 'O2', 'O3'].")
1101+
self.deepspeed_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
10781102

10791103
def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs):
10801104
mismatches = [] if mismatches is None else mismatches
@@ -1144,6 +1168,10 @@ def set_mixed_precision(self, mixed_precision):
11441168
if "bf16" not in ds_config:
11451169
ds_config["bf16"] = {"enabled": True}
11461170

1171+
if mixed_precision == "fp8" and self.enable_msamp:
1172+
if "msamp" not in ds_config:
1173+
ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
1174+
11471175
if mixed_precision != "no":
11481176
diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16"
11491177
if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true":

0 commit comments

Comments
 (0)