Skip to content

Commit 9eef114

Browse files
authored
Adds fp8 stats logging to ESM2 / LLAMA3 DDP/FSDP2 training (#1413)
### Description This MR enables the option to log FP8 stats during training. #### Usage By adjusting the hydraconfig or by adding the following commandline arguments to your training script you can receive FP8 statistics. ```python python train_fsdp2.py \ fp8_stats_config.enabled=True \ # whether to log stats or not fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy \ # where to store the logs fp8_stats_config.fp8_stats_file=./fp8_stats.yaml \ # specifies what stats you want to run. Currently this is saved in this yaml file. fp8_config.enabled=True # set this to use FP8 otherwise stats logging wont work ``` ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [X] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests for bionemo2 - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2 - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * FP8 tensor statistics logging during training with configurable stats collection and debugging outputs. * FP8 analysis tool for automated metric parsing and visualization with publication-quality heatmaps. * Configuration options for enabling FP8 statistics logging and specifying output directories. * **Documentation** * Added FP8 debugging guides with setup and usage instructions. * Added FP8 analysis tool user guide with examples and interpretation guidance. * **Tests** * Added validation tests for FP8 statistics logging in training workflows. * Added comprehensive tests for FP8 analysis tool functionality. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent e53b726 commit 9eef114

28 files changed

Lines changed: 4466 additions & 1 deletion

File tree

.devcontainer/recipes/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ transformers
1616
typer
1717
wandb
1818
zstandard
19+
nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect
20+
seaborn

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ MNISTCustom/
7474
*.pot
7575

7676
# Django stuff:
77-
*.log
7877
local_settings.py
7978
db.sqlite3
8079

@@ -170,6 +169,10 @@ local/
170169

171170
# Logs
172171
*.log
172+
!bionemo-recipes/recipes/fp8_analysis/dummy_logs_esm2/rank_0/nvdlfw_inspect_logs/nvdlfw_inspect_globalrank-0.log
173+
!bionemo-recipes/recipes/fp8_analysis/dummy_logs_esm2/rank_0/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log
174+
!bionemo-recipes/recipes/fp8_analysis/dummy_logs_llama3/rank_0/nvdlfw_inspect_logs/nvdlfw_inspect_globalrank-0.log
175+
!bionemo-recipes/recipes/fp8_analysis/dummy_logs_llama3/rank_0/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log
173176

174177
# Tests
175178
tests/__pycache__/

bionemo-recipes/recipes/esm2_native_te/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ configuration parameters, including switching to `MXFP8BlockScaling`, can be set
106106
python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true
107107
```
108108

109+
#### FP8 Debugging
110+
111+
We also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients.
112+
113+
To enable this please select the following config options.
114+
115+
```python
116+
python train_fsdp2.py \
117+
fp8_stats_config.enabled=True # whether to log stats or not
118+
fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy # where to store the logs
119+
fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml # specifies what stats you want to run. Currently this is saved in this yaml file.
120+
fp8_config.enabled=True # set this to use FP8 otherwise stats logging won't work
121+
```
122+
123+
Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for `train_mfsdp`.
124+
125+
The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure.
126+
127+
This comes as a performance cost that is dependent on the `freq` parameter mentioned above. `freq=1` collects stats on every step which in our
128+
experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using `freq>=10` to reduce this performance hit.
129+
109130
### Sequence Packing (THD input format)
110131

111132
Sequence packing is handled via a padding-free collator (in `collator.py`) that provides input arguments (e.g.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
example_fp8_tensor_stat_collection:
2+
enabled: True
3+
layers:
4+
# Match the actual linear layers within attention that support FP8 stats
5+
layer_types: [layernorm_qkv]
6+
transformer_engine:
7+
LogFp8TensorStats:
8+
enabled: True
9+
tensors_struct:
10+
- tensor: activation
11+
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
12+
freq: 10
13+
- tensor: gradient
14+
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
15+
freq: 10
16+
- tensor: weight
17+
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
18+
freq: 10

bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,8 @@ checkpoint:
7575

7676
logger:
7777
frequency: 100
78+
79+
fp8_stats_config:
80+
enabled: false
81+
fp8_stats_file: ./fp8_debugging_stats.yaml
82+
fp8_log_dir: ./log_fp8_stats

bionemo-recipes/recipes/esm2_native_te/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ tqdm
1010
transformer_engine[pytorch]
1111
transformers
1212
wandb
13+
nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect

bionemo-recipes/recipes/esm2_native_te/tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ def recipe_path() -> Path:
3434
return Path(__file__).parent.parent
3535

3636

37+
def pytest_collection_modifyitems(items):
38+
"""Run FP8 stats logging tests first to avoid late debug initialization."""
39+
stats_test_names = {
40+
"test_sanity_ddp_fp8_stats_logging",
41+
"test_sanity_fsdp2_fp8_stats_logging",
42+
}
43+
stats_tests = [item for item in items if item.name in stats_test_names]
44+
other_tests = [item for item in items if item.name not in stats_test_names]
45+
items[:] = stats_tests + other_tests
46+
47+
3748
@pytest.fixture(scope="session", autouse=True)
3849
def device_mesh():
3950
"""Create a re-usable device mesh for testing.

bionemo-recipes/recipes/esm2_native_te/tests/test_train.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,46 @@ def test_sanity_ddp_fp8(tmp_path, recipe_path):
142142
main_ddp(sanity_config)
143143

144144

145+
@requires_fp8
146+
def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path):
147+
"""Test that FP8 stats logging creates the expected log files."""
148+
fp8_log_dir = tmp_path / "fp8_stats_logs"
149+
150+
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
151+
sanity_config = compose(
152+
config_name="L0_sanity",
153+
overrides=[
154+
f"+wandb_init_args.dir={tmp_path}",
155+
f"checkpoint.ckpt_dir={tmp_path}",
156+
"fp8_config.enabled=true",
157+
"fp8_stats_config.enabled=true",
158+
f"fp8_stats_config.fp8_log_dir={fp8_log_dir}",
159+
"num_train_steps=4",
160+
],
161+
)
162+
163+
main_ddp(sanity_config)
164+
165+
# Verify the log directory structure was created
166+
assert fp8_log_dir.exists(), "FP8 log directory was not created"
167+
assert (fp8_log_dir / "rank_0").exists(), "rank_0 directory was not created"
168+
assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_logs").exists(), "nvdlfw_inspect_logs directory was not created"
169+
assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs").exists(), (
170+
"nvdlfw_inspect_statistics_logs directory was not created"
171+
)
172+
173+
# Verify the log files exist
174+
metadata_log = fp8_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log"
175+
stats_log = fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log"
176+
177+
assert metadata_log.exists(), "Metadata log file was not created"
178+
assert stats_log.exists(), "Statistics log file was not created"
179+
180+
# Verify files are non-empty
181+
assert metadata_log.stat().st_size > 0, "Metadata log file is empty"
182+
assert stats_log.stat().st_size > 0, "Statistics log file is empty"
183+
184+
145185
@requires_fp8
146186
def test_sanity_convergence_fsdp2_fp8(tmp_path, recipe_path):
147187
"""For FSDP2, we check that the script can run successfully with FP8 and check convergence."""
@@ -159,6 +199,32 @@ def test_sanity_convergence_fsdp2_fp8(tmp_path, recipe_path):
159199
assert final_loss < 3.0, f"Final loss {final_loss} is too high"
160200

161201

202+
@requires_fp8
203+
def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path):
204+
"""Test that FP8 stats logging works with FSDP2."""
205+
fp8_log_dir = tmp_path / "fp8_stats_logs"
206+
207+
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
208+
sanity_config = compose(
209+
config_name="L0_sanity",
210+
overrides=[
211+
f"+wandb_init_args.dir={tmp_path}",
212+
f"checkpoint.ckpt_dir={tmp_path}",
213+
"fp8_config.enabled=true",
214+
"fp8_stats_config.enabled=true",
215+
f"fp8_stats_config.fp8_log_dir={fp8_log_dir}",
216+
"num_train_steps=4",
217+
],
218+
)
219+
220+
main_fsdp2(sanity_config)
221+
222+
# Verify log structure (same assertions as above)
223+
assert fp8_log_dir.exists()
224+
assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
225+
assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
226+
227+
162228
@requires_fp8
163229
@pytest.mark.xfail(reason="MFSDP doesn't seem to support fp8_model_init (BIONEMO-3012)")
164230
def test_sanity_mfsdp_fp8_and_model_init(tmp_path, recipe_path):

bionemo-recipes/recipes/esm2_native_te/train_ddp.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from pathlib import Path
1818

1919
import hydra
20+
import nvdlfw_inspect.api as debug_api
2021
import torch
22+
import transformer_engine
2123
import transformer_engine.pytorch
2224
from omegaconf import DictConfig
2325
from torch.distributed.device_mesh import init_device_mesh
@@ -50,6 +52,24 @@ def main(args: DictConfig) -> float | None:
5052
torch.distributed.init_process_group(backend="nccl", device_id=device)
5153
torch.cuda.set_device(dist_config.local_rank)
5254

55+
# TE Debug feature logging
56+
if args.fp8_stats_config.enabled and not args.fp8_config.enabled:
57+
raise ValueError(
58+
"fp8_stats_config.enabled is true but fp8_config.enabled is false, please set fp8_config.enabled to true in the config if you wish to collect FP8 stats"
59+
)
60+
61+
if args.fp8_stats_config.enabled:
62+
fp8_stats_file = args.fp8_stats_config.fp8_stats_file
63+
fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}"
64+
fp8_log_dir.mkdir(parents=True, exist_ok=True)
65+
logger.info(f"Logging FP8 stats to {fp8_log_dir}")
66+
te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features")
67+
debug_api.initialize(
68+
config_file=fp8_stats_file,
69+
feature_dirs=[te_features_dir],
70+
log_dir=fp8_log_dir,
71+
default_logging_enabled=True,
72+
)
5373
# Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2
5474
# and MFSDP.
5575
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("ddp",))
@@ -84,6 +104,9 @@ def main(args: DictConfig) -> float | None:
84104
optimizer = AdamW(model.parameters(), **args.adamw_kwargs)
85105
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
86106

107+
if args.fp8_stats_config.enabled:
108+
debug_api.infer_and_assign_layer_names(model)
109+
87110
model = model.to(device=device)
88111
model = torch.nn.parallel.DistributedDataParallel(
89112
model,
@@ -134,6 +157,8 @@ def main(args: DictConfig) -> float | None:
134157
loss = outputs.loss
135158
loss.backward()
136159

160+
if args.fp8_stats_config.enabled:
161+
debug_api.step()
137162
# Compute and clip gradient norms.
138163
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
139164

@@ -181,6 +206,8 @@ def main(args: DictConfig) -> float | None:
181206

182207
# Clean up distributed training
183208
perf_logger.finish()
209+
if args.fp8_stats_config.enabled:
210+
debug_api.end_debug()
184211
torch.distributed.destroy_process_group()
185212

186213
return perf_logger.min_loss

bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from pathlib import Path
1919

2020
import hydra
21+
import nvdlfw_inspect.api as debug_api
2122
import torch
23+
import transformer_engine
2224
import transformer_engine.pytorch
2325
from omegaconf import DictConfig, OmegaConf
2426
from torch.distributed.device_mesh import init_device_mesh
@@ -55,6 +57,25 @@ def main(args: DictConfig) -> float | None:
5557
torch.distributed.init_process_group(backend="nccl", device_id=device)
5658
torch.cuda.set_device(dist_config.local_rank)
5759

60+
# TE Debug feature logging - MUST be done BEFORE FSDP wrapping
61+
if args.fp8_stats_config.enabled and not args.fp8_config.enabled:
62+
raise ValueError(
63+
"fp8_stats_config.enabled is true but fp8_config.enabled is false, please set fp8_config.enabled to true in the config if you wish to collect FP8 stats"
64+
)
65+
66+
if args.fp8_stats_config.enabled:
67+
fp8_stats_file = args.fp8_stats_config.fp8_stats_file
68+
fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}"
69+
fp8_log_dir.mkdir(parents=True, exist_ok=True)
70+
logger.info(f"Logging FP8 stats to {fp8_log_dir}")
71+
te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features")
72+
debug_api.initialize(
73+
config_file=fp8_stats_file,
74+
feature_dirs=[te_features_dir],
75+
log_dir=fp8_log_dir,
76+
default_logging_enabled=True,
77+
)
78+
5879
# Create a device mesh for FSDP.
5980
device_mesh = init_device_mesh(
6081
"cuda",
@@ -86,6 +107,7 @@ def main(args: DictConfig) -> float | None:
86107

87108
# We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models.
88109
transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer
110+
89111
for layer in transformer_stack:
90112
fully_shard(layer, mesh=device_mesh["dp"])
91113
fully_shard(model, mesh=device_mesh["dp"])
@@ -100,6 +122,10 @@ def main(args: DictConfig) -> float | None:
100122
model.to_empty(device=device)
101123
model.apply(model._init_weights)
102124

125+
# Assign names to layers so debug API can identify them
126+
if args.fp8_stats_config.enabled:
127+
debug_api.infer_and_assign_layer_names(model)
128+
103129
# Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
104130
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
105131
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
@@ -152,6 +178,10 @@ def main(args: DictConfig) -> float | None:
152178
# Step optimizer.
153179
optimizer.step()
154180
scheduler.step()
181+
182+
if args.fp8_stats_config.enabled:
183+
debug_api.step()
184+
155185
optimizer.zero_grad()
156186

157187
perf_logger.log_step(
@@ -193,6 +223,8 @@ def main(args: DictConfig) -> float | None:
193223

194224
# Clean up distributed training
195225
perf_logger.finish()
226+
if args.fp8_stats_config.enabled:
227+
debug_api.end_debug()
196228
torch.distributed.destroy_process_group()
197229

198230
return perf_logger.min_loss

0 commit comments

Comments
 (0)