Skip to content

Commit 0d26cdf

Browse files
savitha-engclaude
andauthored
Add MXFP8/NVFP4 quantization, quantized model init, collator state, a… (#1572)
## Summary Builds on top of [PR #1500](#1500) (`jm/mxfp8-nvfp4-llama3`) with additional features, CI fixes, and benchmark documentation for the `llama3_native_te` recipe. ### Key changes on top of PR #1500 - **FusedAdam with FP32 master weights**: Replaces `MixedPrecisionPolicy` approach with TE's `FusedAdam(master_weight_dtype=torch.float32)` for mixed-precision training — simpler, better supported for FP8/MXFP8/NVFP4 - **Quantized model init with `preserve_high_precision_init_val`**: Stores BF16 copies of init values when using `te.quantized_model_init`, needed for FP32 master weight seeding in FP8 training - **Unified per-layer init path**: `get_autocast_context(init=True)` now works both standalone (model tests, no outer context) and under an outer `te.quantized_model_init` context (recipe training) — BF16 layers exit the outer FP8 context via `quantized_model_init(enabled=False)` - **Layer-wise precision control**: `layer_precision` config allows per-layer FP8/MXFP8/NVFP4/BF16 assignment (e.g., first/last layer BF16 for stability) - **NVFP4 support**: Added `NVFP4BlockScaling` recipe alongside MXFP8 - **70B configs**: Added Llama-3.1-70B hydra configs with context parallelism and THD input format - **CI test fixes**: Parametrized all FP8 tests across recipes (DelayedScaling, Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling) with automatic `xfail` for unsupported hardware — matching existing codebase patterns - **Restored `is_compileable` property**: Required by HuggingFace `transformers` `generate()` auto-compile check - **Hydra config cleanup**: Renamed `7b` → `8b` configs, removed experiment configs, restored pytest markers ## MXFP8 Performance Benchmarks ### Headline: MXFP8 vs BF16 throughput uplift (single B300 node) ![MXFP8 throughput uplift on 8B vs 70B](https://raw.githubusercontent.com/NVIDIA/bionemo-framework/savitha/lingua-7b-fp8-clean-pr/docs/docs/assets/images/llama3/lingua-8b-vs-70b-mxfp8-uplift.png) **Key findings:** - **Single-node:** MXFP8 over BF16 gives ~30% throughput uplift on both 8B and 70B. Quantized model init (`qinit`) adds ~0.8 pp on 8B but **+9.7 pp on 70B** — the per-layer quantize/dequantize work saved by qinit scales with depth (80 vs 32 layers). On 70B, **MXFP8 + qinit delivers +38.4% throughput gain over BF16** on a single B300 node. - **Multi-node 8B** (8 nodes / 64× B200): MXFP8 + qinit reaches **22,517 tokens/s/GPU vs 17,644 BF16 — +27.6% throughput (×1.28 speedup, −21.7% step time)**. - **Multi-node 70B** (4 nodes / 32× B200): MXFP8 + qinit reaches **2,725 tokens/s/GPU vs 1,972 BF16 — +38.2% throughput (×1.40 speedup, −27.6% step time)**. The larger relative gain on 70B vs 8B at scale matches the size-dependent qinit pattern from single-node. <details> <summary><strong>Single-node detail: per-model 3-way comparisons</strong></summary> **Llama-3.1-8B** (1 node / 8× B300 SXM6 AC, mbs=4, gbs=32 seqs / 262k tokens, seq_len=8192): ![8B single-node](https://raw.githubusercontent.com/NVIDIA/bionemo-framework/savitha/lingua-7b-fp8-clean-pr/docs/docs/assets/images/llama3/lingua-8b-mxfp8-1node.png) MXFP8 + qinit (+31.1%) and MXFP8 without qinit (+30.4%) deliver essentially the same throughput gain — at 32 layers the per-layer quantize/dequantize saving is small. **Llama-3.1-70B** (1 node / 8× B300 SXM6 AC, mbs=1, cp=2, dp=4, gbs=4 seqs, seq_len=8192): ![70B single-node](https://raw.githubusercontent.com/NVIDIA/bionemo-framework/savitha/lingua-7b-fp8-clean-pr/docs/docs/assets/images/llama3/lingua-70b-mxfp8-1node.png) MXFP8 + qinit (+39.4%) pulls ahead of MXFP8 without qinit (+28.7%) — a ~10 pp gap that doesn't appear at 8B. With 80 transformer layers, avoiding per-step quantize/dequantize adds up. `preserve_high_precision_init_val=True` (HPIV) is within 1% of qinit-without-HPIV, so HPIV is essentially free at steady state. </details> <details> <summary><strong>Multi-node throughput (B200, production-scale runs)</strong></summary> **Llama-3.1-8B** (8 nodes / 64× B200, mbs=2, grad_acc=2, gbs=256, seq_len=8192): ![8B multi-node](https://raw.githubusercontent.com/NVIDIA/bionemo-framework/savitha/lingua-7b-fp8-clean-pr/docs/docs/assets/images/llama3/lingua-7b-mxfp8-multinode.png) MXFP8 + qinit: **22,517 tokens/s/GPU vs 17,644 BF16 — +27.6% throughput (×1.28 speedup)** **Llama-3.1-70B** (4 nodes / 32× B200, cp=2, dp=16, mbs=1, gbs=16, seq_len=8192): ![70B multi-node](https://raw.githubusercontent.com/NVIDIA/bionemo-framework/savitha/lingua-7b-fp8-clean-pr/docs/docs/assets/images/llama3/lingua-70b-mxfp8-multinode.png) MXFP8 + qinit: **2,725 tokens/s/GPU vs 1,972 BF16 — +38.2% throughput (×1.40 speedup)** </details> <details> <summary><strong>Wandb run links</strong></summary> - Single-node 8B — [BF16](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua_7b_bf16_mbs4_1n_bia) / [MXFP8 + qinit](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua_7b_mxfp8_qinit_mbs4_1n_bia) / [MXFP8 (no qinit)](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua_7b_mxfp8_no_qinit_mbs4_1n_bia) - Single-node 70B — [BF16](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_bf16_mbs1_1n_1k_bia) / [MXFP8 + qinit](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_mxfp8_qinit_mbs1_1n_1k_bia) / [MXFP8 + qinit + HPIV](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_mxfp8_qinit_hpiv_mbs1_1n_1k_bia) / [MXFP8 (no qinit)](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_mxfp8_no_qinit_mbs1_1n_1k_bia) - Multi-node 8B — [BF16](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua-7b-bf16-baseline) / [MXFP8 + qinit](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua_7b_mxfp8_qinit_v6_te_main_8n_prenyx) - Multi-node 70B — [BF16](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_bf16_thd_fusedadam_4n_cp2_bia) / [MXFP8 + qinit](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_mxfp8_qinit_thd_fusedadam_4n_cp2_bia) </details> ## Test plan - [x] All existing model-level tests pass (parametrized across DelayedScaling, Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling with xfail for unsupported hardware) - [x] All existing recipe-level tests pass (same parametrization pattern) - [x] `test_quantized_model_init.py` — 4 tests × 4 recipes = 16 test cases (8 pass on L4, 8 xfail for Hopper/Blackwell-only recipes) - [x] `check_copied_files.py` passes — all 3 `modeling_llama_te.py` copies are identical - [x] Pre-commit hooks pass - [ ] Single-node MXFP8 training verified on Blackwell (benchmarked, see above) - [ ] Multi-node training verified on B200 cluster (benchmarked, see above) ### Type of changes - [x] New feature (non-breaking change which adds functionality) ### CI Pipeline Configuration - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable additional test coverage. --------- Signed-off-by: Savitha Srinivasan <savithas@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3a9d005 commit 0d26cdf

39 files changed

Lines changed: 2576 additions & 192 deletions

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class NVLlamaConfig(LlamaConfig):
5252
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
5353
attn_input_format: str = "thd"
5454
self_attn_mask_type: str = "padding_causal"
55+
layer_precision: list[str | None] | None = None
5556

5657
def __init__(
5758
self,
@@ -217,11 +218,54 @@ def _init_method(x):
217218
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
218219
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
219220

221+
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
222+
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
223+
220224
self.gradient_checkpointing = False
221225

222226
# Initialize weights and apply final processing
223227
self.post_init()
224228

229+
def set_recipes(
230+
self,
231+
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
232+
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
233+
) -> None:
234+
"""Attach quantization recipe objects for per-layer autocast.
235+
236+
Recipes are not serializable and must be set at runtime after model creation
237+
and sharding (FSDP/DDP) but before training. The per-layer precision
238+
assignments are read from ``self.config.layer_precision``.
239+
240+
Args:
241+
fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
242+
fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
243+
"""
244+
self._fp8_recipe = fp8_recipe
245+
self._fp4_recipe = fp4_recipe
246+
247+
def get_layer_autocast(self, layer_number: int):
248+
"""Return the appropriate TE autocast context manager for a given layer.
249+
250+
The context interacts with the outer FP8 autocast in the training script:
251+
- FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
252+
- FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
253+
- BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
254+
255+
Args:
256+
layer_number: The 0-indexed layer number.
257+
258+
Returns:
259+
A context manager for the layer's quantization mode.
260+
"""
261+
precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
262+
if precision == "fp8":
263+
return nullcontext()
264+
elif precision == "fp4":
265+
return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
266+
else:
267+
return transformer_engine.pytorch.autocast(enabled=False)
268+
225269
def forward(
226270
self,
227271
input_ids: torch.Tensor | None = None,
@@ -298,12 +342,14 @@ def forward(
298342
if te_rope_emb.dtype != torch.float32:
299343
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
300344

301-
with self.get_autocast_context(None, outer=True):
302-
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
345+
# Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
346+
# by get_layer_autocast(), which nests inside this context.
347+
with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
348+
for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
303349
if output_hidden_states:
304350
all_hidden_states = (*all_hidden_states, hidden_states)
305351

306-
with self.get_autocast_context(layer_idx):
352+
with self.get_layer_autocast(layer_number):
307353
hidden_states = decoder_layer(
308354
hidden_states,
309355
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
@@ -363,8 +409,10 @@ def get_autocast_context(
363409

364410
if init and self.config.use_quantized_model_init:
365411
if precision in ("fp8", "fp4"):
366-
return transformer_engine.pytorch.quantized_model_init(recipe=recipe)
367-
return nullcontext()
412+
return transformer_engine.pytorch.quantized_model_init(
413+
recipe=recipe, preserve_high_precision_init_val=True
414+
)
415+
return transformer_engine.pytorch.quantized_model_init(enabled=False)
368416

369417
if precision == "fp8":
370418
if recipe is None:
@@ -583,6 +631,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
583631
return 0
584632
return max(self.sequences.values())
585633

634+
@property
635+
def is_compileable(self) -> bool:
636+
"""Required by HuggingFace transformers generate() auto-compile check."""
637+
return False
638+
586639
def reorder_cache(self, beam_idx: torch.LongTensor):
587640
"""Reorder the cache based on the beam indices."""
588641
if isinstance(self.cache_manager, PagedKVCacheManager):
@@ -591,8 +644,3 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
591644
updated_key_cache = key_cache.index_select(0, beam_idx)
592645
updated_value_cache = value_cache.index_select(0, beam_idx)
593646
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
594-
595-
@property
596-
def is_compileable(self) -> bool:
597-
"""Return False as this cache is not compatible with torch.compile."""
598-
return False
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import pickle
18+
import subprocess
19+
20+
import pytest
21+
import torch
22+
from transformer_engine.pytorch.fp8 import check_fp8_support
23+
24+
25+
def requires_fp8(func):
26+
"""Decorator to skip tests that require FP8 support."""
27+
fp8_available, reason = check_fp8_support()
28+
return pytest.mark.skipif(not fp8_available, reason=f"FP8 is not supported on this GPU: {reason}")(func)
29+
30+
31+
requires_multi_gpu = pytest.mark.skipif(
32+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
33+
reason="Test requires at least 2 GPUs",
34+
)
35+
36+
37+
@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"])
38+
@requires_fp8
39+
def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
40+
cmd = [
41+
"torchrun",
42+
"--nproc_per_node=1",
43+
"--rdzv-backend=c10d",
44+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
45+
os.path.relpath(__file__),
46+
"--strategy",
47+
strategy,
48+
]
49+
50+
result = subprocess.run(
51+
cmd,
52+
check=False,
53+
text=True,
54+
stdout=subprocess.PIPE,
55+
stderr=subprocess.PIPE,
56+
timeout=240,
57+
)
58+
if result.returncode != 0:
59+
print(f"STDOUT:\n{result.stdout}")
60+
print(f"STDERR:\n{result.stderr}")
61+
pytest.fail(f"Command failed with exit code {result.returncode}")
62+
63+
64+
@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"])
65+
@requires_fp8
66+
@requires_multi_gpu
67+
def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port):
68+
cmd = [
69+
"torchrun",
70+
"--nproc_per_node=2",
71+
"--rdzv-backend=c10d",
72+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
73+
os.path.relpath(__file__),
74+
"--strategy",
75+
strategy,
76+
]
77+
78+
result = subprocess.run(
79+
cmd,
80+
check=False,
81+
text=True,
82+
stdout=subprocess.PIPE,
83+
stderr=subprocess.PIPE,
84+
timeout=240,
85+
)
86+
if result.returncode != 0:
87+
print(f"STDOUT:\n{result.stdout}")
88+
print(f"STDERR:\n{result.stderr}")
89+
pytest.fail(f"Command failed with exit code {result.returncode}")
90+
91+
92+
if __name__ == "__main__":
93+
import argparse
94+
import enum
95+
import os
96+
import sys
97+
from dataclasses import dataclass, field
98+
from pathlib import Path
99+
100+
# Ensure the model directory is on sys.path for bare module imports.
101+
sys.path.insert(0, Path(__file__).resolve().parent.parent.as_posix())
102+
103+
import torch.distributed as dist
104+
from torch.distributed.device_mesh import init_device_mesh
105+
from torch.distributed.fsdp import fully_shard
106+
from torch.optim import AdamW
107+
from transformer_engine.pytorch.fp8 import DelayedScaling, Format
108+
109+
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
110+
111+
def recursive_assert(a, b, path=""):
112+
if isinstance(a, dict) and isinstance(b, dict):
113+
assert a.keys() == b.keys(), f"Dictionary keys mismatch: {a.keys()} != {b.keys()} at {path}"
114+
for k in a:
115+
recursive_assert(a[k], b[k], path=f"{path}.{k}")
116+
elif isinstance(a, list) and isinstance(b, list):
117+
assert len(a) == len(b), f"List lengths mismatch: {len(a)} != {len(b)} at {path}"
118+
for i in range(len(a)):
119+
recursive_assert(a[i], b[i], path=f"{path}.{i}")
120+
elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
121+
torch.testing.assert_close(a, b, msg=f"Tensor mismatch at {path}")
122+
else:
123+
assert a == b, f"Value mismatch at {path}: {a} != {b}"
124+
125+
class Strategy(enum.StrEnum):
126+
DDP = "ddp"
127+
FSDP2 = "fsdp2"
128+
129+
@dataclass
130+
class DistributedConfig:
131+
"""Class to track distributed ranks."""
132+
133+
rank: int = field(default_factory=dist.get_rank)
134+
local_rank: int = field(default_factory=lambda: int(os.environ["LOCAL_RANK"]))
135+
world_size: int = field(default_factory=dist.get_world_size)
136+
137+
def is_main_process(self) -> bool:
138+
"""This is the global rank 0 process, to be used for wandb logging, etc."""
139+
return self.rank == 0
140+
141+
parser = argparse.ArgumentParser()
142+
parser.add_argument("--strategy", type=Strategy, default=Strategy.DDP, choices=[Strategy.FSDP2, Strategy.DDP])
143+
args = parser.parse_args()
144+
145+
torch.distributed.init_process_group(backend="nccl")
146+
dist_config = DistributedConfig()
147+
torch.cuda.set_device(dist_config.local_rank)
148+
device_mesh = init_device_mesh(
149+
"cuda",
150+
mesh_shape=(dist_config.world_size, 1),
151+
mesh_dim_names=("dp", "tp"),
152+
)
153+
device = f"cuda:{dist_config.local_rank}"
154+
155+
fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10)
156+
157+
config = NVLlamaConfig(
158+
hidden_size=256,
159+
intermediate_size=512,
160+
num_hidden_layers=6,
161+
num_attention_heads=8,
162+
num_key_value_heads=4,
163+
vocab_size=100,
164+
dtype=torch.bfloat16,
165+
)
166+
config.layer_precision = ["fp8"] * config.num_hidden_layers
167+
model = NVLlamaForCausalLM(config)
168+
169+
if args.strategy is Strategy.FSDP2:
170+
for layer in model.model.layers:
171+
fully_shard(layer, mesh=device_mesh["dp"])
172+
fully_shard(model, mesh=device_mesh["dp"])
173+
model.to(device)
174+
175+
elif args.strategy is Strategy.DDP:
176+
model.to(device)
177+
model = torch.nn.parallel.DistributedDataParallel(
178+
model,
179+
device_ids=[dist_config.local_rank],
180+
output_device=dist_config.local_rank,
181+
device_mesh=device_mesh["dp"],
182+
)
183+
184+
optimizer = AdamW(model.parameters())
185+
186+
# Attach FP8 recipes to the model (layer precision is already on config).
187+
llama_model = model.module.model if args.strategy is Strategy.DDP else model.model
188+
llama_model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
189+
190+
model.train()
191+
192+
generator = torch.Generator()
193+
generator.manual_seed(torch.distributed.get_rank())
194+
195+
for _ in range(3):
196+
input_data = {
197+
"input_ids": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
198+
"labels": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
199+
"attention_mask": torch.ones(1, 32),
200+
}
201+
input_data = {k: v.to(torch.cuda.current_device()) for k, v in input_data.items()}
202+
203+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
204+
outputs = model(**input_data)
205+
206+
outputs.loss.backward()
207+
208+
# Access FP8 extra states directly from modules instead of state_dict()
209+
# since state_dict() now filters them out for HuggingFace compatibility
210+
fp8_extra_states = {}
211+
for name, module in model.named_modules():
212+
if hasattr(module, "_extra_state") and callable(module._extra_state):
213+
extra_state = module._extra_state()
214+
if extra_state is not None and len(extra_state) > 0:
215+
fp8_extra_states[f"{name}._extra_state"] = extra_state
216+
217+
# lm_head is BF16, not FP8, so exclude it from FP8 checks
218+
fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key}
219+
220+
# 2 ranks, test to ensure that both ranks have the same FP8 extra states
221+
if torch.distributed.get_world_size() == 2:
222+
outputs_list = [None] * torch.distributed.get_world_size() if torch.distributed.get_rank() == 0 else None
223+
torch.distributed.gather_object(fp8_extra_states, outputs_list, dst=0)
224+
if torch.distributed.get_rank() == 0:
225+
assert outputs_list is not None
226+
227+
for key in outputs_list[0]:
228+
state_1 = outputs_list[0][key]
229+
state_2 = outputs_list[1][key]
230+
assert len(state_1) > 0, f"No FP8 extra states for {key}, rank 0"
231+
assert len(state_2) > 0, f"No FP8 extra states for {key}, rank 1"
232+
dict_1 = pickle.loads(state_1.detach().numpy(force=True).tobytes())
233+
dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes())
234+
recursive_assert(dict_1, dict_2)
235+
236+
# One rank, test to ensure the correct FP8 extra states are saved
237+
if torch.distributed.get_world_size() == 1:
238+
for key, val in fp8_extra_states.items():
239+
assert len(val) > 0, f"No FP8 extra states for {key}"
240+
fp8_meta_dict = pickle.loads(val.detach().numpy(force=True).tobytes())
241+
assert fp8_meta_dict["recipe"] == fp8_recipe, f"Recipe mismatch for {key}"
242+
243+
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)