Skip to content

Commit c157a41

Browse files
authored
add stop+go tests to llama3 recipe, turn off async checkpointing for fp8 (#1494)
async dcp checkpointing is currently not working with fp8 model init, so we need to detect this and switch back to synchronous checkpointing. This also adds tests to ensure the dcp checkpoints are functional Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 493fb3f commit c157a41

8 files changed

Lines changed: 459 additions & 872 deletions

File tree

bionemo-recipes/recipes/llama3_native_te/checkpoint.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save
3535
from torch.distributed.checkpoint.state_dict_saver import save as dcp_save
3636
from torch.distributed.checkpoint.stateful import Stateful
37+
from torch.distributed.tensor import DTensor
3738
from torchdata.stateful_dataloader import StatefulDataLoader
39+
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
3840

3941
from distributed_config import DistributedConfig
4042

@@ -115,8 +117,20 @@ def load_checkpoint_ddp(
115117
ckpt_path: str | os.PathLike,
116118
dist_config: DistributedConfig,
117119
dataloader: StatefulDataLoader | None = None,
120+
weights_only: bool = True,
118121
) -> CheckpointOutput:
119-
"""Load DDP checkpoint."""
122+
"""Load DDP checkpoint.
123+
124+
Args:
125+
model: The model to load.
126+
optimizer: The optimizer to load.
127+
scheduler: The LR scheduler to load.
128+
ckpt_path: The path to the checkpoint.
129+
dist_config: The distributed configuration.
130+
dataloader: The dataloader to load.
131+
weights_only: Whether to load the checkpoint weights only. We have to set this to True when loading FP8
132+
checkpoints.
133+
"""
120134
checkpoint_path, _ = get_latest_checkpoint(ckpt_path)
121135

122136
if not checkpoint_path:
@@ -126,7 +140,7 @@ def load_checkpoint_ddp(
126140
checkpoint = torch.load(
127141
checkpoint_path / "checkpoint.pt",
128142
map_location=f"cuda:{dist_config.local_rank}",
129-
weights_only=True,
143+
weights_only=weights_only,
130144
)
131145

132146
model.load_state_dict(checkpoint["model"])
@@ -221,6 +235,7 @@ class AppState(Stateful):
221235
def state_dict(self):
222236
"""Get the state dict for the model, optimizer, scheduler, and step."""
223237
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
238+
model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")}
224239
return {
225240
"model": model_state_dict,
226241
"optim": optimizer_state_dict,
@@ -236,6 +251,7 @@ def load_state_dict(self, state_dict: dict):
236251
self.optimizer,
237252
model_state_dict=state_dict["model"],
238253
optim_state_dict=state_dict["optim"],
254+
options=StateDictOptions(strict=False),
239255
)
240256
self.scheduler.load_state_dict(state_dict["scheduler"])
241257
self.step = state_dict["step"]
@@ -322,6 +338,13 @@ def save_checkpoint_fsdp2(
322338
checkpoint_path = ckpt_path / f"step_{step}"
323339
checkpoint_path.mkdir(parents=True, exist_ok=True)
324340

341+
model_params = (p.to_local() if isinstance(p, DTensor) else p for p in model.parameters())
342+
if async_save and any((isinstance(p, QuantizedTensor) for p in model_params)):
343+
logger.warning(
344+
"Async checkpointing is not supported for FP8 models, falling back to synchronous checkpointing."
345+
)
346+
async_save = False
347+
325348
if dataloader is not None:
326349
save_dataloader(
327350
dataloader=dataloader,

bionemo-recipes/recipes/llama3_native_te/perf_logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class PerfLogger:
4444
min_loss: The minimum loss seen so far.
4545
"""
4646

47-
def __init__(self, dist_config: DistributedConfig, args: DictConfig):
47+
def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: int):
4848
"""Initialize the logger."""
4949
self._dist_config = dist_config
5050
self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)
@@ -75,7 +75,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
7575
if self._dist_config.is_main_process():
7676
# Log the entire args object to wandb for experiment tracking and reproducibility.
7777
self._wandb_run = wandb.init(**args.wandb, config=self._run_config)
78-
self._progress_bar = tqdm(total=args.num_train_steps, desc="Training")
78+
self._progress_bar = tqdm(initial=start_step, total=args.num_train_steps, desc="Training")
7979

8080
if args.profiler.enabled:
8181
self._profiler = NsightProfiler(

bionemo-recipes/recipes/llama3_native_te/tests/conftest.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pytest
2121
import torch
22+
from transformer_engine.pytorch import fp8 as te_fp8
2223

2324

2425
sys.path.append(Path(__file__).parent.parent.as_posix())
@@ -61,6 +62,56 @@ def pytest_collection_modifyitems(items):
6162
items[:] = stats_tests + other_tests
6263

6364

65+
# ---------------------------------------------------------------------------
66+
# FP8 recipe parametrization
67+
# ---------------------------------------------------------------------------
68+
69+
# Each entry: (recipe_class_name, hydra_overrides, check_fn)
70+
_FP8_RECIPE_CONFIGS = [
71+
(
72+
"DelayedScaling",
73+
["fp8_config.fp8_recipe=transformer_engine.common.recipe.DelayedScaling"],
74+
te_fp8.check_fp8_support,
75+
),
76+
(
77+
"Float8CurrentScaling",
78+
["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8CurrentScaling"],
79+
te_fp8.check_fp8_support,
80+
),
81+
(
82+
"Float8BlockScaling",
83+
["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8BlockScaling"],
84+
te_fp8.check_fp8_block_scaling_support,
85+
),
86+
(
87+
"MXFP8BlockScaling",
88+
["fp8_config.fp8_recipe=transformer_engine.common.recipe.MXFP8BlockScaling"],
89+
te_fp8.check_mxfp8_support,
90+
),
91+
]
92+
93+
94+
def _parametrize_fp8_recipes():
95+
"""Generate pytest.param objects with xfail marks for unsupported FP8 recipes."""
96+
params = []
97+
for name, overrides, check_fn in _FP8_RECIPE_CONFIGS:
98+
supported, reason = check_fn()
99+
params.append(
100+
pytest.param(
101+
overrides,
102+
id=name,
103+
marks=pytest.mark.xfail(condition=not supported, reason=reason),
104+
)
105+
)
106+
return params
107+
108+
109+
@pytest.fixture(params=_parametrize_fp8_recipes())
110+
def fp_recipe(request):
111+
"""Parametrized fixture providing FP8 recipe Hydra overrides for each supported TE recipe."""
112+
return request.param
113+
114+
64115
@pytest.fixture(scope="session", autouse=True)
65116
def device_mesh():
66117
"""Create a re-usable torch process group for testing.

0 commit comments

Comments
 (0)