Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions bionemo-recipes/models/mixtral/modeling_mixtral_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,22 @@ def _restack_from_views(self) -> None:
device = torch.cuda.current_device()
for attr_name in ("experts_gate_up_weight", "experts_down_weight"):
old_param = getattr(self, attr_name)
new_data = torch.empty_like(old_param, device=device)
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
setattr(self, attr_name, nn.Parameter(new_data))
if isinstance(old_param.data, DTensor):
# FSDP2 has sharded this param; materialize the local shard on CUDA
# and reconstruct the DTensor wrapper so FSDP2 can manage it.
local_data = old_param.data.to_local()
new_local = torch.empty(local_data.shape, dtype=local_data.dtype, device=device)
torch.nn.init.normal_(new_local, mean=0.0, std=self.initializer_range)
new_dtensor = DTensor.from_local(
new_local,
device_mesh=old_param.data.device_mesh,
placements=old_param.data.placements,
)
setattr(self, attr_name, nn.Parameter(new_dtensor))
else:
new_data = torch.empty_like(old_param, device=device)
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
setattr(self, attr_name, nn.Parameter(new_data))

# Re-sync views to point to the new stacked parameter
self._sync_expert_views()
Expand All @@ -298,13 +311,15 @@ def _sync_expert_views(self) -> None:
gate_up_w = self.experts_gate_up_weight
if isinstance(gate_up_w, DTensor):
gate_up_w = gate_up_w.to_local()
for i in range(self.num_local_experts):
num_local = gate_up_w.shape[0]
for i in range(num_local):
object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i])

down_w = self.experts_down_weight
if isinstance(down_w, DTensor):
down_w = down_w.to_local()
for i in range(self.num_local_experts):
num_local_down = down_w.shape[0]
for i in range(num_local_down):
object.__setattr__(self.experts_down, f"weight{i}", down_w[i])

def set_ep_group(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None:
Expand Down Expand Up @@ -394,7 +409,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
self._sync_expert_views()

dispatch_output = self.dispatcher.dispatch(hidden_states, selected_experts, routing_weights)
expert_output = self._expert_ffn(dispatch_output.expert_input, dispatch_output.tokens_per_expert)

expert_input = dispatch_output.expert_input
tokens_per_expert = dispatch_output.tokens_per_expert

# MXFP8 requires both tensor dims divisible by 32. Upstream attention layers
# get this from the collator (pad_sequences_to_be_divisible_by=32), but after
# all-to-all dispatch the per-rank token count is data-dependent (routing
# decisions pick different expert loads). Pad here so GroupedLinear's MXFP8
# kernels don't assert, then slice the padding off afterwards.
n_tokens = expert_input.shape[0]
mxfp8_pad = (32 - n_tokens % 32) % 32
if mxfp8_pad:
expert_input = torch.nn.functional.pad(expert_input, (0, 0, 0, mxfp8_pad))
# Attribute the padding tokens to the last expert so m_splits still sums correctly.
tokens_per_expert = list(tokens_per_expert)
tokens_per_expert[-1] += mxfp8_pad

expert_output = self._expert_ffn(expert_input, tokens_per_expert)

if mxfp8_pad:
expert_output = expert_output[:n_tokens]

output = self.dispatcher.combine(expert_output, dispatch_output.handle)

return output.reshape(original_shape)
Expand Down Expand Up @@ -503,12 +539,20 @@ def __init__(
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe

if fp8_recipe is not None and self.config.layer_precision is None:
if fp4_recipe is not None:
if self.config.layer_precision is None:
if fp8_recipe is not None and fp4_recipe is not None:
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")

warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
if fp8_recipe is not None:
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
elif fp4_recipe is not None:
raise RuntimeError(
"FP4 recipe provided but no layer_precision configured. "
"Set layer_precision explicitly when using FP4."
)

if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)

Expand Down Expand Up @@ -857,6 +901,10 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
class HFInferenceParams(InferenceParams):
"""Extension of the InferenceParams class to support HF generate() and beam search."""

# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
# custom TE-based cache is not compatible with torch.compile generate().
is_compileable = False

def get_seq_length(self, layer_idx: int = 0) -> int:
"""Return the current cached sequence length.

Expand Down
9 changes: 9 additions & 0 deletions bionemo-recipes/recipes/mixtral_native_te/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# syntax=docker/dockerfile:1.4
FROM nvcr.io/nvidia/pytorch:26.03-py3

RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
PIP_CONSTRAINT= pip install -r /requirements.txt

WORKDIR /workspace/bionemo
COPY . .
31 changes: 31 additions & 0 deletions bionemo-recipes/recipes/mixtral_native_te/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# TransformerEngine-accelerated Mixtral training with a native PyTorch training loop

This folder demonstrates how to train TE-accelerated Mixtral with a native PyTorch training loop using FSDP2 for
distributed training. The recipe mirrors the structure and conventions of `llama3_native_te`, and includes a Lingua-style
configuration for natural-language pre-training on DCLM Baseline 1.0.

## Commands

Single GPU sanity run:

```bash
python train_fsdp2.py --config-name L0_sanity
```

Single GPU Lingua smoke run:

```bash
python train_fsdp2.py --config-name L2_lingua_8x1B num_train_steps=20 checkpoint.ckpt_dir=./checkpoints
```

Cluster or multi-GPU run:

```bash
torchrun --standalone --nproc_per_node=2 train_fsdp2.py --config-name L2_lingua_8x1B
```

## Notes

- The Lingua config uses the `meta-llama/Meta-Llama-3-8B` tokenizer and streams `mlfoundations/dclm-baseline-1.0`.
- `expert_parallel_size` remains `1` in this v1 recipe so it matches the existing Llama3 Lingua recipe structure.
- Use `HF_TOKEN` for Hugging Face access and `WANDB_KEY` for Weights & Biases logging.
Loading
Loading