Skip to content

Commit 63f58d3

Browse files
trvachovclaude
andcommitted
Add Mixtral MoE training recipes (FSDP2 + TransformerEngine)
Two self-contained recipes following existing Llama3/ESM2 recipe conventions: - bionemo-recipes/recipes/mixtral_native_te/: TE-accelerated Mixtral FSDP2 training with a Lingua-style DCLM Baseline 1.0 pre-training config for Mixtral-8x1B and 8x7B. Includes DDP and FSDP2 entry points. - bionemo-recipes/recipes/opengenome2_mixtral_native_te/: TE Mixtral for autoregressive DNA on OpenGenome2 metagenomes, mirroring opengenome2_llama_native_te (THD packing, genomic label masking, validation, nucleotide tokenizer packaged with the recipe). Key design decisions: - Self-contained KISS: fused MoE kernels (fused_a2a, fused_token_router, fused_indices_converter), collator, checkpoint, and perf logger are duplicated across both recipes rather than shared, matching repo convention. - Configurable expert parallelism via all-to-all token dispatch; expert_parallel_size=1 by default for parity with the Llama3 recipe. - MXFP8 alignment: pad post-alltoall MoE expert input to a multiple of 32 before GroupedLinear (attribute padding to the last expert so m_splits sums correctly; slice padding off the output). No-op for non-MXFP8 and already-aligned batches. Verified on 8x B300 SXM6 with Mixtral-8x7B EP=8 at SEQ=8192: FP8 1.196 s/step, MXFP8 1.248 s/step. - FSDP2 checkpointing uses DCP format (.distcp files), covered by dedicated distributed checkpointing tests. - CI-robust tests: session-scoped local WordLevel tokenizer fixture avoids HuggingFace Hub dependency; expanded train coverage (7 single-GPU, 4 two-GPU tests per recipe) plus dataset and distributed checkpoint tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Timur Rvachov <trvachov@nvidia.com>
1 parent 86c8329 commit 63f58d3

70 files changed

Lines changed: 13844 additions & 11 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,22 @@ def _restack_from_views(self) -> None:
279279
device = torch.cuda.current_device()
280280
for attr_name in ("experts_gate_up_weight", "experts_down_weight"):
281281
old_param = getattr(self, attr_name)
282-
new_data = torch.empty_like(old_param, device=device)
283-
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
284-
setattr(self, attr_name, nn.Parameter(new_data))
282+
if isinstance(old_param.data, DTensor):
283+
# FSDP2 has sharded this param; materialize the local shard on CUDA
284+
# and reconstruct the DTensor wrapper so FSDP2 can manage it.
285+
local_data = old_param.data.to_local()
286+
new_local = torch.empty(local_data.shape, dtype=local_data.dtype, device=device)
287+
torch.nn.init.normal_(new_local, mean=0.0, std=self.initializer_range)
288+
new_dtensor = DTensor.from_local(
289+
new_local,
290+
device_mesh=old_param.data.device_mesh,
291+
placements=old_param.data.placements,
292+
)
293+
setattr(self, attr_name, nn.Parameter(new_dtensor))
294+
else:
295+
new_data = torch.empty_like(old_param, device=device)
296+
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
297+
setattr(self, attr_name, nn.Parameter(new_data))
285298

286299
# Re-sync views to point to the new stacked parameter
287300
self._sync_expert_views()
@@ -298,13 +311,15 @@ def _sync_expert_views(self) -> None:
298311
gate_up_w = self.experts_gate_up_weight
299312
if isinstance(gate_up_w, DTensor):
300313
gate_up_w = gate_up_w.to_local()
301-
for i in range(self.num_local_experts):
314+
num_local = gate_up_w.shape[0]
315+
for i in range(num_local):
302316
object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i])
303317

304318
down_w = self.experts_down_weight
305319
if isinstance(down_w, DTensor):
306320
down_w = down_w.to_local()
307-
for i in range(self.num_local_experts):
321+
num_local_down = down_w.shape[0]
322+
for i in range(num_local_down):
308323
object.__setattr__(self.experts_down, f"weight{i}", down_w[i])
309324

310325
def set_ep_group(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None:
@@ -394,7 +409,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
394409
self._sync_expert_views()
395410

396411
dispatch_output = self.dispatcher.dispatch(hidden_states, selected_experts, routing_weights)
397-
expert_output = self._expert_ffn(dispatch_output.expert_input, dispatch_output.tokens_per_expert)
412+
413+
expert_input = dispatch_output.expert_input
414+
tokens_per_expert = dispatch_output.tokens_per_expert
415+
416+
# MXFP8 requires both tensor dims divisible by 32. Upstream attention layers
417+
# get this from the collator (pad_sequences_to_be_divisible_by=32), but after
418+
# all-to-all dispatch the per-rank token count is data-dependent (routing
419+
# decisions pick different expert loads). Pad here so GroupedLinear's MXFP8
420+
# kernels don't assert, then slice the padding off afterwards.
421+
n_tokens = expert_input.shape[0]
422+
mxfp8_pad = (32 - n_tokens % 32) % 32
423+
if mxfp8_pad:
424+
expert_input = torch.nn.functional.pad(expert_input, (0, 0, 0, mxfp8_pad))
425+
# Attribute the padding tokens to the last expert so m_splits still sums correctly.
426+
tokens_per_expert = list(tokens_per_expert)
427+
tokens_per_expert[-1] += mxfp8_pad
428+
429+
expert_output = self._expert_ffn(expert_input, tokens_per_expert)
430+
431+
if mxfp8_pad:
432+
expert_output = expert_output[:n_tokens]
433+
398434
output = self.dispatcher.combine(expert_output, dispatch_output.handle)
399435

400436
return output.reshape(original_shape)
@@ -503,12 +539,20 @@ def __init__(
503539
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
504540
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
505541

506-
if fp8_recipe is not None and self.config.layer_precision is None:
507-
if fp4_recipe is not None:
542+
if self.config.layer_precision is None:
543+
if fp8_recipe is not None and fp4_recipe is not None:
508544
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
509-
510-
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
511-
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
545+
if fp8_recipe is not None:
546+
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
547+
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
548+
elif fp4_recipe is not None:
549+
raise RuntimeError(
550+
"FP4 recipe provided but no layer_precision configured. "
551+
"Set layer_precision explicitly when using FP4."
552+
)
553+
554+
if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
555+
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
512556

513557
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
514558

@@ -857,6 +901,10 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
857901
class HFInferenceParams(InferenceParams):
858902
"""Extension of the InferenceParams class to support HF generate() and beam search."""
859903

904+
# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
905+
# custom TE-based cache is not compatible with torch.compile generate().
906+
is_compileable = False
907+
860908
def get_seq_length(self, layer_idx: int = 0) -> int:
861909
"""Return the current cached sequence length.
862910
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# syntax=docker/dockerfile:1.4
2+
FROM nvcr.io/nvidia/pytorch:26.03-py3
3+
4+
RUN --mount=type=cache,target=/root/.cache/pip \
5+
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
6+
PIP_CONSTRAINT= pip install -r /requirements.txt
7+
8+
WORKDIR /workspace/bionemo
9+
COPY . .
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# TransformerEngine-accelerated Mixtral training with a native PyTorch training loop
2+
3+
This folder demonstrates how to train TE-accelerated Mixtral with a native PyTorch training loop using FSDP2 for
4+
distributed training. The recipe mirrors the structure and conventions of `llama3_native_te`, and includes a Lingua-style
5+
configuration for natural-language pre-training on DCLM Baseline 1.0.
6+
7+
## Commands
8+
9+
Single GPU sanity run:
10+
11+
```bash
12+
python train_fsdp2.py --config-name L0_sanity
13+
```
14+
15+
Single GPU Lingua smoke run:
16+
17+
```bash
18+
python train_fsdp2.py --config-name L2_lingua_8x1B num_train_steps=20 checkpoint.ckpt_dir=./checkpoints
19+
```
20+
21+
Cluster or multi-GPU run:
22+
23+
```bash
24+
torchrun --standalone --nproc_per_node=2 train_fsdp2.py --config-name L2_lingua_8x1B
25+
```
26+
27+
## Notes
28+
29+
- The Lingua config uses the `meta-llama/Meta-Llama-3-8B` tokenizer and streams `mlfoundations/dclm-baseline-1.0`.
30+
- `expert_parallel_size` remains `1` in this v1 recipe so it matches the existing Llama3 Lingua recipe structure.
31+
- Use `HF_TOKEN` for Hugging Face access and `WANDB_KEY` for Weights & Biases logging.

0 commit comments

Comments
 (0)