Skip to content

Commit 8fcada6

Browse files
Jonathan Mitchelljomitchellnv
authored andcommitted
NVFP4 and MXFP8 integrations
- includes capability to log out stats for MXFP8 and NVFP4 at the same time - Enables layer-wise precision setting - Includes support for layer-wise quant config - Adds support for FSDP2 DDP and MFSDP for MXFP8 NVFP4 - note: FP32 master weights wont work with ddp,mfsdp Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent b2ddae1 commit 8fcada6

33 files changed

Lines changed: 2493 additions & 257 deletions

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@
2626
"editor.rulers": [
2727
120
2828
],
29-
"autoDocstring.docstringFormat": "google-notypes"
29+
"autoDocstring.docstringFormat": "google-notypes",
30+
"search.exclude": { "**/logs/**": true },
3031
}

bionemo-recipes/models/esm2/modeling_esm_te.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
"""
2424

2525
import warnings
26+
from contextlib import nullcontext
2627
from typing import ClassVar, Literal, Optional, Unpack
2728

2829
# TODO: put import guard around transformer_engine here, with an informative error message around
2930
# installation and the nvidia docker container.
3031
import torch
32+
import transformer_engine.common.recipe
3133
import transformer_engine.pytorch
3234
from torch import nn
3335
from torch.nn import CrossEntropyLoss
@@ -71,6 +73,7 @@ def __init__(
7173
max_seq_length: Optional[int] = None,
7274
padded_vocab_size: Optional[int] = 64,
7375
attn_mask_type: str = "padding",
76+
layer_precision: list[str | None] | None = None,
7477
**kwargs,
7578
):
7679
"""Initialize the NVEsmConfig with additional TE-related config options.
@@ -100,6 +103,9 @@ def __init__(
100103
padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults
101104
to vocab_size. Must be greater than or equal to vocab_size.
102105
attn_mask_type: The type of attention mask to use.
106+
layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers``
107+
where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None``
108+
(the default) means no quantization is configured.
103109
**kwargs: Additional config options to pass to EsmConfig.
104110
"""
105111
super().__init__(**kwargs)
@@ -111,6 +117,7 @@ def __init__(
111117
self.micro_batch_size = micro_batch_size
112118
self.max_seq_length = max_seq_length
113119
self.attn_mask_type = attn_mask_type
120+
self.layer_precision = layer_precision
114121

115122
# Set padded_vocab_size with default fallback to vocab_size
116123
self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size
@@ -165,6 +172,8 @@ def _init_method(x):
165172
for i in range(config.num_hidden_layers)
166173
]
167174
)
175+
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
176+
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
168177
self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(
169178
config.hidden_size,
170179
eps=config.layer_norm_eps,
@@ -174,6 +183,49 @@ def _init_method(x):
174183
if config.position_embedding_type == "rotary":
175184
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
176185

186+
def set_recipes(
187+
self,
188+
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
189+
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
190+
) -> None:
191+
"""Attach quantization recipe objects for per-layer autocast.
192+
193+
Recipes are not serializable and must be set at runtime after model creation
194+
and sharding (FSDP/DDP/mFSDP) but before training. The per-layer precision
195+
assignments are read from ``self.config.layer_precision``.
196+
197+
These recipes are also hardware specific, so we should not store them as
198+
attributes of the model and attach them at runtime.
199+
200+
Args:
201+
fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
202+
fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
203+
"""
204+
self._fp8_recipe = fp8_recipe
205+
self._fp4_recipe = fp4_recipe
206+
207+
def get_layer_autocast(self, layer_number: int):
208+
"""Return the appropriate TE autocast context manager for a given layer.
209+
210+
The context interacts with the outer FP8 autocast in the training script:
211+
- FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
212+
- FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
213+
- BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
214+
215+
Args:
216+
layer_number: The 0-indexed layer number.
217+
218+
Returns:
219+
A context manager for the layer's quantization mode.
220+
"""
221+
precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
222+
if precision == "fp8":
223+
return nullcontext()
224+
elif precision == "fp4":
225+
return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
226+
else:
227+
return transformer_engine.pytorch.autocast(enabled=False)
228+
177229
def forward(
178230
self,
179231
hidden_states: torch.Tensor,
@@ -201,22 +253,26 @@ def forward(
201253
if te_rope_emb.dtype == torch.float32:
202254
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
203255

204-
for layer_module in self.layers:
205-
if kwargs.get("output_hidden_states", False):
206-
all_hidden_states = (*all_hidden_states, hidden_states)
207-
208-
hidden_states = layer_module(
209-
hidden_states,
210-
attention_mask,
211-
rotary_pos_emb=te_rope_emb,
212-
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
213-
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
214-
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
215-
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
216-
max_seqlen_q=kwargs.get("max_length_q", None),
217-
max_seqlen_kv=kwargs.get("max_length_k", None),
218-
pad_between_seqs=kwargs.get("pad_between_seqs", None),
219-
)
256+
# Outer FP8 autocast enables FP8 compute for the encoder stack. Per-layer overrides (FP4, BF16) are handled
257+
# by get_layer_autocast(), which nests inside this context.
258+
with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
259+
for layer_number, layer_module in enumerate(self.layers):
260+
if kwargs.get("output_hidden_states", False):
261+
all_hidden_states = (*all_hidden_states, hidden_states)
262+
263+
with self.get_layer_autocast(layer_number):
264+
hidden_states = layer_module(
265+
hidden_states,
266+
attention_mask,
267+
rotary_pos_emb=te_rope_emb,
268+
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
269+
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
270+
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
271+
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
272+
max_seqlen_q=kwargs.get("max_length_q", None),
273+
max_seqlen_kv=kwargs.get("max_length_k", None),
274+
pad_between_seqs=kwargs.get("pad_between_seqs", None),
275+
)
220276

221277
hidden_states = self.emb_layer_norm_after(hidden_states)
222278

bionemo-recipes/models/esm2/tests/test_distributed_fp8.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ def is_main_process(self) -> bool:
161161
)
162162
device = f"cuda:{dist_config.local_rank}"
163163

164+
fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10)
165+
164166
config = NVEsmConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", dtype=torch.bfloat16, revision="c731040f")
167+
config.layer_precision = ["fp8"] * config.num_hidden_layers
165168
model = NVEsmForMaskedLM(config)
166169

167170
if args.strategy is Strategy.FSDP2:
@@ -195,13 +198,15 @@ def is_main_process(self) -> bool:
195198
tp_dim="tp",
196199
)
197200

201+
# Attach FP8 recipes to the encoder (layer precision is already on config).
202+
encoder = model.module.esm.encoder if args.strategy is Strategy.DDP else model.esm.encoder
203+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
204+
198205
model.train()
199206

200207
generator = torch.Generator()
201208
generator.manual_seed(torch.distributed.get_rank())
202209

203-
fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10)
204-
205210
for _ in range(3):
206211
input_data = {
207212
"input_ids": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
@@ -211,8 +216,7 @@ def is_main_process(self) -> bool:
211216
input_data = {k: v.to(torch.cuda.current_device()) for k, v in input_data.items()}
212217

213218
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
214-
with transformer_engine.pytorch.autocast(enabled=True, recipe=fp8_recipe):
215-
outputs = model(**input_data)
219+
outputs = model(**input_data)
216220

217221
outputs.loss.backward()
218222

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for NVEsmEncoder.set_recipes and get_layer_autocast."""
17+
18+
from contextlib import nullcontext
19+
from unittest.mock import patch
20+
21+
import pytest
22+
import transformer_engine.common.recipe
23+
import transformer_engine.pytorch
24+
25+
from modeling_esm_te import NVEsmConfig, NVEsmEncoder
26+
27+
28+
@pytest.fixture
29+
def encoder():
30+
"""Create a small NVEsmEncoder on CUDA for testing."""
31+
config = NVEsmConfig(
32+
hidden_size=320,
33+
intermediate_size=1280,
34+
num_hidden_layers=6,
35+
num_attention_heads=20,
36+
max_position_embeddings=1026,
37+
)
38+
return NVEsmEncoder(config)
39+
40+
41+
# -- set_recipes --
42+
43+
44+
def test_all_fp8(encoder):
45+
encoder.config.layer_precision = ["fp8"] * 6
46+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
47+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
48+
assert encoder._fp8_recipe is fp8_recipe
49+
assert encoder._fp4_recipe is None
50+
assert all(p == "fp8" for p in encoder.config.layer_precision)
51+
52+
53+
def test_all_fp4(encoder):
54+
encoder.config.layer_precision = ["fp4"] * 6
55+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
56+
encoder.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe)
57+
assert encoder._fp8_recipe is None
58+
assert encoder._fp4_recipe is fp4_recipe
59+
assert all(p == "fp4" for p in encoder.config.layer_precision)
60+
61+
62+
def test_all_bf16(encoder):
63+
encoder.config.layer_precision = [None] * 6
64+
encoder.set_recipes(fp8_recipe=None, fp4_recipe=None)
65+
assert all(p is None for p in encoder.config.layer_precision)
66+
67+
68+
def test_mixed_fp8_fp4(encoder):
69+
encoder.config.layer_precision = ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
70+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
71+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
72+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
73+
assert encoder.config.layer_precision == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
74+
75+
76+
def test_mixed_fp8_bf16(encoder):
77+
encoder.config.layer_precision = ["fp8", None, "fp8", None, "fp8", None]
78+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
79+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
80+
assert encoder.config.layer_precision == ["fp8", None, "fp8", None, "fp8", None]
81+
82+
83+
def test_mixed_all_three(encoder):
84+
encoder.config.layer_precision = ["fp8", "fp8", None, None, "fp4", "fp4"]
85+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
86+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
87+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
88+
assert encoder.config.layer_precision == ["fp8", "fp8", None, None, "fp4", "fp4"]
89+
90+
91+
def test_covers_all_layers(encoder):
92+
encoder.config.layer_precision = ["fp8"] + [None] * 5
93+
encoder.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None)
94+
assert len(encoder.config.layer_precision) == 6
95+
96+
97+
def test_recipes_stored_as_attributes(encoder):
98+
encoder.config.layer_precision = ["fp8", "fp4", None, None, None, None]
99+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
100+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
101+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
102+
assert encoder._fp8_recipe is fp8_recipe
103+
assert encoder._fp4_recipe is fp4_recipe
104+
# The precision list only contains strings/None, not recipe objects.
105+
for v in encoder.config.layer_precision:
106+
assert v is None or isinstance(v, str)
107+
108+
109+
# -- get_layer_autocast --
110+
111+
112+
def test_fp8_layer_returns_nullcontext(encoder):
113+
encoder.config.layer_precision = ["fp8"] + [None] * 5
114+
encoder.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None)
115+
ctx = encoder.get_layer_autocast(0)
116+
assert isinstance(ctx, nullcontext)
117+
118+
119+
def test_fp4_layer_returns_te_autocast(encoder):
120+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
121+
encoder.config.layer_precision = ["fp4"] + [None] * 5
122+
encoder.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe)
123+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
124+
mock_autocast.return_value = "fp4_context"
125+
ctx = encoder.get_layer_autocast(0)
126+
mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe)
127+
assert ctx == "fp4_context"
128+
129+
130+
def test_bf16_layer_returns_te_autocast_disabled(encoder):
131+
encoder.config.layer_precision = [None] * 6
132+
encoder.set_recipes(fp8_recipe=None, fp4_recipe=None)
133+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
134+
mock_autocast.return_value = "bf16_context"
135+
ctx = encoder.get_layer_autocast(0)
136+
mock_autocast.assert_called_once_with(enabled=False)
137+
assert ctx == "bf16_context"
138+
139+
140+
def test_uninitialized_defaults_to_bf16(encoder):
141+
"""When layer_precision is None (default), all layers default to BF16."""
142+
assert encoder.config.layer_precision is None
143+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
144+
mock_autocast.return_value = "bf16_context"
145+
ctx = encoder.get_layer_autocast(0)
146+
mock_autocast.assert_called_once_with(enabled=False)
147+
assert ctx == "bf16_context"
148+
149+
150+
def test_mixed_layers_return_correct_contexts(encoder):
151+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
152+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
153+
encoder.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None]
154+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
155+
156+
# FP8 layers -> nullcontext
157+
assert isinstance(encoder.get_layer_autocast(0), nullcontext)
158+
assert isinstance(encoder.get_layer_autocast(1), nullcontext)
159+
160+
# FP4 layers -> te.pytorch.autocast
161+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
162+
mock_autocast.return_value = "fp4_context"
163+
encoder.get_layer_autocast(2)
164+
mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe)
165+
166+
# BF16 layers -> te.pytorch.autocast(enabled=False)
167+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
168+
mock_autocast.return_value = "bf16_context"
169+
encoder.get_layer_autocast(4)
170+
mock_autocast.assert_called_with(enabled=False)
171+
172+
173+
def test_layer_precision_is_pickleable(encoder):
174+
"""The config.layer_precision list should be trivially pickleable."""
175+
import pickle
176+
177+
encoder.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None]
178+
roundtripped = pickle.loads(pickle.dumps(encoder.config.layer_precision))
179+
assert roundtripped == encoder.config.layer_precision

0 commit comments

Comments
 (0)