Skip to content

Commit 8dda89b

Browse files
committed
adds llama3 MXFP8 NVFP4 layer-wise precision
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent f873667 commit 8dda89b

22 files changed

Lines changed: 1441 additions & 160 deletions

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class NVLlamaConfig(LlamaConfig):
5252
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
5353
attn_input_format: str = "thd"
5454
self_attn_mask_type: str = "padding_causal"
55+
layer_precision: list[str | None] | None = None
5556

5657
def __init__(
5758
self,
@@ -217,11 +218,54 @@ def _init_method(x):
217218
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
218219
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
219220

221+
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
222+
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
223+
220224
self.gradient_checkpointing = False
221225

222226
# Initialize weights and apply final processing
223227
self.post_init()
224228

229+
def set_recipes(
230+
self,
231+
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
232+
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
233+
) -> None:
234+
"""Attach quantization recipe objects for per-layer autocast.
235+
236+
Recipes are not serializable and must be set at runtime after model creation
237+
and sharding (FSDP/DDP) but before training. The per-layer precision
238+
assignments are read from ``self.config.layer_precision``.
239+
240+
Args:
241+
fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
242+
fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
243+
"""
244+
self._fp8_recipe = fp8_recipe
245+
self._fp4_recipe = fp4_recipe
246+
247+
def get_layer_autocast(self, layer_number: int):
248+
"""Return the appropriate TE autocast context manager for a given layer.
249+
250+
The context interacts with the outer FP8 autocast in the training script:
251+
- FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
252+
- FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
253+
- BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
254+
255+
Args:
256+
layer_number: The 0-indexed layer number.
257+
258+
Returns:
259+
A context manager for the layer's quantization mode.
260+
"""
261+
precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
262+
if precision == "fp8":
263+
return nullcontext()
264+
elif precision == "fp4":
265+
return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
266+
else:
267+
return transformer_engine.pytorch.autocast(enabled=False)
268+
225269
def forward(
226270
self,
227271
input_ids: torch.Tensor | None = None,
@@ -298,12 +342,14 @@ def forward(
298342
if te_rope_emb.dtype != torch.float32:
299343
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
300344

301-
with self.get_autocast_context(None, outer=True):
302-
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
345+
# Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
346+
# by get_layer_autocast(), which nests inside this context.
347+
with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
348+
for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
303349
if output_hidden_states:
304350
all_hidden_states = (*all_hidden_states, hidden_states)
305351

306-
with self.get_autocast_context(layer_idx):
352+
with self.get_layer_autocast(layer_number):
307353
hidden_states = decoder_layer(
308354
hidden_states,
309355
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import pickle
18+
import subprocess
19+
20+
import pytest
21+
import torch
22+
from transformer_engine.pytorch.fp8 import check_fp8_support
23+
24+
25+
def requires_fp8(func):
26+
"""Decorator to skip tests that require FP8 support."""
27+
fp8_available, reason = check_fp8_support()
28+
return pytest.mark.skipif(not fp8_available, reason=f"FP8 is not supported on this GPU: {reason}")(func)
29+
30+
31+
requires_multi_gpu = pytest.mark.skipif(
32+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
33+
reason="Test requires at least 2 GPUs",
34+
)
35+
36+
37+
@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"])
38+
@requires_fp8
39+
def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
40+
cmd = [
41+
"torchrun",
42+
"--nproc_per_node=1",
43+
"--rdzv-backend=c10d",
44+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
45+
os.path.relpath(__file__),
46+
"--strategy",
47+
strategy,
48+
]
49+
50+
result = subprocess.run(
51+
cmd,
52+
check=False,
53+
text=True,
54+
stdout=subprocess.PIPE,
55+
stderr=subprocess.PIPE,
56+
timeout=240,
57+
)
58+
if result.returncode != 0:
59+
print(f"STDOUT:\n{result.stdout}")
60+
print(f"STDERR:\n{result.stderr}")
61+
pytest.fail(f"Command failed with exit code {result.returncode}")
62+
63+
64+
@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"])
65+
@requires_fp8
66+
@requires_multi_gpu
67+
def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port):
68+
cmd = [
69+
"torchrun",
70+
"--nproc_per_node=2",
71+
"--rdzv-backend=c10d",
72+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
73+
os.path.relpath(__file__),
74+
"--strategy",
75+
strategy,
76+
]
77+
78+
result = subprocess.run(
79+
cmd,
80+
check=False,
81+
text=True,
82+
stdout=subprocess.PIPE,
83+
stderr=subprocess.PIPE,
84+
timeout=240,
85+
)
86+
if result.returncode != 0:
87+
print(f"STDOUT:\n{result.stdout}")
88+
print(f"STDERR:\n{result.stderr}")
89+
pytest.fail(f"Command failed with exit code {result.returncode}")
90+
91+
92+
if __name__ == "__main__":
93+
import argparse
94+
import enum
95+
import os
96+
import sys
97+
from dataclasses import dataclass, field
98+
from pathlib import Path
99+
100+
# Ensure the model directory is on sys.path for bare module imports.
101+
sys.path.insert(0, Path(__file__).resolve().parent.parent.as_posix())
102+
103+
import torch.distributed as dist
104+
from torch.distributed.device_mesh import init_device_mesh
105+
from torch.distributed.fsdp import fully_shard
106+
from torch.optim import AdamW
107+
from transformer_engine.pytorch.fp8 import DelayedScaling, Format
108+
109+
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
110+
111+
def recursive_assert(a, b, path=""):
112+
if isinstance(a, dict) and isinstance(b, dict):
113+
assert a.keys() == b.keys(), f"Dictionary keys mismatch: {a.keys()} != {b.keys()} at {path}"
114+
for k in a:
115+
recursive_assert(a[k], b[k], path=f"{path}.{k}")
116+
elif isinstance(a, list) and isinstance(b, list):
117+
assert len(a) == len(b), f"List lengths mismatch: {len(a)} != {len(b)} at {path}"
118+
for i in range(len(a)):
119+
recursive_assert(a[i], b[i], path=f"{path}.{i}")
120+
elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
121+
torch.testing.assert_close(a, b, msg=f"Tensor mismatch at {path}")
122+
else:
123+
assert a == b, f"Value mismatch at {path}: {a} != {b}"
124+
125+
class Strategy(enum.StrEnum):
126+
DDP = "ddp"
127+
FSDP2 = "fsdp2"
128+
129+
@dataclass
130+
class DistributedConfig:
131+
"""Class to track distributed ranks."""
132+
133+
rank: int = field(default_factory=dist.get_rank)
134+
local_rank: int = field(default_factory=lambda: int(os.environ["LOCAL_RANK"]))
135+
world_size: int = field(default_factory=dist.get_world_size)
136+
137+
def is_main_process(self) -> bool:
138+
"""This is the global rank 0 process, to be used for wandb logging, etc."""
139+
return self.rank == 0
140+
141+
parser = argparse.ArgumentParser()
142+
parser.add_argument("--strategy", type=Strategy, default=Strategy.DDP, choices=[Strategy.FSDP2, Strategy.DDP])
143+
args = parser.parse_args()
144+
145+
torch.distributed.init_process_group(backend="nccl")
146+
dist_config = DistributedConfig()
147+
torch.cuda.set_device(dist_config.local_rank)
148+
device_mesh = init_device_mesh(
149+
"cuda",
150+
mesh_shape=(dist_config.world_size, 1),
151+
mesh_dim_names=("dp", "tp"),
152+
)
153+
device = f"cuda:{dist_config.local_rank}"
154+
155+
fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10)
156+
157+
config = NVLlamaConfig(
158+
hidden_size=256,
159+
intermediate_size=512,
160+
num_hidden_layers=6,
161+
num_attention_heads=8,
162+
num_key_value_heads=4,
163+
vocab_size=100,
164+
dtype=torch.bfloat16,
165+
)
166+
config.layer_precision = ["fp8"] * config.num_hidden_layers
167+
model = NVLlamaForCausalLM(config)
168+
169+
if args.strategy is Strategy.FSDP2:
170+
for layer in model.model.layers:
171+
fully_shard(layer, mesh=device_mesh["dp"])
172+
fully_shard(model, mesh=device_mesh["dp"])
173+
model.to(device)
174+
175+
elif args.strategy is Strategy.DDP:
176+
model.to(device)
177+
model = torch.nn.parallel.DistributedDataParallel(
178+
model,
179+
device_ids=[dist_config.local_rank],
180+
output_device=dist_config.local_rank,
181+
device_mesh=device_mesh["dp"],
182+
)
183+
184+
optimizer = AdamW(model.parameters())
185+
186+
# Attach FP8 recipes to the model (layer precision is already on config).
187+
llama_model = model.module.model if args.strategy is Strategy.DDP else model.model
188+
llama_model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
189+
190+
model.train()
191+
192+
generator = torch.Generator()
193+
generator.manual_seed(torch.distributed.get_rank())
194+
195+
for _ in range(3):
196+
input_data = {
197+
"input_ids": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
198+
"labels": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
199+
"attention_mask": torch.ones(1, 32),
200+
}
201+
input_data = {k: v.to(torch.cuda.current_device()) for k, v in input_data.items()}
202+
203+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
204+
outputs = model(**input_data)
205+
206+
outputs.loss.backward()
207+
208+
# Access FP8 extra states directly from modules instead of state_dict()
209+
# since state_dict() now filters them out for HuggingFace compatibility
210+
fp8_extra_states = {}
211+
for name, module in model.named_modules():
212+
if hasattr(module, "_extra_state") and callable(module._extra_state):
213+
extra_state = module._extra_state()
214+
if extra_state is not None and len(extra_state) > 0:
215+
fp8_extra_states[f"{name}._extra_state"] = extra_state
216+
217+
# lm_head is BF16, not FP8, so exclude it from FP8 checks
218+
fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key}
219+
220+
# 2 ranks, test to ensure that both ranks have the same FP8 extra states
221+
if torch.distributed.get_world_size() == 2:
222+
outputs_list = [None] * torch.distributed.get_world_size() if torch.distributed.get_rank() == 0 else None
223+
torch.distributed.gather_object(fp8_extra_states, outputs_list, dst=0)
224+
if torch.distributed.get_rank() == 0:
225+
assert outputs_list is not None
226+
227+
for key in outputs_list[0]:
228+
state_1 = outputs_list[0][key]
229+
state_2 = outputs_list[1][key]
230+
assert len(state_1) > 0, f"No FP8 extra states for {key}, rank 0"
231+
assert len(state_2) > 0, f"No FP8 extra states for {key}, rank 1"
232+
dict_1 = pickle.loads(state_1.detach().numpy(force=True).tobytes())
233+
dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes())
234+
recursive_assert(dict_1, dict_2)
235+
236+
# One rank, test to ensure the correct FP8 extra states are saved
237+
if torch.distributed.get_world_size() == 1:
238+
for key, val in fp8_extra_states.items():
239+
assert len(val) > 0, f"No FP8 extra states for {key}"
240+
fp8_meta_dict = pickle.loads(val.detach().numpy(force=True).tobytes())
241+
assert fp8_meta_dict["recipe"] == fp8_recipe, f"Recipe mismatch for {key}"
242+
243+
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)