diff --git a/bionemo-recipes/models/mixtral/modeling_mixtral_te.py b/bionemo-recipes/models/mixtral/modeling_mixtral_te.py index e61a852016..61430fb053 100644 --- a/bionemo-recipes/models/mixtral/modeling_mixtral_te.py +++ b/bionemo-recipes/models/mixtral/modeling_mixtral_te.py @@ -279,9 +279,22 @@ def _restack_from_views(self) -> None: device = torch.cuda.current_device() for attr_name in ("experts_gate_up_weight", "experts_down_weight"): old_param = getattr(self, attr_name) - new_data = torch.empty_like(old_param, device=device) - torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range) - setattr(self, attr_name, nn.Parameter(new_data)) + if isinstance(old_param.data, DTensor): + # FSDP2 has sharded this param; materialize the local shard on CUDA + # and reconstruct the DTensor wrapper so FSDP2 can manage it. + local_data = old_param.data.to_local() + new_local = torch.empty(local_data.shape, dtype=local_data.dtype, device=device) + torch.nn.init.normal_(new_local, mean=0.0, std=self.initializer_range) + new_dtensor = DTensor.from_local( + new_local, + device_mesh=old_param.data.device_mesh, + placements=old_param.data.placements, + ) + setattr(self, attr_name, nn.Parameter(new_dtensor)) + else: + new_data = torch.empty_like(old_param, device=device) + torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range) + setattr(self, attr_name, nn.Parameter(new_data)) # Re-sync views to point to the new stacked parameter self._sync_expert_views() @@ -298,13 +311,15 @@ def _sync_expert_views(self) -> None: gate_up_w = self.experts_gate_up_weight if isinstance(gate_up_w, DTensor): gate_up_w = gate_up_w.to_local() - for i in range(self.num_local_experts): + num_local = gate_up_w.shape[0] + for i in range(num_local): object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i]) down_w = self.experts_down_weight if isinstance(down_w, DTensor): down_w = down_w.to_local() - for i in range(self.num_local_experts): + num_local_down = down_w.shape[0] + for i in range(num_local_down): object.__setattr__(self.experts_down, f"weight{i}", down_w[i]) def set_ep_group(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None: @@ -394,7 +409,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self._sync_expert_views() dispatch_output = self.dispatcher.dispatch(hidden_states, selected_experts, routing_weights) - expert_output = self._expert_ffn(dispatch_output.expert_input, dispatch_output.tokens_per_expert) + + expert_input = dispatch_output.expert_input + tokens_per_expert = dispatch_output.tokens_per_expert + + # MXFP8 requires both tensor dims divisible by 32. Upstream attention layers + # get this from the collator (pad_sequences_to_be_divisible_by=32), but after + # all-to-all dispatch the per-rank token count is data-dependent (routing + # decisions pick different expert loads). Pad here so GroupedLinear's MXFP8 + # kernels don't assert, then slice the padding off afterwards. + n_tokens = expert_input.shape[0] + mxfp8_pad = (32 - n_tokens % 32) % 32 + if mxfp8_pad: + expert_input = torch.nn.functional.pad(expert_input, (0, 0, 0, mxfp8_pad)) + # Attribute the padding tokens to the last expert so m_splits still sums correctly. + tokens_per_expert = list(tokens_per_expert) + tokens_per_expert[-1] += mxfp8_pad + + expert_output = self._expert_ffn(expert_input, tokens_per_expert) + + if mxfp8_pad: + expert_output = expert_output[:n_tokens] + output = self.dispatcher.combine(expert_output, dispatch_output.handle) return output.reshape(original_shape) @@ -503,12 +539,20 @@ def __init__( self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe - if fp8_recipe is not None and self.config.layer_precision is None: - if fp4_recipe is not None: + if self.config.layer_precision is None: + if fp8_recipe is not None and fp4_recipe is not None: raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.") - - warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning) - self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers + if fp8_recipe is not None: + warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning) + self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers + elif fp4_recipe is not None: + raise RuntimeError( + "FP4 recipe provided but no layer_precision configured. " + "Set layer_precision explicitly when using FP4." + ) + + if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None: + raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.") self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) @@ -857,6 +901,10 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None): class HFInferenceParams(InferenceParams): """Extension of the InferenceParams class to support HF generate() and beam search.""" + # Required by transformers >= 5.4 _valid_auto_compile_criteria(); this + # custom TE-based cache is not compatible with torch.compile generate(). + is_compileable = False + def get_seq_length(self, layer_idx: int = 0) -> int: """Return the current cached sequence length. diff --git a/bionemo-recipes/recipes/mixtral_native_te/Dockerfile b/bionemo-recipes/recipes/mixtral_native_te/Dockerfile new file mode 100644 index 0000000000..faedb5f609 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/Dockerfile @@ -0,0 +1,9 @@ +# syntax=docker/dockerfile:1.4 +FROM nvcr.io/nvidia/pytorch:26.03-py3 + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=requirements.txt,target=/requirements.txt \ + PIP_CONSTRAINT= pip install -r /requirements.txt + +WORKDIR /workspace/bionemo +COPY . . diff --git a/bionemo-recipes/recipes/mixtral_native_te/README.md b/bionemo-recipes/recipes/mixtral_native_te/README.md new file mode 100644 index 0000000000..e31e8eb476 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/README.md @@ -0,0 +1,31 @@ +# TransformerEngine-accelerated Mixtral training with a native PyTorch training loop + +This folder demonstrates how to train TE-accelerated Mixtral with a native PyTorch training loop using FSDP2 for +distributed training. The recipe mirrors the structure and conventions of `llama3_native_te`, and includes a Lingua-style +configuration for natural-language pre-training on DCLM Baseline 1.0. + +## Commands + +Single GPU sanity run: + +```bash +python train_fsdp2.py --config-name L0_sanity +``` + +Single GPU Lingua smoke run: + +```bash +python train_fsdp2.py --config-name L2_lingua_8x1B num_train_steps=20 checkpoint.ckpt_dir=./checkpoints +``` + +Cluster or multi-GPU run: + +```bash +torchrun --standalone --nproc_per_node=2 train_fsdp2.py --config-name L2_lingua_8x1B +``` + +## Notes + +- The Lingua config uses the `meta-llama/Meta-Llama-3-8B` tokenizer and streams `mlfoundations/dclm-baseline-1.0`. +- `expert_parallel_size` remains `1` in this v1 recipe so it matches the existing Llama3 Lingua recipe structure. +- Use `HF_TOKEN` for Hugging Face access and `WANDB_KEY` for Weights & Biases logging. diff --git a/bionemo-recipes/recipes/mixtral_native_te/checkpoint.py b/bionemo-recipes/recipes/mixtral_native_te/checkpoint.py new file mode 100644 index 0000000000..82590298d6 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/checkpoint.py @@ -0,0 +1,617 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import time +from dataclasses import dataclass +from pathlib import Path +from typing import NamedTuple + +import torch +import transformers +from distributed_config import DistributedConfig +from safetensors.torch import save_file +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_state_dict, + set_state_dict, +) +from torch.distributed.checkpoint.state_dict_loader import load as dcp_load +from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save +from torch.distributed.checkpoint.state_dict_saver import save as dcp_save +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.tensor import DTensor +from torchdata.stateful_dataloader import StatefulDataLoader +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor + + +logger = logging.getLogger(__name__) + +# Tracks in-flight async checkpoint futures keyed by strategy name (e.g. "fsdp2"). +# Each entry holds the Future returned by dcp_async_save so we can await it before starting +# the next async save or before shutting down. +_ckpt_futures: dict = {} + + +def get_fsdp2_checkpoint_process_group( + process_group: torch.distributed.ProcessGroup | None, + *, + expert_parallel_size: int, +) -> torch.distributed.ProcessGroup | None: + """Choose the process group used by DCP for FSDP2 checkpoints. + + When EP is active, the state dict contains DTensor expert weights that are sharded on the + EP mesh as well as FSDP shards on the DP mesh. DCP must therefore coordinate across the + full world instead of only the DP subgroup. + """ + if expert_parallel_size > 1: + return None + return process_group + + +def get_ddp_model_checkpoint_path(ckpt_path: Path, rank: int) -> Path: + """Return the per-rank model checkpoint path for DDP EP checkpoints.""" + return ckpt_path / f"model_rank_{rank}.pt" + + +def get_ddp_optimizer_checkpoint_path(ckpt_path: Path, rank: int) -> Path: + """Return the per-rank optimizer checkpoint path for DDP EP checkpoints.""" + return ckpt_path / f"optimizer_rank_{rank}.pt" + + +def get_ddp_metadata_checkpoint_path(ckpt_path: Path) -> Path: + """Return the shared metadata checkpoint path for DDP EP checkpoints.""" + return ckpt_path / "metadata.pt" + + +def unwrap_checkpoint_model(model: torch.nn.Module) -> torch.nn.Module: + """Strip compile/DDP wrappers until we reach the underlying module.""" + unwrapped = model + while True: + if hasattr(unwrapped, "_orig_mod"): + unwrapped = unwrapped._orig_mod # type: ignore[attr-defined] + continue + if hasattr(unwrapped, "module"): + unwrapped = unwrapped.module # type: ignore[attr-defined] + continue + return unwrapped + + +def build_unfiltered_model_state_dict(model: torch.nn.Module) -> dict[str, torch.Tensor]: + """Build a model state dict without losing EP expert parameters or TE extra state.""" + model_state_dict = torch.nn.Module.state_dict(model) + for name, param in model.named_parameters(): + if name not in model_state_dict: + model_state_dict[name] = param.detach().clone() + for name, buffer in model.named_buffers(): + if name not in model_state_dict: + model_state_dict[name] = buffer.detach().clone() + return model_state_dict + + +class CheckpointOutput(NamedTuple): + """Output of checkpoint loading.""" + + model: torch.nn.Module + optimizer: torch.optim.Optimizer + scheduler: torch.optim.lr_scheduler.LRScheduler + dataloader: StatefulDataLoader | None + step: int + epoch: int + + +# ============================================================================ +# Helper functions +# ============================================================================ + + +def get_latest_checkpoint(ckpt_path: str | os.PathLike) -> tuple[Path | None, int]: + """Get the latest checkpoint path and step number. + + Returns: + Tuple of (checkpoint path, step number). + If no checkpoint files are found, returns (None, 0). + """ + ckpt_path = Path(ckpt_path) + if not ckpt_path.exists(): + return None, 0 + + checkpoints = [f for f in ckpt_path.iterdir() if f.name.startswith("step_")] + + if not checkpoints: + return None, 0 + + latest = max(checkpoints, key=lambda x: int(Path(x).stem.split("_")[1])) + step = int(Path(latest).stem.split("_")[1]) + return latest, step + + +def should_save_checkpoint(step: int, save_every_n_steps: int) -> bool: + """Determine if a checkpoint should be saved.""" + return save_every_n_steps > 0 and step % save_every_n_steps == 0 and step > 0 + + +def prune_checkpoints(ckpt_path: str | os.PathLike, max_checkpoints: int) -> None: + """Prune checkpoints to keep only the latest `max_checkpoints` checkpoints.""" + ckpt_path = Path(ckpt_path) + checkpoints = [f for f in ckpt_path.iterdir() if f.name.startswith("step_")] + checkpoints.sort(key=lambda x: int(Path(x).stem.split("_")[1])) + if len(checkpoints) > max_checkpoints: + for checkpoint in checkpoints[:-max_checkpoints]: + logger.info(f"Pruning checkpoint {checkpoint}") + if checkpoint.is_dir(): + shutil.rmtree(checkpoint) + else: + os.remove(checkpoint) + + +# ============================================================================ +# DDP Checkpointing +# ============================================================================ + + +def load_checkpoint_ddp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + weights_only: bool = True, + expert_parallel_size: int = 1, +) -> CheckpointOutput: + """Load DDP checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The path to the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + weights_only: Whether to load the checkpoint weights only. We have to set this to True when loading FP8 + checkpoints. + expert_parallel_size: Expert parallelism size. When > 1, loads expert weights with EP-aware state dict handling. + """ + checkpoint_path, _ = get_latest_checkpoint(ckpt_path) + + if not checkpoint_path: + logger.info("No DDP checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + if expert_parallel_size > 1: + underlying_model = unwrap_checkpoint_model(model) + underlying_model.load_state_dict( + torch.load( + get_ddp_model_checkpoint_path(checkpoint_path, dist_config.rank), + map_location=f"cuda:{dist_config.local_rank}", + weights_only=weights_only, + ) + ) + optimizer.load_state_dict( + torch.load( + get_ddp_optimizer_checkpoint_path(checkpoint_path, dist_config.rank), + map_location=f"cuda:{dist_config.local_rank}", + weights_only=weights_only, + ) + ) + metadata = torch.load( + get_ddp_metadata_checkpoint_path(checkpoint_path), + map_location=f"cuda:{dist_config.local_rank}", + weights_only=weights_only, + ) + if metadata["world_size"] != dist_config.world_size: + raise RuntimeError( + "DDP EP checkpoints require the same world size when resuming: " + f"{metadata['world_size']} != {dist_config.world_size}" + ) + scheduler.load_state_dict(metadata["scheduler"]) + dataloader = load_dataloader(dataloader, checkpoint_path, dist_config) + step = metadata["step"] + epoch = metadata["epoch"] + else: + checkpoint = torch.load( + checkpoint_path / "checkpoint.pt", + map_location=f"cuda:{dist_config.local_rank}", + weights_only=weights_only, + ) + + underlying_model = unwrap_checkpoint_model(model) + underlying_model.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + dataloader = load_dataloader(dataloader, checkpoint_path, dist_config) + step = checkpoint["step"] + epoch = checkpoint["epoch"] + + if dist_config.is_main_process(): + logger.info(f"Loaded DDP checkpoint from step {step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, step + 1, epoch) + + +def save_checkpoint_ddp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + epoch: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + max_checkpoints: int | None = None, + expert_parallel_size: int = 1, +) -> None: + """Saves the Dataloader state and the DDP checkpoint.""" + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + # Dataloader checkpointing needs to happen on all ranks, while DDP model checkpointing only needs to happen on the + # main process. + save_dataloader(dataloader, checkpoint_path, dist_config) + + if expert_parallel_size == 1 and not dist_config.is_main_process(): + return + + if expert_parallel_size > 1: + underlying_model = unwrap_checkpoint_model(model) + torch.save( + build_unfiltered_model_state_dict(underlying_model), + get_ddp_model_checkpoint_path(checkpoint_path, dist_config.rank), + ) + torch.save( + optimizer.state_dict(), + get_ddp_optimizer_checkpoint_path(checkpoint_path, dist_config.rank), + ) + if dist_config.is_main_process(): + torch.save( + { + "scheduler": scheduler.state_dict(), + "step": step, + "epoch": epoch, + "world_size": dist_config.world_size, + }, + get_ddp_metadata_checkpoint_path(checkpoint_path), + ) + logger.info(f"Saved distributed DDP checkpoint to {checkpoint_path}") + else: + underlying_model = unwrap_checkpoint_model(model) + + model_state_dict = build_unfiltered_model_state_dict(underlying_model) + + torch.save( + { + # Use nn.Module.state_dict to preserve Mixtral EP expert weights and TE _extra_state. + "model": model_state_dict, + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "step": step, + "epoch": epoch, + }, + checkpoint_path / "checkpoint.pt", + ) + + logger.info(f"Saved DDP checkpoint to {checkpoint_path}") + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + +def save_final_model_ddp( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for DDP - only on main process.""" + if not dist_config.is_main_process(): + return + + # Unwrap model if wrapped + underlying_model: transformers.PreTrainedModel = unwrap_checkpoint_model(model) # type: ignore + + os.makedirs(save_directory, exist_ok=True) + underlying_model.save_pretrained(save_directory, state_dict=underlying_model.state_dict(), safe_serialization=True) + logger.info(f"Saved final DDP model to {save_directory}") + + +# ============================================================================ +# FSDP2 Checkpointing +# ============================================================================ + + +@dataclass +class AppState(Stateful): + """AppState for FSDP2 checkpoint. + + Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html + """ + + model: torch.nn.Module + optimizer: torch.optim.Optimizer + scheduler: torch.optim.lr_scheduler.LRScheduler + step: int = 0 + epoch: int = 0 + + def state_dict(self): + """Get the state dict for the model, optimizer, scheduler, and step.""" + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) + model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")} + return { + "model": model_state_dict, + "optim": optimizer_state_dict, + "scheduler": self.scheduler.state_dict(), + "step": self.step, + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict: dict): + """Load the state dict for the model, optimizer, scheduler, and step.""" + set_state_dict( + self.model, + self.optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"], + options=StateDictOptions(strict=False), + ) + self.scheduler.load_state_dict(state_dict["scheduler"]) + self.step = state_dict["step"] + self.epoch = state_dict["epoch"] + + +def load_checkpoint_fsdp2( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + process_group: torch.distributed.ProcessGroup | None = None, + expert_parallel_size: int = 1, +) -> CheckpointOutput: + """Load FSDP2 checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The directory containing checkpoints. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + process_group: The process group to use for checkpointing. + expert_parallel_size: Expert parallelism size. When > 1, loads expert weights with EP-aware state dict handling. + """ + checkpoint_path, _ = get_latest_checkpoint(ckpt_path) + if not checkpoint_path: + logger.info("No FSDP2 checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + app_state = AppState( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + state_dict = {"app": app_state} + dcp_load( + state_dict, + checkpoint_id=checkpoint_path, + process_group=get_fsdp2_checkpoint_process_group(process_group, expert_parallel_size=expert_parallel_size), + ) + + if dataloader is not None: + load_dataloader( + dataloader=dataloader, + ckpt_path=checkpoint_path, + dist_config=dist_config, + ) + + logger.info(f"Loaded distributed FSDP2 checkpoint from step {app_state.step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, app_state.step + 1, app_state.epoch) + + +def save_checkpoint_fsdp2( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + epoch: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + process_group: torch.distributed.ProcessGroup | None = None, + max_checkpoints: int | None = None, + async_save: bool = False, + expert_parallel_size: int = 1, +) -> None: + """Save FSDP2 checkpoint. + + Args: + model: The model to save. + optimizer: The optimizer to save. + scheduler: The LR scheduler to save. + ckpt_path: The directory to save the checkpoint. + step: The step number to save the checkpoint. + epoch: The epoch number to save the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to save. + process_group: The process group to use for checkpointing. + max_checkpoints: The maximum number of checkpoints to keep. + async_save: Whether to save the checkpoint asynchronously. + expert_parallel_size: Expert parallelism size. When > 1, saves expert weights with EP-aware state dict handling. + """ + start_time = time.perf_counter() + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + model_params = (p.to_local() if isinstance(p, DTensor) else p for p in model.parameters()) + if async_save and any((isinstance(p, QuantizedTensor) for p in model_params)): + logger.warning( + "Async checkpointing is not supported for FP8 models, falling back to synchronous checkpointing." + ) + async_save = False + + if dataloader is not None: + save_dataloader( + dataloader=dataloader, + ckpt_path=checkpoint_path, + dist_config=dist_config, + ) + logger.info(f"Saved FSDP2 dataloader to {ckpt_path}") + + state_dict = {"app": AppState(model=model, optimizer=optimizer, scheduler=scheduler, step=step, epoch=epoch)} + checkpoint_process_group = get_fsdp2_checkpoint_process_group( + process_group, + expert_parallel_size=expert_parallel_size, + ) + if async_save: + # If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time. + if "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None: + _ckpt_futures["fsdp2"].result() + + _ckpt_futures["fsdp2"] = dcp_async_save( + state_dict, + checkpoint_id=checkpoint_path, + process_group=checkpoint_process_group, + ) + else: + dcp_save(state_dict, checkpoint_id=checkpoint_path, process_group=checkpoint_process_group) + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + if dist_config.is_main_process(): + logger.info( + f"Saved distributed FSDP2 checkpoint to {checkpoint_path} in {time.perf_counter() - start_time:.2f} seconds" + ) + + +def save_final_model_fsdp2( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for FSDP2 - gather on all ranks, save on main.""" + # ALL ranks must participate in gathering + model_state_dict = get_model_state_dict( + model=model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=True, + ), + ) + + # Only main process saves + if not dist_config.is_main_process(): + return + + os.makedirs(save_directory, exist_ok=True) + + # Save just the weights using safetensors + save_file(model_state_dict, os.path.join(save_directory, "model.safetensors")) + + # Save the config + underlying_model = model.module if hasattr(model, "module") else model + if hasattr(underlying_model, "config"): + underlying_model.config.save_pretrained(save_directory) + + logger.info(f"Saved final FSDP2 model to {save_directory} (weights + config only)") + + +# ============================================================================ +# Dataloader Checkpointing +# ============================================================================ + + +def save_dataloader( + dataloader: StatefulDataLoader | None, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, +): + """Save the dataloader state to a file. + + For resuming training with long epochs, we save the dataloader state as part of the checkpoint to allow for resuming + from the exact same step. Here we save the dataloader state based on global rank. Note, the total number of ranks + and dataloader num_workers should match for resuming training. + + Args: + dataloader: The dataloader to save the state of. + ckpt_path: The path to save the dataloader state to. + dist_config: The distributed configuration. + """ + if dataloader is None: + return + + ckpt_path = Path(ckpt_path) + ckpt_path.mkdir(parents=True, exist_ok=True) + dataloader_path = ckpt_path / f"dataloader_rank_{dist_config.rank}.pt" + + dataloader_state = dataloader.state_dict() + dataloader_state["num_workers"] = dataloader.num_workers + dataloader_state["num_ranks"] = dist_config.world_size + torch.save(dataloader_state, dataloader_path) + if dist_config.is_main_process(): + logger.info(f"Saved dataloader state to {dataloader_path}") + + +def load_dataloader( + dataloader: StatefulDataLoader | None, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, +) -> StatefulDataLoader | None: + """Load the dataloader state from a file. + + Here we load the dataloader state based on global rank. + + Args: + dataloader: The dataloader to load the state of. + ckpt_path: The path to load the dataloader state from. + dist_config: The distributed configuration. + """ + if dataloader is None: + return dataloader + + dataloader_path = Path(ckpt_path) / f"dataloader_rank_{dist_config.rank}.pt" + if not dataloader_path.exists(): + logger.warning( + f"No dataloader checkpoint found for rank {dist_config.rank}, starting dataloader from scratch." + ) + return dataloader + + dataloader_state = torch.load(dataloader_path, weights_only=True) + + if ( + dataloader.num_workers != dataloader_state["num_workers"] + or dist_config.world_size != dataloader_state["num_ranks"] + ): + logger.warning( + f"Dataloader num_workers mismatch: {dataloader.num_workers} != {dataloader_state['num_workers']} or " + f"num_ranks mismatch: {dist_config.world_size} != {dataloader_state['num_ranks']}, " + "starting dataloader from scratch." + ) + return dataloader + + dataloader.load_state_dict(dataloader_state) + if dist_config.is_main_process(): + logger.info(f"Loaded dataloader state from {dataloader_path}") + + return dataloader diff --git a/bionemo-recipes/recipes/mixtral_native_te/collator.py b/bionemo-recipes/recipes/mixtral_native_te/collator.py new file mode 100644 index 0000000000..4555c1762a --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/collator.py @@ -0,0 +1,1042 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/esm2/collator.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""Data collators for sequence packing and context parallel training. + +This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. +""" + +import logging +import threading +from dataclasses import dataclass, field +from typing import Any, TypedDict + +import datasets +import nvtx +import torch +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp +from transformers import DataCollator, DataCollatorForLanguageModeling + + +logger = logging.getLogger(__name__) + + +@dataclass +class DataCollatorWithFlattening: + """Data collator that wraps a DataCollatorForLanguageModeling and flattens inputs for flash-attention. + + This collator enables efficient training on batches containing variable-length sequences, by first flattening + (packing) multiple input sequences into a single contiguous tensor without padding between sequences. Then, it + applies masked language modeling (MLM) masking using the provided DataCollatorForLanguageModeling instance. + + The collator also generates metadata required for Flash Attention or context-parallel attention: + - `cu_seq_lens_q` and `cu_seq_lens_k` tensors, denoting cumulative sequence lengths so that sequence boundaries + within the packed tensor are known during attention computation. + + Optionally, the collator can: + - Pad the total number of tokens in the batch to be divisible by `pad_to_multiple_of` (by appending a mock + sequence). + - Pad each individual sequence to be divisible by `pad_sequences_to_be_divisible_by` if provided. + + Only PyTorch tensors (`return_tensors="pt"`) are supported. + + Args: + collator (DataCollatorForLanguageModeling): The collator to use for MLM masking. This is a captive + collator and should be constructed externally and passed in. + return_position_ids (bool): Whether to return position ids (default False). + pad_to_multiple_of (int, optional): If set, pads the total sequence length to be divisible by this number. + pad_sequences_to_be_divisible_by (int, optional): If set, each individual sequence is padded to this value. + separator_id (int, optional): A label to insert between sequences, typically should be -100 for causal LM. + + Example: + >>> from transformers import AutoTokenizer, DataCollatorForLanguageModeling + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") + >>> mlm_collator = DataCollatorForLanguageModeling(tokenizer) + >>> flat_collator = DataCollatorWithFlattening( + ... collator=mlm_collator, + ... pad_to_multiple_of=8, + ... ) + >>> + >>> # Input: variable length protein sequences + >>> sequences = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... {"input_ids": [0, 12, 13, 2]}, # 4 tokens + ... ] # Total: 15 tokens + >>> batch = flat_collator(sequences) + >>> print(batch['input_ids'].shape) # torch.Size([1, 16]) + >>> print(batch['labels'].shape) # torch.Size([1, 16]) + >>> print(batch['cu_seq_lens_q']) # tensor([0, 5, 11, 15, 16], dtype=torch.int32) + + Note: + The output is a THD-format (Total, Height, Depth) batch, where all input sequences are packed without + inter-sequence padding. Sequence boundaries are preserved using `cu_seq_lens_q`/`cu_seq_lens_k`, enabling + Flash Attention or context-parallelism without traditional attention masks. + """ + + collator: DataCollatorForLanguageModeling + return_position_ids: bool = False + pad_to_multiple_of: int | None = None + pad_sequences_to_be_divisible_by: int | None = None + separator_id: int | None = None + + def __post_init__(self): + """Ensure padding options are not used together.""" + if self.pad_sequences_to_be_divisible_by is not None and self.pad_to_multiple_of is not None: + raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together") + + def __call__(self, features, return_tensors=None): + """Process a batch of variable-length sequences for Flash Attention with MLM. + + This method performs the following steps: + 1. Flattens multiple sequences into a single packed tensor with Flash Attention metadata + 2. Applies MLM masking to the flattened sequence while preserving special tokens + 3. Optionally pads to a multiple of a specified number for hardware optimization + + Args: + features (List[Dict[str, List[int]]]): List of tokenized sequences, each containing + 'input_ids' and optionally 'attention_mask'. Example: + [ + {"input_ids": [0, 5, 6, 7, 2]}, # Protein sequence 1 + {"input_ids": [0, 8, 9, 10, 11, 2]}, # Protein sequence 2 + {"input_ids": [0, 12, 13, 2]} # Protein sequence 3 + ] + return_tensors (str, optional): Format for returned tensors. Only "pt" (PyTorch) + is supported. Defaults to None (uses collator default). + + Returns: + Dict[str, torch.Tensor]: Batch dictionary containing: + - input_ids (torch.Tensor): Flattened and MLM-masked token sequences. + Shape: [1, total_tokens] where total_tokens = sum of all sequence lengths + (plus padding if pad_to_multiple_of is specified). + - labels (torch.Tensor): MLM labels with -100 for non-masked tokens and + original token IDs for masked positions. Same shape as input_ids. + - cu_seq_lens_q (torch.IntTensor): Cumulative sequence lengths for queries. + Shape: [num_sequences + 1] or [num_sequences + 2] if padding is added. + Example: [0, 5, 11, 15] or [0, 5, 11, 15, 16] with padding. + - cu_seq_lens_k (torch.IntTensor): Cumulative sequence lengths for keys. + Same as cu_seq_lens_q for self-attention. + - max_length_q (int): Maximum sequence length in the batch. + - max_length_k (int): Same as max_length_q for self-attention. + - attention_mask (torch.Tensor): Attention mask with 1s for actual tokens + and 0s for padding tokens (if any). + + Raises: + NotImplementedError: If return_tensors is not "pt". + + Example: + >>> # Input features + >>> features = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... ] + >>> + >>> batch = collator(features) + >>> + >>> # Output shapes and values + >>> batch['input_ids'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['labels'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['cu_seq_lens_q'] # tensor([0, 5, 11], dtype=torch.int32) or larger + + Note: + The output is in THD (Total, Height, Depth) format with batch_size=1 and + sequence_length=total_tokens, optimized for Flash Attention's variable-length + sequence processing capabilities. When pad_to_multiple_of is used, an additional + mock sequence is appended to reach the desired total length. + """ + if return_tensors is not None and return_tensors != "pt": + raise NotImplementedError(f"Only return_tensors='pt' is supported, got '{return_tensors}'") + + # Perform the masking with the BSHD collator. + bshd_batch = self.collator(features, return_tensors=return_tensors) + + # Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values. + packed_batch = _pt_flatten_collate(features, return_position_ids=self.return_position_ids) + + # Get the masked input_ids and labels from the BSHD batch. + masked_input_ids = bshd_batch["input_ids"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + masked_labels = bshd_batch["labels"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + + if self.separator_id is not None: + masked_labels[:, packed_batch["cu_seq_lens_q"][1:-1]] = self.separator_id + + # Update the packed batch with the masked input_ids and labels. + packed_batch["input_ids"] = masked_input_ids + packed_batch["labels"] = masked_labels + + if self.pad_to_multiple_of is not None: + packed_batch = self._pad_batch_to_multiple_of(packed_batch) + + elif self.pad_sequences_to_be_divisible_by is not None: + packed_batch = self._pad_sequences_to_be_divisible_by(packed_batch) + + return packed_batch + + def _pad_batch_to_multiple_of(self, batch): + """Add a mock sequence to make the total number of tokens divisible by pad_to_multiple_of.""" + # Ensure token_pad is an integer, defaulting to 1 if pad_token_id is None or invalid + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_to_multiple_of is not None, "pad_to_multiple_of must be set" + + return _pt_pad_to_multiple_of( + batch, + self.pad_to_multiple_of, + token_pad=pad_token_id, + label_pad=-100, + ) + + def _pad_sequences_to_be_divisible_by(self, batch): + """Pad individual sequences using cu_seq_lens_*_padded for context parallelism.""" + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_sequences_to_be_divisible_by is not None, "pad_sequences_to_be_divisible_by must be set" + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + batch["input_ids"], + batch["labels"], + batch["cu_seq_lens_q"], + self.pad_sequences_to_be_divisible_by, + padding_token_id=pad_token_id, + padding_label_id=-100, + ) + + batch["input_ids"] = input_ids_padded.unsqueeze(0) + batch["labels"] = labels_padded.unsqueeze(0) + batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) + batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) + batch["pad_between_seqs"] = True + return batch + + +@dataclass +class TokenPackingDataset(torch.utils.data.IterableDataset): + """Dataset that uses sequence packing to construct batches with variable length up to a maximum number of tokens.""" + + dataset: datasets.IterableDataset + """Dataset to pack.""" + max_tokens_per_batch: int + """Maximum number of tokens per batch.""" + drop_last: bool = True + """Whether to drop the last batch if it's less than max_length.""" + split_samples: bool = False + """Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens.""" + pad_sequences_to_be_divisible_by: int | None = None + """If set, account for per-sequence padding when accumulating batches. + + Each sequence's contribution to the batch length is rounded up to the nearest multiple of this value, + matching the padding behavior of DataCollatorWithFlattening with the same parameter. When used with + split_samples=True, the split point is chosen so that the first part (after padding) exactly fills + the remaining batch capacity. + """ + + def __post_init__(self): + """Validate padding configuration.""" + if ( + self.pad_sequences_to_be_divisible_by is not None + and self.max_tokens_per_batch % self.pad_sequences_to_be_divisible_by != 0 + ): + logger.warning( + "max_tokens_per_batch (%d) is not divisible by pad_sequences_to_be_divisible_by (%d). " + "Batches may not fill to exactly max_tokens_per_batch when split_samples=True.", + self.max_tokens_per_batch, + self.pad_sequences_to_be_divisible_by, + ) + + def _padded_len(self, length: int) -> int: + """Return the padded length of a sequence, rounding up to the nearest multiple of pad_sequences_to_be_divisible_by.""" + if self.pad_sequences_to_be_divisible_by is None: + return length + return -(-length // self.pad_sequences_to_be_divisible_by) * self.pad_sequences_to_be_divisible_by + + def __iter__(self): + """Yield batches of samples, each with a variable number of tokens up to the maximum length. + + When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting + the final sample if needed. The remaining tokens from the split sample start the next batch. + + When pad_sequences_to_be_divisible_by is set, each sequence's padded length is used when + accumulating batch sizes, so the total padded length of the batch matches max_tokens_per_batch. + + Returns: + A generator of batches of samples, each with a variable number of tokens up to the maximum length. + """ + samples = [] + current_length = 0 + for sample in iter(self.dataset): + sample_length = len(sample["input_ids"]) + padded_len = self._padded_len(sample_length) + if padded_len > self.max_tokens_per_batch: + raise ValueError( + f"TokenPackingDataset: Padded sample length ({padded_len}) exceeds max_tokens_per_batch " + f"({self.max_tokens_per_batch}). Set truncation or a maximum length in your tokenizer or dataset to" + " ensure all samples fit within max_tokens_per_batch." + ) + + current_length += padded_len + if current_length == self.max_tokens_per_batch: + yield [*samples, sample] + samples = [] + current_length = 0 + + elif current_length > self.max_tokens_per_batch: + if not self.split_samples: + # Yield the current batch (before this sample) and start a new one with this sample. + if samples: + yield samples + samples = [sample] + current_length = padded_len + else: + # Calculate how many padded tokens are already in the batch. + tokens_in_batch = current_length - padded_len + # Calculate how many tokens we can fit from this sample, ensuring the + # padded length doesn't exceed the remaining capacity. + tokens_available = self.max_tokens_per_batch - tokens_in_batch + if self.pad_sequences_to_be_divisible_by is not None: + d = self.pad_sequences_to_be_divisible_by + tokens_available = (tokens_available // d) * d + if tokens_available <= 0: + # Remaining capacity is less than pad_sequences_to_be_divisible_by; + # can't fit any tokens from this sample. Yield current batch and start fresh. + if samples: + yield samples + samples = [sample] + current_length = padded_len + else: + first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) + yield [*samples, first_part] + samples = [remaining_part] + current_length = self._padded_len(len(samples[0]["input_ids"])) + else: + samples.append(sample) + + if not self.drop_last and samples: + yield samples + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset.""" + self.dataset.set_epoch(epoch) + + +@dataclass +class DataCollatorForContextParallel: + """A collator that is aware of context parallelism. + + For the case of context parallelism, padded sequences will be returned from the wrapped collator, and then split + into shards for each context parallelism rank. + + The shards are then typically sent to the ContextParallelDataLoaderWrapper which will scatter them to the + appropriate GPUs. + + Note: + When used with the ContextParallelDataLoaderWrapper and both context parallelism and tensor parallelism are + used, the collator inspects the ordering of the mesh dimensions to determine the layout of the flattened batch. + + If "cp" comes before "tp" in the mesh dimension names (CP row-major), the flattened batch will be: + [(cp0, tp0), (cp0, tp1), ..., (cp1, tp0), (cp1, tp1), ...] + + If "tp" comes before "cp" (TP row-major), the flattened batch will be: + [(tp0, cp0), (tp0, cp1), ..., (tp1, cp0), (tp1, cp1), ...] + + Args: + collator: The collator to use for the batch. + device_mesh: The device mesh with named dimensions. Must contain either a "cp" dimension for context parallelism + and/or a "tp" dimension for tensor parallelism. + qkv_format: The format of the query-key-value (QKV) tensor. + is_causal_lm: Whether the collator is for a causal language model. If True, the labels will be shifted before + being split into CP shards, and will be returned in the `shift_labels` field. + + """ + + collator: DataCollator + device_mesh: torch.distributed.device_mesh.DeviceMesh + qkv_format: str = "thd" + is_causal_lm: bool = False + + # Derived fields, initialized in __post_init__. + cp_world_size: int = field(init=False) + tp_world_size: int | None = field(init=False) + _is_cp_row_major: bool = field(init=False) + + def __post_init__(self): + """Initialize the cp_world_size, tp_world_size, and _is_cp_row_major fields based on the device mesh.""" + dim_names = self.device_mesh.mesh_dim_names + if dim_names is None: + raise ValueError("device_mesh must have mesh_dim_names") + + self.cp_world_size = self.device_mesh.size(dim_names.index("cp")) if "cp" in dim_names else 1 + self.tp_world_size = self.device_mesh.size(dim_names.index("tp")) if "tp" in dim_names else None + + # Determine whether CP is the row (outer) dimension of the 2D mesh. + # When flattened, the row-major dimension's index changes slowest. + # If "cp" comes before "tp" in mesh_dim_names, CP is the row dimension. + if "cp" in dim_names and "tp" in dim_names: + self._is_cp_row_major = dim_names.index("cp") < dim_names.index("tp") + else: + self._is_cp_row_major = True + + def __call__(self, features) -> list[dict[str, Any]]: + """Process batches of data and create shards for each context parallelism rank. + + Args: + features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'. + + Returns: + A list of dictionaries, each containing a shard of the batch for a given context parallelism rank. + """ + batch = self.collator(features) + + # Remove the attention mask from the batch, it's not valid for CP. + batch.pop("attention_mask", None) + + if self.is_causal_lm: + labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) + batch["labels"] = labels[..., 1:].contiguous() + + combined_batch = [] + for cp_rank in range(self.cp_world_size): + input_ids_sharded, labels_sharded = _split_batch_by_cp_rank( + cu_seqlens_padded=batch.get("cu_seq_lens_q_padded", None), # This will be None for BSHD format. + input_ids_padded=batch["input_ids"], + labels_padded=batch["labels"], + qvk_format=self.qkv_format, + cp_rank=cp_rank, + cp_world_size=self.cp_world_size, + ) + batch_shard = dict(batch) + batch_shard["input_ids"] = input_ids_sharded + if self.is_causal_lm: + batch_shard["shift_labels"] = labels_sharded + batch_shard["labels"] = None + else: + batch_shard["labels"] = labels_sharded + # Now determine the max length of the sequence. + if self.qkv_format == "thd": + seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1] + max_length = seqlens_q.max().item() + elif self.qkv_format == "bshd": + max_length = batch["input_ids"].shape[1] + else: + raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") + + batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64 + combined_batch.append(batch_shard) + + if self.tp_world_size is not None: + # Replicate each CP shard for TP ranks. The ordering depends on which dimension forms the rows in the + # flattened mesh. + if self._is_cp_row_major: + # Flattened mesh: [(cp0,tp0), (cp0,tp1), (cp1,tp0), (cp1,tp1)] + # Output: [cp0, cp0, cp1, cp1] + combined_batch = [batch for batch in combined_batch for _ in range(self.tp_world_size)] + else: + # Flattened mesh: [(tp0,cp0), (tp0,cp1), (tp1,cp0), (tp1,cp1)] + # Output: [cp0, cp1, cp0, cp1] + combined_batch = [ + combined_batch[cp_rank] for _ in range(self.tp_world_size) for cp_rank in range(self.cp_world_size) + ] + + return combined_batch + + +class ContextParallelDataLoaderWrapper: + """A dataloader that is aware of context and tensor parallelism.""" + + def __init__( + self, + dataloader: torch.utils.data.DataLoader | None, + cp_tp_mesh: torch.distributed.device_mesh.DeviceMesh, + ): + """A dataloader wrapper that distributes the data across the context and tensor parallelism groups. + + This class materializes a single dataloader for each data parallel mesh rank, and splits / replicates the data + from this dataloader across the context and tensor parallelism groups. + + Args: + dataloader: The dataloader to use. + cp_tp_mesh: The context parallel mesh, or a flattened, combined context parallel and tensor parallel mesh. + If a flattened mesh is provided, the cp / tp dimensions should be in the order they appeared in the + mesh_dim_names as passed to DataCollatorForContextParallel. + """ + if cp_tp_mesh.get_local_rank() == 0: + assert dataloader is not None, "dataloader must be provided on rank 0" + self.dataloader = dataloader + + else: + assert dataloader is None, "Dataloader on non-rank 0 will not be used" + + self.cp_tp_rank = cp_tp_mesh.get_local_rank() + self.cp_tp_group = cp_tp_mesh.get_group() + self.num_cp_tp_ranks = cp_tp_mesh.size() + self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None + + logger.debug( + "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", + torch.distributed.get_rank() if torch.distributed.is_initialized() else "", + self.cp_tp_rank, + ) + + def __iter__(self): + """Make the dataloader iterable.""" + if self.cp_tp_rank == 0: + self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() + return self + + @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") + def __next__(self): + """Get the batch from the dataloader for the current CP rank.""" + if self._prefetch_thread is not None: + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, Exception): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. + + This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the + given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their + corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called. + """ + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except StopIteration as e: + self._prefetch_result = e + except Exception as e: + self._prefetch_result = e + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + @nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green") + def _send_data_to_cp_tp_ranks(self): + """Send data to all the CP/TP ranks. + + This function will get the batch from the dataloader on CP rank 0, and then determine + the shards for all the different CP group members. + combined_batch = [, , ..., ] + Then it will scatter the shards to the different CP group members. + The shards are then combined into a single batch and returned to the caller + for the current CP rank. + + If tensor parallelism is also being used, the combined batch will look like: + combined_batch = [, , ..., , , ...] + where there are cp_world_size shards, and each shard is replicated tp_world_size times. The ordering of the + shards depends on which dimension forms the rows in the flattened mesh. + + Scalability: + Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they + do not grow linearly with CP size. + + Args: + None + + Returns: + batch: The batch for the current CP/TP rank. + + """ + try: + with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"): + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + except StopIteration as ex: + # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so + # that the dataloader can be restarted. + combined_batch = [ex] * self.num_cp_tp_ranks + + batch_on_this_rank = _scatter_batch_to_cp_tp_ranks(combined_batch, self.cp_tp_group) + + if isinstance(batch_on_this_rank, StopIteration): + raise batch_on_this_rank + + return batch_on_this_rank + + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_tp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + + +def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]: + """Split a sample dictionary at a specified number of tokens. + + This function splits a sample into two parts: the first part contains exactly `num_tokens` tokens, + and the second part contains the remaining tokens. All fields that are sequences (input_ids, attention_mask, + token_type_ids, labels, etc.) are split accordingly. + + Args: + sample: Dictionary containing sample data with fields like input_ids, attention_mask, token_type_ids, labels, etc. + num_tokens: Number of tokens to include in the first part of the split. + + Returns: + A tuple of two dictionaries: (first_part, remaining_part), where: + - first_part contains the first `num_tokens` tokens from each sequence field + - remaining_part contains the remaining tokens from each sequence field + + Example: + >>> sample = { + ... "input_ids": [0, 5, 6, 7, 8, 9, 2], + ... "attention_mask": [1, 1, 1, 1, 1, 1, 1], + ... "labels": [0, 5, 6, 7, 8, 9, 2] + ... } + >>> first, remaining = split_sample_by_num_tokens(sample, 3) + >>> first["input_ids"] # [0, 5, 6] + >>> remaining["input_ids"] # [7, 8, 9, 2] + """ + sample_length = len(sample["input_ids"]) + if num_tokens >= sample_length: + raise ValueError( + f"num_tokens ({num_tokens}) must be less than sample length ({sample_length}) to split the sample" + ) + if num_tokens <= 0: + raise ValueError(f"num_tokens ({num_tokens}) must be positive") + + first_part = {} + remaining_part = {} + + # Fields that should be split by tokens (sequence fields) + sequence_fields = ["input_ids", "attention_mask", "token_type_ids", "token_type", "labels"] + + for key, value in sample.items(): + if key in sequence_fields: + # Handle both list and tensor inputs + if isinstance(value, torch.Tensor): + first_part[key] = value[:num_tokens].clone() + remaining_part[key] = value[num_tokens:].clone() + elif isinstance(value, list): + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + else: + # For other types, try to slice if possible + try: + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + except (TypeError, IndexError): + # If slicing doesn't work, copy the value to both parts + # This handles fields that shouldn't be split (like metadata) + first_part[key] = value + remaining_part[key] = value + else: + # For non-sequence fields, copy to both parts + # This handles metadata fields that shouldn't be split + first_part[key] = value + remaining_part[key] = value + + return first_part, remaining_part + + +def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths. + + Args: + features: List of tokenized samples, each containing at least ``input_ids``. + return_position_ids: Whether to return position ids for each token. + + Returns: + A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and + ``max_length_q``/``max_length_k``. + """ + is_labels_provided = "labels" in features[0] + sample_lengths = [len(sample["input_ids"]) for sample in features] + + batch = {} + batch["max_length_q"] = batch["max_length_k"] = max(sample_lengths) + batch["input_ids"] = torch.tensor( + [[token for sample in features for token in sample["input_ids"]]], dtype=torch.int64 + ) + if is_labels_provided: + batch["labels"] = torch.tensor( + [[label for sample in features for label in sample["labels"]]], dtype=torch.int64 + ) + cu_seq_lens = torch.zeros(len(features) + 1, dtype=torch.int32) + cu_seq_lens[1:] = torch.cumsum(torch.tensor(sample_lengths), dim=0, dtype=torch.int32) + batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens + if "attention_mask" in features[0]: + batch["attention_mask"] = torch.tensor( + [[v for sample in features for v in sample["attention_mask"]]], dtype=torch.int64 + ) + if return_position_ids: + batch["position_ids"] = torch.hstack( + [torch.arange(sample_len, dtype=torch.int64) for sample_len in sample_lengths] + ).unsqueeze(0) + + return batch + + +def _find_seq_dim(tensor: torch.Tensor, seq_len: int) -> int: + """Find which dimension of tensor matches the expected sequence length. + + Args: + tensor: The tensor to inspect. + seq_len: The expected sequence length to match against tensor dimensions. + + Returns: + The dimension index that matches the sequence length. + + Raises: + ValueError: If no dimension matches the expected sequence length. + """ + if tensor.ndim == 1: + if tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"1D tensor shape {tensor.shape} doesn't match sequence length {seq_len}") + elif tensor.ndim >= 2: + if tensor.shape[1] == seq_len: + return 1 + elif tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"Tensor shape {tensor.shape} doesn't match sequence length {seq_len} in dim 0 or 1") + raise ValueError(f"Unexpected tensor ndim={tensor.ndim}") + + +def _process_tensor_thd( + val: torch.Tensor | None, + seq_len: int, + slice_sizes: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + cp_rank: int, + total_slices: int, +) -> torch.Tensor | None: + """Extract the THD context-parallel shard for a single tensor. + + For each sequence in the batch, selects two slices (one from the beginning and one from the end) + corresponding to the given CP rank, following the zigzag CP sharding pattern. + + Args: + val: The tensor to shard, or None (returned as-is). + seq_len: Total sequence length (from cu_seqlens_padded[-1]). + slice_sizes: Per-sequence slice sizes, computed as sequence_lengths // total_slices. + cu_seqlens_padded: Cumulative sequence lengths including padding. + cp_rank: The context parallelism rank index. + total_slices: Total number of slices per sequence (2 * cp_world_size). + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + """ + if val is None: + return val + + seq_dim = _find_seq_dim(val, seq_len) + + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices - cp_rank - 1) * slice_size), + seq_start + ((total_slices - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(seq_dim, torch.cat(cp_rank_slices)) + + +def _process_tensor_bshd( + val: torch.Tensor | None, + cp_rank: int, + cp_world_size: int, +) -> torch.Tensor | None: + """Extract the BSHD context-parallel shard for a single tensor. + + Splits a BSHD-format tensor along the sequence dimension (dim=1) into 2*cp_world_size chunks, + then selects the two chunks corresponding to the given CP rank (zigzag pattern). + + Args: + val: The tensor to shard, or None (returned as-is). + cp_rank: The context parallelism rank index. + cp_world_size: Total number of context parallelism ranks. + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + + Raises: + ValueError: If the tensor has fewer than 2 dimensions or its sequence length + is not divisible by 2 * cp_world_size. + """ + if val is None: + return val + + if val.ndim < 2: + raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D") + + seq_len = val.shape[1] + + # Calculate chunk size + total_chunks = 2 * cp_world_size + chunk_size = seq_len // total_chunks + + if seq_len % total_chunks != 0: + raise ValueError( + f"Sequence length {seq_len} must be divisible by {total_chunks} " + f"(2 * cp_world_size) for BSHD context parallelism" + ) + + # Determine which chunks this rank should get + # Rank 0 gets chunks [0, total_chunks-1] + # Rank 1 gets chunks [1, total_chunks-2] + # Rank k gets chunks [k, total_chunks-k-1] + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + + # Collect slices for this rank + rank_slices = [] + for chunk_idx in chunk_indices: + start_idx = chunk_idx * chunk_size + end_idx = start_idx + chunk_size + rank_slices.append(torch.arange(start_idx, end_idx, device=val.device)) + + # Concatenate indices for all chunks this rank should get + indices = torch.cat(rank_slices) + + # Select along sequence dimension (dim=1) + return val.index_select(1, indices) + + +def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token_pad: int, label_pad: int): + """Pad a batch to a multiple of pad_to_multiple_of. + + Appends a mock sequence to the end of the batch with the given token_pad and label_pad to make the total number of + tokens divisible by pad_to_multiple_of. + + Args: + batch: Input batch, possibly containing labels and/or cu_seq_lens / max_length keys. + pad_to_multiple_of: Multiple to pad to. + token_pad: Token to pad with. + label_pad: Label to pad with. + + Returns: + Batch dictionary with padded input_ids, labels, cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k. + """ + # Number of tokens we need to pad to make the total number of tokens divisible by pad_to_multiple_of + remainder = -batch["input_ids"].numel() % pad_to_multiple_of + + if remainder == 0: + return batch + + batch["input_ids"] = torch.cat( + [batch["input_ids"], torch.full((1, remainder), token_pad, dtype=batch["input_ids"].dtype)], dim=1 + ) + + if "labels" in batch: + batch["labels"] = torch.cat( + [batch["labels"], torch.full((1, remainder), label_pad, dtype=batch["labels"].dtype)], dim=1 + ) + + if "cu_seq_lens_q" in batch: + batch["cu_seq_lens_q"] = torch.cat( + [ + batch["cu_seq_lens_q"], + torch.tensor([batch["cu_seq_lens_q"][-1] + remainder], dtype=batch["cu_seq_lens_q"].dtype), + ], + dim=0, + ) + batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"] + + if "max_length_q" in batch: + batch["max_length_q"] = max(batch["max_length_q"], remainder) + batch["max_length_k"] = batch["max_length_q"] + + if "attention_mask" in batch: + batch["attention_mask"] = torch.cat( + [batch["attention_mask"], torch.zeros((1, remainder), dtype=batch["attention_mask"].dtype)], dim=1 + ) + + if "position_ids" in batch: + batch["position_ids"] = torch.cat( + [batch["position_ids"], torch.arange(remainder, dtype=batch["position_ids"].dtype).unsqueeze(0)], dim=1 + ) + + return batch + + +# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 +# we can replace this with the one in TransformerEngine. +@nvtx.annotate("collator._split_batch_by_cp_rank", color="green") +def _split_batch_by_cp_rank( + cu_seqlens_padded: torch.Tensor | None, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup | None = None, + qvk_format: str = "thd", + cp_rank: int | None = None, + cp_world_size: int | None = None, +): + """Slice batch input along sequence dimension into multiple chunks for THD or BSHD format. + + This function is intended for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths for THD format, + and with padded sequences for BSHD format. + + Args: + cu_seqlens_padded: Cumulative sequence length. Required for THD format, optional for BSHD format. + input_ids_padded: Input IDs. + labels_padded: Labels. + cp_group: Context parallel group. + qvk_format: Format of the input data ("thd" or "bshd"). + cp_world_size: The size of the context parallelism group. If provided, the function will use this value to determine the rank. + cp_rank: Optional manual CP rank index. When provided, the function shards tensors as if it + were executing on that rank without querying `torch.distributed.get_rank`. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + + if cp_world_size is None or cp_world_size <= 1: + # No splitting needed + return input_ids_padded, labels_padded + + if cp_rank is None: + cp_rank = torch.distributed.get_rank(group=cp_group) + elif not (0 <= cp_rank < cp_world_size): + raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.") + + if qvk_format == "thd": + if cu_seqlens_padded is None: + raise ValueError("cu_seqlens_padded is required for THD format") + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_world_size + slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence + + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + last_elem = cu_seqlens_padded[-1] + seq_len_val = last_elem.item() if isinstance(last_elem, torch.Tensor) else last_elem + + input_ids_padded = _process_tensor_thd( + input_ids_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + labels_padded = _process_tensor_thd( + labels_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + + elif qvk_format == "bshd": + input_ids_padded = _process_tensor_bshd(input_ids_padded, cp_rank, cp_world_size) + labels_padded = _process_tensor_bshd(labels_padded, cp_rank, cp_world_size) + + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded + + +class BatchType(TypedDict): + """The fields in the batch dictionary for THD context parallel.""" + + input_ids: torch.Tensor + labels: torch.Tensor | None + shift_labels: torch.Tensor | None + cu_seq_lens_q: torch.Tensor + cu_seq_lens_k: torch.Tensor + cu_seq_lens_q_padded: torch.Tensor + cu_seq_lens_k_padded: torch.Tensor + max_length_q: int + max_length_k: int + pad_between_seqs: bool + + +@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") +def _scatter_batch_to_cp_tp_ranks( + all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None +) -> BatchType | StopIteration: + """Scatter a batch to all the CP ranks. + + Args: + all_batches (list[BatchType] | list[StopIteration]): A list of already-sharded batches to scatter to the CP/TP + ranks. + cp_tp_group (torch.distributed.ProcessGroup | None): The process group to scatter the batches to. + + Returns: + BatchType | StopIteration: The batch on this rank. + """ + scatter_object_output_list = [None] + # Note: This does not provide an async_op handle. Thus its blocking. + torch.distributed.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=all_batches, + group=cp_tp_group, + group_src=0, + ) + return scatter_object_output_list[0] diff --git a/bionemo-recipes/recipes/mixtral_native_te/dataset.py b/bionemo-recipes/recipes/mixtral_native_te/dataset.py new file mode 100644 index 0000000000..76c3e5c86e --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/dataset.py @@ -0,0 +1,274 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import datasets +import datasets.distributed +from collator import ( + DataCollatorWithFlattening, + TokenPackingDataset, +) +from distributed_config import DistributedConfig +from torch.utils.data import DataLoader, DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer +from transformers.data.data_collator import DataCollatorForLanguageModeling + + +logger = logging.getLogger(__name__) + + +def create_tokenized_dataset( + distributed_config: DistributedConfig, + tokenizer_name_or_path: str, + load_dataset_kwargs: dict, + max_seq_length: int = 8192, + stride: int = 200, + buffer_size: int = 5_000, + text_column: str = "text", + tokenize_batch_size: int = 100, +): + """Create a tokenized dataset with windowing. + + Args: + distributed_config: The distributed configuration. + tokenizer_name_or_path: Name or path to the tokenizer directory. + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. + max_seq_length: The maximum length of sequences (window size). + stride: The stride for windowing (overlap = stride tokens). + buffer_size: The buffer size for shuffle. + text_column: Name of the column containing text sequences (default: "text"). + tokenize_batch_size: The batch size for tokenization. + + Returns: + Tuple of (tokenized_dataset, tokenizer). + """ + logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}") + dataset = datasets.load_dataset(**load_dataset_kwargs) + + if isinstance(dataset, datasets.IterableDataset): + # Hugging Face's `split_dataset_by_node` is quite sensitive to the total number of shards -- if the number of + # shards is not perfectly divisible by the world size, it defaults to loading the same shards on all nodes and + # using strided sampling to avoid loading the same data on all nodes. This can be quite inefficient with large + # numbers of shards and workers, so we use `dataset.shard` instead. + if distributed_config.world_size > dataset.num_shards: + logger.info(f"Sharding dataset with {dataset.num_shards} shards with split_dataset_by_node") + dataset = datasets.distributed.split_dataset_by_node( + dataset, rank=distributed_config.rank, world_size=distributed_config.world_size + ) + else: + logger.info(f"Sharding dataset with {dataset.num_shards} shards with dataset.shard") + dataset = dataset.shard(num_shards=distributed_config.world_size, index=distributed_config.rank) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + + def tokenize_with_windowing(examples): + """Tokenize text sequences with windowing (one-to-many mapping).""" + # Tokenize with windowing using return_overflowing_tokens + result = tokenizer( + examples[text_column], + max_length=max_seq_length, + stride=stride, + truncation=True, + return_overflowing_tokens=True, + add_special_tokens=True, + ) + return result + + tokenized_dataset = dataset.select_columns(text_column).map( + tokenize_with_windowing, + batched=True, + batch_size=tokenize_batch_size, + remove_columns=[text_column], + ) + + if isinstance(tokenized_dataset, datasets.IterableDataset): + # We shuffle after tokenization to make sure we shuffle the sharded input sequences. + tokenized_dataset = tokenized_dataset.shuffle(seed=42, buffer_size=buffer_size) + + # Even in THD mode, we use a base MLM collator that requires a padding token to be set. + if tokenizer.pad_token is None: + logger.warning(f"Tokenizer does not have a padding token. Setting it to the EOS token: {tokenizer.eos_token}") + tokenizer.pad_token = tokenizer.eos_token + + return tokenized_dataset, tokenizer + + +def create_bshd_dataloader( + distributed_config: DistributedConfig, + tokenizer_name_or_path: str, + load_dataset_kwargs: dict, + micro_batch_size: int, + num_workers: int = 1, + prefetch_factor: int = 4, + max_seq_length: int = 8192, + stride: int = 200, + seed: int = 42, + buffer_size: int = 500_000, + use_stateful_dataloader: bool = False, + text_column: str = "text", + pad_sequences_to_be_divisible_by: int | None = None, +): + """Create a BSHD dataloader for Mixtral pre-training. + + Args: + distributed_config: The distributed configuration. + tokenizer_name_or_path: Name or path to the tokenizer directory. + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. + micro_batch_size: The batch size per device. + num_workers: The number of workers to use for the dataloader. + prefetch_factor: The prefetch factor to use for the dataloader. + max_seq_length: The maximum length of sequences (window size). + stride: The stride for windowing (overlap = stride tokens). + seed: The seed to use for the distributed sampler and data collator. + buffer_size: The buffer size for shuffle. + use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state. + text_column: Name of the column containing text sequences (default: "text"). + pad_sequences_to_be_divisible_by: The number to pad sequences to be divisible by, required for FP8 training. + Default: None. + + Returns: + A tuple of (dataloader, dataset_or_sampler). + """ + tokenized_dataset, tokenizer = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_name_or_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=max_seq_length, + stride=stride, + buffer_size=buffer_size, + text_column=text_column, + tokenize_batch_size=micro_batch_size * prefetch_factor, + ) + + if isinstance(tokenized_dataset, datasets.IterableDataset): + sampler = None + else: + sampler = DistributedSampler( + tokenized_dataset, + rank=distributed_config.rank, + num_replicas=distributed_config.world_size, + seed=seed, + ) + + # Create base collator + base_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, # Causal language modeling + pad_to_multiple_of=pad_sequences_to_be_divisible_by, + ) + + data_collator = base_collator + logger.info("Using standard DataCollatorForLanguageModeling") + + # TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again. + dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader + train_dataloader = dataloader_class( + tokenized_dataset, + sampler=sampler, + batch_size=micro_batch_size, + collate_fn=data_collator, + num_workers=num_workers, + pin_memory=not use_stateful_dataloader, + persistent_workers=num_workers > 0, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + ) + + return train_dataloader, tokenized_dataset if sampler is None else sampler + + +def create_thd_dataloader( + distributed_config: DistributedConfig, + tokenizer_name_or_path: str, + load_dataset_kwargs: dict, + micro_batch_size: int | None = None, + token_micro_batch_size: int | None = None, + num_workers: int = 1, + prefetch_factor: int = 4, + max_seq_length: int = 8192, + stride: int = 200, + buffer_size: int = 500_000, + use_stateful_dataloader: bool = False, + text_column: str = "text", + split_samples_in_token_packing: bool = True, + pad_sequences_to_be_divisible_by: int | None = None, +): + """Create a dataloader that packs up to the maximum number of tokens per batch. + + Args: + distributed_config: The distributed configuration. + tokenizer_name_or_path: Name or path to the tokenizer directory. + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. + micro_batch_size: The batch size per device. + token_micro_batch_size: The maximum number of tokens per batch. If None, the micro_batch_size * max_seq_length + will be used. Defaults to None. + num_workers: The number of workers to use for the dataloader. + prefetch_factor: The prefetch factor to use for the dataloader. + max_seq_length: The maximum length of sequences (window size). + stride: The stride for windowing (overlap = stride tokens). + buffer_size: The buffer size for shuffle. + use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state. + text_column: Name of the column containing text sequences (default: "text"). + split_samples_in_token_packing: Whether to split samples to form batches with exactly token_micro_batch_size + tokens. Default: True. + pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value. + This is useful for context parallelism. Defaults to None. + + Returns: + A tuple of (dataloader, dataset_or_sampler). + """ + tokenized_dataset, tokenizer = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_name_or_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=max_seq_length, + stride=stride, + buffer_size=buffer_size, + text_column=text_column, + ) + + assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset." + if token_micro_batch_size is None: + assert micro_batch_size is not None, "Only one of micro_batch_size or token_micro_batch_size can be provided." + token_micro_batch_size = micro_batch_size * max_seq_length + else: + assert micro_batch_size is None, "Only one of micro_batch_size or token_micro_batch_size can be provided." + assert token_micro_batch_size >= max_seq_length, "token_micro_batch_size must be greater than max_seq_length." + + # Create base MLM collator and wrap with flattening collator + data_collator = DataCollatorWithFlattening( + collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), + separator_id=-100, + pad_sequences_to_be_divisible_by=pad_sequences_to_be_divisible_by, + ) + + # TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again. + dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader + train_dataloader = dataloader_class( + TokenPackingDataset( + tokenized_dataset, + max_tokens_per_batch=token_micro_batch_size, + split_samples=split_samples_in_token_packing, + ), + batch_size=None, # The TokenPackingDataset will handle the batching. + collate_fn=data_collator, + num_workers=num_workers, + pin_memory=not use_stateful_dataloader, + persistent_workers=num_workers > 0, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + ) + + return train_dataloader, tokenized_dataset diff --git a/bionemo-recipes/recipes/mixtral_native_te/distributed_config.py b/bionemo-recipes/recipes/mixtral_native_te/distributed_config.py new file mode 100644 index 0000000000..271a5ffcfc --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/distributed_config.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from dataclasses import dataclass, field + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class DistributedConfig: + """Class to track distributed ranks and handle basic distributed training setup. + + If torch distributed environment variables are not set, we set them to default values for single-process training. + + Attributes: + rank: The rank of the process. + local_rank: The local rank of the process. + world_size: The total number of processes. + """ + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + _master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost")) + _master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355")) + + def is_main_process(self) -> bool: + """This is the global rank 0 process, to be used for wandb logging, etc.""" + return self.rank == 0 diff --git a/bionemo-recipes/recipes/mixtral_native_te/dlcm_sanity_dataset.parquet b/bionemo-recipes/recipes/mixtral_native_te/dlcm_sanity_dataset.parquet new file mode 100644 index 0000000000..a4348af6f5 Binary files /dev/null and b/bionemo-recipes/recipes/mixtral_native_te/dlcm_sanity_dataset.parquet differ diff --git a/bionemo-recipes/recipes/mixtral_native_te/fp8_debugging.py b/bionemo-recipes/recipes/mixtral_native_te/fp8_debugging.py new file mode 100644 index 0000000000..79bfa06b2b --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/fp8_debugging.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from pathlib import Path + +import nvdlfw_inspect.api as debug_api +import transformer_engine +from distributed_config import DistributedConfig + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def initialize_fp8_debugging( + dist_config: DistributedConfig, + enabled: bool, + fp8_stats_file: str, + fp8_log_dir: str | os.PathLike, + fp8_enabled: bool, +) -> None: + """Initialize FP8 debugging. + + Args: + dist_config: The distributed configuration. + enabled: Whether to enable FP8 debugging. + fp8_stats_file: The file containing the FP8 stats. + fp8_log_dir: The directory to log the FP8 stats to. + fp8_enabled: Whether FP8 autocast is enabled. + """ + if not enabled: + return + + if not fp8_enabled: + raise ValueError( + "fp8_stats_config.enabled is true but fp8_config.enabled is false, " + "please set fp8_config.enabled to true in the config if you wish to collect FP8 stats" + ) + + fp8_log_dir = Path(fp8_log_dir) / f"rank_{dist_config.rank}" + fp8_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging FP8 stats to {fp8_log_dir}") + te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") + debug_api.initialize( + config_file=fp8_stats_file, + feature_dirs=[te_features_dir], + log_dir=fp8_log_dir.as_posix(), + default_logging_enabled=True, + ) diff --git a/bionemo-recipes/recipes/mixtral_native_te/fused_a2a.py b/bionemo-recipes/recipes/mixtral_native_te/fused_a2a.py new file mode 100644 index 0000000000..96a4b34862 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/fused_a2a.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/mixtral/fused_a2a.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +# Portions of this code are from DeepSeek DeepEP project +# Copyright (c) 2025 DeepSeek +# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE + +import os + + +try: + from deep_ep import Buffer + from deep_ep.utils import EventHandle, EventOverlap + + HAVE_DEEP_EP = True + Buffer.set_num_sms(int(os.environ.get("DEEP_EP_SM_NUMS", "20"))) +except ImportError: + HAVE_DEEP_EP = False + +import torch + + +_buffer = None +_nvshmem_available = None + + +def _is_nvshmem_available() -> bool: + """Check if DeepEP was compiled with NVSHMEM support. + + Probes NVSHMEM by calling get_rdma_buffer_size_hint, since + is_sm90_compiled() alone is not a reliable proxy — SM90 can + be compiled while NVSHMEM is still disabled. + """ + global _nvshmem_available # noqa: PLW0603 + if _nvshmem_available is None: + try: + config = Buffer.get_dispatch_config(2) + config.get_rdma_buffer_size_hint(256, 2) + _nvshmem_available = True + except RuntimeError: + _nvshmem_available = False + return _nvshmem_available + + +def get_hidden_bytes(x: torch.Tensor) -> int: + """Calculate the number of hidden bytes for a tensor. + + Args: + x (torch.Tensor): Input tensor + + Returns: + int: Number of hidden bytes + """ + return x.size(1) * max(x.element_size(), 2) + + +def get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int): + """Get or create a buffer for all-to-all communication. + + Args: + group (torch.distributed.ProcessGroup): Process group for communication + hidden_bytes (int): Number of hidden bytes needed + + Returns: + Buffer: Communication buffer + """ + global _buffer # noqa: PLW0603 + num_nvl_bytes, num_rdma_bytes = 0, 0 + nvshmem = _is_nvshmem_available() + for config in ( + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), + ): + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) + if nvshmem: + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) + + # Allocate buffer if not existed or not enough buffer + # NOTES: the adaptive routing configuration of the network **must be off** + if ( + _buffer is None + or _buffer.group != group + or _buffer.num_nvl_bytes < num_nvl_bytes + or _buffer.num_rdma_bytes < num_rdma_bytes + ): + _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) + return _buffer + + +class FusedDispatch(torch.autograd.Function): + """Fused dispatch operation for MoE routing combining computation and communication.""" + + @staticmethod + def forward( + ctx, + x, + token_indices, + token_probs, + num_experts, + group, + async_finish=False, + allocate_on_comm_stream=False, + ): + """Forward pass of fused dispatch.""" + previous_event = None + if async_finish: + previous_event = EventOverlap(EventHandle()) + # Calculate layout before actual dispatch + buffer = get_buffer(group, get_hidden_bytes(x)) + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + event, + ) = buffer.get_dispatch_layout( + token_indices, + num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, + # so this is not compatible with CUDA graph + ( + recv_x, + recv_token_indices, + recv_token_probs, + num_recv_tokens_per_expert_list, + handle, + after_event_overlap, + ) = buffer.dispatch( + x, + topk_idx=token_indices, + topk_weights=token_probs, # DeepEP only supports float32 probs + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=event, # wait in deepep::intra/inter_dispatch + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + # Make sure current stream is synchronized + if async_finish: + after_event_overlap.current_stream_wait() + + # Save for backward + ctx.group = group + ctx.handle = handle + ctx.async_finish = async_finish + ctx.allocate_on_comm_stream = allocate_on_comm_stream + tokens_per_expert = torch.tensor(num_recv_tokens_per_expert_list) + + return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle) + + @staticmethod + def backward( + ctx, + grad_output, + grad_token_indices, + grad_token_probs, + grad_tokens_per_expert, + grad_handle, + ): + """Backward pass of fused dispatch.""" + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) + handle = ctx.handle + previous_event = None + if ctx.async_finish: + previous_event = EventOverlap(EventHandle()) + grad_x, grad_token_probs, after_event = buffer.combine( + grad_output.contiguous(), + handle, + topk_weights=grad_token_probs.float(), + previous_event=previous_event, + async_finish=ctx.async_finish, + allocate_on_comm_stream=ctx.allocate_on_comm_stream, + ) + # Make sure current stream is synchronized + if ctx.async_finish: + after_event.current_stream_wait() + return grad_x, None, grad_token_probs, None, None, None, None + + +class FusedCombine(torch.autograd.Function): + """Fused combine operation for MoE output combining computation and communication.""" + + @staticmethod + def forward(ctx, x, group, handle, async_finish=False, allocate_on_comm_stream=False): + """Forward pass of fused combine.""" + previous_event = None + if async_finish: + previous_event = EventOverlap(EventHandle()) + buffer = get_buffer(group, get_hidden_bytes(x)) + combined_x, _, after_event = buffer.combine( + x, + handle=handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + # Make sure current stream is synchronized + if async_finish: + after_event.current_stream_wait() + + ctx.handle = handle + ctx.group = group + ctx.async_finish = async_finish + ctx.allocate_on_comm_stream = allocate_on_comm_stream + return combined_x, None + + @staticmethod + def backward(ctx, grad_output, previous_event=None): + """Backward pass of fused combine.""" + previous_event = None + if ctx.async_finish: + previous_event = EventOverlap(EventHandle()) + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) + grad_x, _, _, _, _, after_event = buffer.dispatch( + grad_output.contiguous(), + handle=ctx.handle, + previous_event=previous_event, + async_finish=ctx.async_finish, + allocate_on_comm_stream=ctx.allocate_on_comm_stream, + ) + # Make sure current stream is synchronized + if ctx.async_finish: + after_event.current_stream_wait() + return grad_x, None, None, None, None + + +if HAVE_DEEP_EP: + + def fused_dispatch( + x, + token_indices, + token_probs, + num_experts, + group, + async_finish=False, + allocate_on_comm_stream=False, + ): + """Perform fused dispatch operation if deep_ep is available. + + Args: + x: Input tensor [num_tokens, hidden_size] + token_indices: Token routing indices [num_tokens, topk] + token_probs: Token routing probabilities [num_tokens, topk] + num_experts: Number of experts + group: Process group + async_finish: Whether to finish asynchronously + allocate_on_comm_stream: Whether to allocate on communication stream + + Returns: + Result of FusedDispatch + """ + return FusedDispatch.apply( + x.contiguous(), + token_indices, + token_probs, + num_experts, + group, + async_finish, + allocate_on_comm_stream, + ) + + def fused_combine(x, group, handle, async_finish=False, allocate_on_comm_stream=False): + """Perform fused combine operation if deep_ep is available. + + Args: + x: Input tensor + group: Process group + handle: Communication handle + async_finish: Whether to finish asynchronously + allocate_on_comm_stream: Whether to allocate on communication stream + + Returns: + Result of FusedCombine + """ + return FusedCombine.apply(x, group, handle, async_finish, allocate_on_comm_stream) + +else: + fused_dispatch = None + fused_combine = None diff --git a/bionemo-recipes/recipes/mixtral_native_te/fused_indices_converter.py b/bionemo-recipes/recipes/mixtral_native_te/fused_indices_converter.py new file mode 100644 index 0000000000..71bfaba4d3 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/fused_indices_converter.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/mixtral/fused_indices_converter.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +import math +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +import torch + + +def _identity_decorator(fn): + """Return the decorated callable unchanged (no-op decorator fallback).""" + return fn + + +null_decorator = _identity_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + triton.autotune = null_decorator + triton.heuristics = null_decorator + tl = MagicMock() + + +if TYPE_CHECKING: + import triton + import triton.language as tl + + +# Assign a block to a row([1,topk]), generate a local routing map([1,num_of_local_experts]) +@triton.jit +def _indices_to_multihot_kernel( + indices_ptr, + probs_in_indices_ptr, + multihot_indices_ptr, # bool + probs_in_multihot_ptr, + position_map_ptr, + num_of_local_experts: tl.constexpr, + num_of_local_experts_next_power_of_2: tl.constexpr, + topk: tl.constexpr, + topk_next_power_of_2: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # noqa: N803 +): + """Triton kernel for converting indices to multihot representation. + + Input: + indices: [num_of_tokens, topk] + probs_in_indices: [num_of_tokens, topk] + Output: + multihot_indices: [num_of_tokens, num_of_local_experts] + probs_in_multihot: [num_of_tokens, num_of_local_experts] + + Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, + then the kernel can process the following conversion: + + Input Example: + indices = [ + [0, 1], + [1, 2] + ] + probs_in_indices = [ + [0.1, 0.2], + [0.3, 0.4] + ] + Output Example: + multihot_indices = [ + [1, 1, -1, -1], + [-1, 1, 1, -1] + ] + probs_in_multihot = [ + [0.1, 0.2, 0.0, 0.0], + [0.0, 0.3, 0.4, 0.0] + ] + """ + # Prepare the [0, topk) row + topk_row = tl.arange(0, topk_next_power_of_2) + topk_row = tl.where(topk_row < topk, topk_row, -1) + topk_row_mask = topk_row != -1 + # Prepare the [0, num_of_local_experts) row + num_exp_row = tl.arange(0, num_of_local_experts_next_power_of_2) + num_exp_row = tl.where(num_exp_row < num_of_local_experts, num_exp_row, -1) + num_exp_row_mask = num_exp_row != -1 + + # Load a [1, topk] row from the indices buffer + row_idx = tl.program_id(0) + indices_row = tl.load(indices_ptr + row_idx * topk + topk_row, mask=topk_row_mask) + indices_row = tl.where(topk_row_mask, indices_row, -1) + probs_row = tl.load(probs_in_indices_ptr + row_idx * topk + topk_row, mask=topk_row_mask) + + # Get the position of the each index in the indices_row, which is saved for backwards + position_row = tl.where(indices_row != -1, topk_row, -1) + # Mask of the valid indices + mask = (indices_row != -1) & (indices_row < num_of_local_experts) + + row_idx_offset = row_idx * num_of_local_experts + # Store to initialize + tl.store(multihot_indices_ptr + row_idx_offset + num_exp_row, 0, mask=num_exp_row_mask) + tl.store(probs_in_multihot_ptr + row_idx_offset + num_exp_row, 0, mask=num_exp_row_mask) + tl.store(position_map_ptr + row_idx_offset + num_exp_row, -1, mask=num_exp_row_mask) + # Use barrier to make sure the initialization is done + tl.debug_barrier() + # Store the indices and probs_in_indices + tl.store(multihot_indices_ptr + row_idx_offset + indices_row, 1, mask) + tl.store(probs_in_multihot_ptr + row_idx_offset + indices_row, probs_row, mask) + # Store the position of the position_row for backwards + tl.store(position_map_ptr + row_idx_offset + indices_row, position_row, mask) + + +# Assign a block to a row([1,topk]), generate a probs_indices([1,topk]) +@triton.jit +def _multihot_to_indices_kernel( + probs_in_multihot_ptr, + position_map_ptr, + probs_indices_ptr, + num_of_local_experts: tl.constexpr, + num_of_local_experts_next_power_of_2: tl.constexpr, + topk: tl.constexpr, + topk_next_power_of_2: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # noqa: N803 +): + """Triton kernel for converting multihot representation to indices. + + Input: + probs_in_multihot: [num_of_tokens, num_of_local_experts] + position_map: [num_of_tokens, num_of_local_experts] + Output: + probs_indices: [num_of_tokens, topk] + + Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, + then the kernel can process the following conversion: + + Input Example: + probs_in_multihot = [ + [0.7, 0.8, 0.0, 0.0], + [0.0, 0.1, 0.9, 0.0] + ] + position_map = [ + [1, 1, -1, -1], + [-1, 1, 1, -1] + ] + Output Example: + probs_indices = [ + [0.7, 0.8], + [0.1, 0.9] + ] + """ + # Prepare the [0, topk) row + topk_row = tl.arange(0, topk_next_power_of_2) + topk_row = tl.where(topk_row < topk, topk_row, -1) + topk_row_mask = topk_row != -1 + # Prepare the [0, num_of_local_experts) row + num_exp_row = tl.arange(0, num_of_local_experts_next_power_of_2) + num_exp_row = tl.where(num_exp_row < num_of_local_experts, num_exp_row, -1) + num_exp_row_mask = num_exp_row != -1 + + # Load a [1, num_of_local_experts] row from the local routing map + row_idx = tl.program_id(0) + ptr_offset = row_idx * num_of_local_experts + num_exp_row + probs_in_multihot_row = tl.load(probs_in_multihot_ptr + ptr_offset, mask=num_exp_row_mask) + + # Get the original position of the valid value in the the indices + position_map_row = tl.load(position_map_ptr + ptr_offset, mask=num_exp_row_mask) + position_map_row = tl.where(num_exp_row_mask, position_map_row, -1) + mask = position_map_row != -1 + + # Store to initialize + tl.store(probs_indices_ptr + row_idx * topk + topk_row, 0, mask=topk_row_mask) + # Use barrier to make sure the initialization is done + tl.debug_barrier() + # Restore the indices and probs_indices + tl.store( + probs_indices_ptr + row_idx * topk + position_map_row, + probs_in_multihot_row, + mask, + ) + + +class IndicesToMultihot(torch.autograd.Function): + """Convert moe topk indices to multihot representation. + + This class implements a custom forward and backward propagation + operation for efficiently converting indices to multihot + representation. + It is an experimental feature and may change in future versions. + """ + + @staticmethod + def forward(ctx, indices, probs_indices, num_of_local_experts): # noqa: D417 + """Forward function for IndicesToMultihot. + + Convert indices to multihot representation. + + Args: + indices: [num_of_tokens, topk] + probs_indices: [num_of_tokens, topk] + num_of_local_experts: int + + Returns: + multihot_indices: [num_of_tokens, num_of_local_experts] + probs_in_multihot: [num_of_tokens, num_of_local_experts] + """ + assert HAVE_TRITON, "Triton is not installed" + num_of_tokens = indices.shape[0] + assert indices.shape == probs_indices.shape, "indices and probs_indices must have the same shape" + topk = indices.shape[1] + device = indices.device + multihot_indices = torch.empty((num_of_tokens, num_of_local_experts), dtype=torch.bool, device=device) + probs_in_multihot = torch.empty( + (num_of_tokens, num_of_local_experts), + dtype=probs_indices.dtype, + device=device, + ) + position_map = torch.empty((num_of_tokens, num_of_local_experts), dtype=torch.int32, device=device) + # Compute the next power of 2 for the topk and num_of_local_experts + topk_next_power_of_2 = 2 ** math.ceil(math.log2(topk)) + num_of_local_experts_next_power_of_2 = 2 ** math.ceil(math.log2(num_of_local_experts)) + grid = (num_of_tokens,) + _indices_to_multihot_kernel[grid]( + indices, + probs_indices, + multihot_indices, + probs_in_multihot, + position_map, + num_of_local_experts, + num_of_local_experts_next_power_of_2, + topk, + topk_next_power_of_2, + BLOCK_SIZE=32, # use only 1 warp per block + num_warps=1, + ) + + ctx.save_for_backward(position_map) + ctx.num_of_tokens = num_of_tokens + ctx.num_of_local_experts = num_of_local_experts + ctx.topk = topk + return multihot_indices, probs_in_multihot + + @staticmethod + def backward(ctx, grad_multihot_indices, grad_probs_in_multihot): # noqa: D417 + """Backward function for IndicesToMultihot. + + Convert multihot probs representation to indices. + indices is ignored in the backward function. + + Args: + grad_multihot_indices: [num_of_tokens, num_of_local_experts] + grad_probs_in_multihot: [num_of_tokens, num_of_local_experts] + + Returns: + grad_probs_indices: [num_of_tokens, topk] + """ + position_map = ctx.saved_tensors[0] + num_of_tokens = ctx.num_of_tokens + num_of_local_experts = ctx.num_of_local_experts + topk = ctx.topk + + # Initialize the gradient of the indices and probs_indices + grad_probs_indices = torch.empty( + (num_of_tokens, topk), dtype=grad_probs_in_multihot.dtype, device=grad_probs_in_multihot.device + ) + # Compute the next power of 2 for the topk and num_of_local_experts + topk_next_power_of_2 = 2 ** math.ceil(math.log2(topk)) + num_of_local_experts_next_power_of_2 = 2 ** math.ceil(math.log2(num_of_local_experts)) + + grid = (num_of_tokens,) + _multihot_to_indices_kernel[grid]( + # if the grad_probs_in_multihot is all-one/all-zero, + # overlapping stride will cause error without contiguous() + grad_probs_in_multihot.contiguous(), + position_map, + grad_probs_indices, + num_of_local_experts, + num_of_local_experts_next_power_of_2, + topk, + topk_next_power_of_2, + BLOCK_SIZE=32, # use only 1 warp per block + num_warps=1, + ) + return None, grad_probs_indices, None + + +def fused_indices_to_multihot(indices, probs_indices, num_of_local_experts): + """Convert moe topk indices to multihot representation.""" + return IndicesToMultihot.apply(indices, probs_indices, num_of_local_experts) diff --git a/bionemo-recipes/recipes/mixtral_native_te/fused_token_router.py b/bionemo-recipes/recipes/mixtral_native_te/fused_token_router.py new file mode 100644 index 0000000000..7dd3c68e5e --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/fused_token_router.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/mixtral/fused_token_router.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""DeepEP-backed TokenDispatcher using fused all-to-all and Triton index conversion.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch +import torch.distributed as dist +import transformer_engine.pytorch +from fused_a2a import fused_combine, fused_dispatch +from fused_indices_converter import HAVE_TRITON, fused_indices_to_multihot +from modeling_mixtral_te import DispatchOutput + + +@dataclass +class _FusedHandle: + """Opaque state for FusedTokenRouter between dispatch and combine.""" + + deepep_handle: Any + row_id_map: torch.Tensor + probs_multihot: torch.Tensor + recv_shape: torch.Size + + +class FusedTokenRouter: + """TokenDispatcher using DeepEP fused communication and Triton index conversion. + + Dispatch flow: + 1. ``fused_dispatch`` — DeepEP all-to-all sends tokens to expert-owning ranks. + 2. ``fused_indices_to_multihot`` — Triton kernel converts sparse ``[N, top_k]`` + indices to dense ``[N, num_local_experts]`` mask with differentiable probs. + 3. ``moe_permute(map_type="mask")`` — TE sorts received tokens by local expert. + + Combine flow: + 1. ``moe_unpermute(map_type="mask")`` — TE unsorts and applies routing weights. + 2. ``fused_combine`` — DeepEP reverse all-to-all sends results back. + + Args: + num_experts: Total number of experts (global, across all EP ranks). + num_local_experts: Number of experts hosted on this rank. + hidden_size: Hidden dimension size. + ep_size: Expert parallel world size. + """ + + def __init__(self, num_experts: int, num_local_experts: int, hidden_size: int, ep_size: int): + """Initialize the FusedTokenRouter.""" + if fused_dispatch is None or fused_combine is None: + raise ImportError("deep_ep is required for FusedTokenRouter. Install it with: pip install deep_ep") + if not HAVE_TRITON: + raise ImportError( + "Triton is required for FusedTokenRouter (used by fused_indices_to_multihot). " + "Install it with: pip install triton" + ) + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.ep_size = ep_size + self._ep_group: dist.ProcessGroup | None = None + + def set_ep_group(self, ep_group: dist.ProcessGroup) -> None: + """Set the expert-parallel process group for communication.""" + self._ep_group = ep_group + + def dispatch( + self, + hidden_states: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + ) -> DispatchOutput: + """Dispatch tokens to their assigned experts via DeepEP fused all-to-all. + + Args: + hidden_states: Flattened input tensor of shape ``[N, H]``. + selected_experts: Expert assignments, shape ``[N, top_k]``, int. + routing_weights: Normalized routing probabilities, shape ``[N, top_k]``, float32. + + Returns: + DispatchOutput with expert-sorted tokens, per-expert counts, and an opaque handle. + """ + assert self._ep_group is not None, "EP group must be set via set_ep_group() before dispatch" + + # Step 1: Fused all-to-all dispatch (DeepEP) + recv_x, recv_indices, recv_probs, tokens_per_expert, deepep_handle = fused_dispatch( + hidden_states, + selected_experts, + routing_weights.float(), # DeepEP requires float32 probs + self.num_experts, + self._ep_group, + ) + + # Step 2: Convert sparse [N, top_k] indices to dense [N, num_local_experts] multihot (Triton) + # Note: DeepEP returns local expert indices (0-based per rank), not global indices. + multihot_mask, probs_multihot = fused_indices_to_multihot(recv_indices, recv_probs, self.num_local_experts) + + # Step 3: Permute received tokens by local expert for GroupedLinear + num_out_tokens = int(tokens_per_expert.sum().item()) + permuted_x, row_id_map = transformer_engine.pytorch.moe_permute( + recv_x, multihot_mask.to(torch.int32), num_out_tokens=num_out_tokens, map_type="mask" + ) + + handle = _FusedHandle( + deepep_handle=deepep_handle, + row_id_map=row_id_map, + probs_multihot=probs_multihot, + recv_shape=recv_x.shape, + ) + + return DispatchOutput( + expert_input=permuted_x, + tokens_per_expert=tokens_per_expert.tolist(), + handle=handle, + ) + + def combine(self, expert_output: torch.Tensor, handle: _FusedHandle) -> torch.Tensor: + """Combine expert outputs back to the original token order. + + Args: + expert_output: Expert output tensor of shape ``[total_recv_tokens, H]``. + handle: Opaque state from ``dispatch()``. + + Returns: + Combined output tensor of shape ``[N, H]`` with routing weights applied. + """ + # Step 1: Unpermute expert output and apply routing weights + unpermuted = transformer_engine.pytorch.moe_unpermute( + expert_output, + handle.row_id_map, + merging_probs=handle.probs_multihot, + restore_shape=handle.recv_shape, + map_type="mask", + ) + + # Step 2: Fused all-to-all combine (reverse dispatch) + combined, _ = fused_combine(unpermuted, self._ep_group, handle.deepep_handle) + + return combined diff --git a/bionemo-recipes/recipes/mixtral_native_te/hydra_config/L0_sanity.yaml b/bionemo-recipes/recipes/mixtral_native_te/hydra_config/L0_sanity.yaml new file mode 100644 index 0000000000..a06dca8ac9 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/hydra_config/L0_sanity.yaml @@ -0,0 +1,53 @@ +defaults: + - defaults + - _self_ + +# Use tiny Mixtral config for fast convergence testing +config_name_or_path: ./model_configs/mixtral-8x1B +config_kwargs: + hidden_size: 384 + intermediate_size: 1536 + num_hidden_layers: 2 + num_attention_heads: 6 + num_key_value_heads: 6 + num_local_experts: 4 + num_experts_per_tok: 2 + max_position_embeddings: 256 + attn_input_format: "bshd" + self_attn_mask_type: "causal" + router_jitter_noise: 0.0 + +num_train_steps: 20 + +use_torch_compile: false +use_meta_device: false # small model fits on device directly; avoids meta-device complexity with EP + +# EP=1 for single-GPU sanity testing. Multi-GPU EP tests are in test_fsdp_ep.py. +expert_parallel_size: 1 + +dataset: + tokenizer_name_or_path: nvidia/Llama-3.1-8B-Instruct-FP8 + micro_batch_size: 1 + num_workers: 0 + max_seq_length: 256 + stride: 32 + text_column: "text" + load_dataset_kwargs: + path: "parquet" + split: "train" + data_files: "dlcm_sanity_dataset.parquet" + streaming: true + +wandb: + name: "mixtral_8x1B_sanity" + mode: "offline" + +lr_scheduler_kwargs: + num_warmup_steps: 10 + +checkpoint: + ckpt_dir: null + save_final_model: false + +logger: + frequency: 1 diff --git a/bionemo-recipes/recipes/mixtral_native_te/hydra_config/L2_lingua_8x1B.yaml b/bionemo-recipes/recipes/mixtral_native_te/hydra_config/L2_lingua_8x1B.yaml new file mode 100644 index 0000000000..a572e49767 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/hydra_config/L2_lingua_8x1B.yaml @@ -0,0 +1,57 @@ +# Config to match the Llama 3 Lingua 1B training recipe using Mixtral-8x1B architecture. + +defaults: + - defaults + - _self_ + +config_name_or_path: ./model_configs/mixtral-8x1B + +config_kwargs: + attn_input_format: thd + self_attn_mask_type: padding_causal + +use_sequence_packing: true + +wandb: + name: mixtral-lingua-8x1B-te + project: null + +num_train_steps: 60_000 + +dataset: + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + micro_batch_size: 2 + num_workers: 8 + max_seq_length: 4096 + stride: 512 + buffer_size: 50_000 + use_stateful_dataloader: false + load_dataset_kwargs: + path: "mlfoundations/dclm-baseline-1.0" + data_dir: "global-shard_01_of_10" + split: "train" + streaming: true + +adamw_kwargs: + lr: 0.003 + fused: true + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 0.033 + +lr_scheduler_kwargs: + num_warmup_steps: 5_000 + num_decay_steps: 55_000 + min_lr_ratio: 1e-6 + +checkpoint: + ckpt_dir: null + save_final_model: true + resume_from_checkpoint: true + save_every_n_steps: 10_000 + async_save: true + +profiler: + enabled: false + start_step: 250 + end_step: 260 diff --git a/bionemo-recipes/recipes/mixtral_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/mixtral_native_te/hydra_config/defaults.yaml new file mode 100644 index 0000000000..0e9ba6ff98 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/hydra_config/defaults.yaml @@ -0,0 +1,91 @@ +# Training config +use_te: true # Whether to use TransformerEngine layers through NVMixtralForCausalLM (if false, use HF's MixtralForCausalLM) +config_name_or_path: ??? # E.g., ./model_configs/mixtral-8x1B or a HuggingFace model name +config_kwargs: {} # Arguments to pass to the AutoConfig.from_pretrained method + +num_train_steps: ??? +grad_acc_steps: 1 # Gradient accumulation steps - effective batch = micro_batch_size * num_gpus * grad_acc_steps + +use_meta_device: true +use_torch_compile: false +use_sequence_packing: false + +# Expert parallelism: number of GPUs per expert-parallel group. +# Must divide world_size evenly. Set > 1 to enable MoE expert parallelism. +expert_parallel_size: 1 +# Token dispatcher for EP runs. Options: "alltoall" (NCCL, always available) or +# "fused_deepep" (requires deep_ep + Triton). token_dispatcher_fallback controls +# what happens when fused_deepep is unavailable: "alltoall" to fall back silently, +# or "error" to raise immediately. +token_dispatcher: alltoall +token_dispatcher_fallback: error + +dataset: + tokenizer_name_or_path: ??? # Set to the path of your tokenizer (e.g., nvidia/Llama-3.1-8B-Instruct-FP8) + micro_batch_size: 2 + num_workers: 1 + max_seq_length: 4096 # Window size for text sequences + stride: 512 # Overlap for windowing + buffer_size: 500_000 # Shuffle buffer size + use_stateful_dataloader: false + pad_sequences_to_be_divisible_by: null + load_dataset_kwargs: + path: ??? + split: "train" + streaming: true + +# WandB config +wandb: + name: ??? + project: null # Optional: set to your wandb project name + +# TransformerEngine FP8 config. See +# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more information on +# supported formats. +fp8_config: + enabled: false + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + +# Optimizer config +adamw_kwargs: + lr: 3e-3 + fused: true + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 0.033 + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 2_000 + num_decay_steps: 498_000 + min_lr_ratio: 0.000001 + +# Checkpoint config +checkpoint: + ckpt_dir: ??? + save_final_model: true + resume_from_checkpoint: true + save_every_n_steps: 50 + max_checkpoints: 5 # Keep only the latest 5 checkpoints + async_save: true # Whether to save the checkpoint asynchronously, currently only supported with FSDP2. + +logger: + frequency: 100 + +fp8_stats_config: + enabled: false + fp8_stats_file: ./fp8_debugging_stats.yaml + fp8_log_dir: ./log_fp8_stats + +profiler: + enabled: false + start_step: 10 + end_step: 15 diff --git a/bionemo-recipes/recipes/mixtral_native_te/model_configs/mixtral-8x1B/config.json b/bionemo-recipes/recipes/mixtral_native_te/model_configs/mixtral-8x1B/config.json new file mode 100644 index 0000000000..712cc4cd49 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/model_configs/mixtral-8x1B/config.json @@ -0,0 +1,29 @@ +{ + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 4096, + "model_type": "mixtral", + "num_attention_heads": 16, + "num_hidden_layers": 16, + "num_key_value_heads": 8, + "num_local_experts": 8, + "num_experts_per_tok": 2, + "output_router_logits": false, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.0, + "router_jitter_noise": 0.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/bionemo-recipes/recipes/mixtral_native_te/model_configs/mixtral-8x7B/config.json b/bionemo-recipes/recipes/mixtral_native_te/model_configs/mixtral-8x7B/config.json new file mode 100644 index 0000000000..48b4408f25 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/model_configs/mixtral-8x7B/config.json @@ -0,0 +1,29 @@ +{ + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 4096, + "model_type": "mixtral", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "num_local_experts": 8, + "num_experts_per_tok": 2, + "output_router_logits": false, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.0, + "router_jitter_noise": 0.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/bionemo-recipes/recipes/mixtral_native_te/modeling_mixtral_te.py b/bionemo-recipes/recipes/mixtral_native_te/modeling_mixtral_te.py new file mode 100644 index 0000000000..23ad635f88 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/modeling_mixtral_te.py @@ -0,0 +1,1176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/mixtral/modeling_mixtral_te.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""TransformerEngine-optimized Mixtral model with Mixture of Experts.""" + +import logging +import os +import warnings +from collections import OrderedDict +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Any, ClassVar, ContextManager, Protocol, Unpack + +import torch +import torch.distributed as dist +import torch.nn as nn +import transformer_engine.common.recipe +import transformer_engine.pytorch +import transformers +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict +from torch.distributed.tensor import DTensor, Shard +from torch.distributed.tensor.device_mesh import DeviceMesh +from transformer_engine.pytorch.attention import InferenceParams +from transformer_engine.pytorch.attention.inference import PagedKVCacheManager +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers import MixtralConfig, PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers.utils.generic import TransformersKwargs + + +logger = logging.getLogger(__name__) + + +AUTO_MAP = { + "AutoConfig": "modeling_mixtral_te.NVMixtralConfig", + "AutoModel": "modeling_mixtral_te.NVMixtralModel", + "AutoModelForCausalLM": "modeling_mixtral_te.NVMixtralForCausalLM", +} + + +class NVMixtralConfig(MixtralConfig): + """NVMixtral configuration.""" + + # Attention input format: + # "bshd" = Batch, Sequence, Head, Dimension (standard padded format) + # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) + attn_input_format: str = "thd" + self_attn_mask_type: str = "padding_causal" + layer_precision: list[str | None] | None = None + use_quantized_model_init: bool = False + expert_parallel_size: int = 1 + moe_aux_loss_coeff: float = 0.0 + + def __init__(self, **kwargs): + """Initialize the NVMixtralConfig with additional TE-related config options.""" + super().__init__(**kwargs) + + if self.layer_precision is not None: + if len(self.layer_precision) != self.num_hidden_layers: + raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}") + for precision in self.layer_precision: + if precision not in {"fp8", "fp4", None}: + raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}') + + if self.num_local_experts % self.expert_parallel_size != 0: + raise ValueError( + f"num_local_experts ({self.num_local_experts}) must be divisible by " + f"expert_parallel_size ({self.expert_parallel_size})" + ) + + +@dataclass +class DispatchOutput: + """Output of TokenDispatcher.dispatch(). + + Attributes: + expert_input: Tokens sorted by local expert, shape ``[total_recv_tokens, H]``. + tokens_per_expert: Token count per local expert. + handle: Opaque state needed by ``combine()`` to reverse the dispatch. + """ + + expert_input: torch.Tensor + tokens_per_expert: list[int] + handle: Any + + +class TokenDispatcher(Protocol): + """Protocol for MoE token dispatch/combine strategies. + + Encapsulates the full dispatch cycle (permute -> communicate -> sort) and + combine cycle (unsort -> communicate -> unpermute) so that the MoE block + is agnostic to the communication backend (NCCL all-to-all, HybridEP, etc.). + """ + + def dispatch( + self, + hidden_states: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + ) -> DispatchOutput: + """Dispatch tokens to their assigned experts. + + Args: + hidden_states: Flattened input tensor of shape ``[N, H]``. + selected_experts: Expert assignments, shape ``[N, top_k]``, int. + routing_weights: Normalized routing probabilities, shape ``[N, top_k]``, float32. + + Returns: + DispatchOutput with expert-sorted tokens, per-expert counts, and an opaque handle. + """ + ... + + def combine( + self, + expert_output: torch.Tensor, + handle: Any, + ) -> torch.Tensor: + """Combine expert outputs back to the original token order. + + Args: + expert_output: Expert output tensor of shape ``[total_recv_tokens, H]``. + handle: Opaque state from ``dispatch()``. + + Returns: + Combined output tensor of shape ``[N, H]`` with routing weights applied. + """ + ... + + def set_ep_group(self, ep_group: dist.ProcessGroup) -> None: + """Set the expert-parallel process group for communication.""" + ... + + +class NVMixtralPreTrainedModel(PreTrainedModel): + """Base class for NVMixtral models.""" + + config_class = NVMixtralConfig + base_model_prefix = "model" + _no_split_modules = ("NVMixtralDecoderLayer",) + _skip_keys_device_placement = ("past_key_values",) + _do_not_quantize = ("lm_head", "model.layers.*.mlp.gate") # Flag for testing that these layers are not quantized. + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # After reset_parameters materializes GroupedLinear views on CUDA, + # re-stack them into the authoritative stacked parameters. + for module in self.modules(): + if isinstance(module, NVMixtralSparseMoeBlock): + module._restack_from_views() + + self.model.embed_tokens.to_empty(device="cuda") + self.model.embed_tokens.apply(self._init_weights) + + self.model.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=self.model.config).inv_freq.to("cuda") + + self.tie_weights() + + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys.""" + state_dict = super().state_dict(*args, **kwargs) + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVMixtralSparseMoeBlock(nn.Module): + """Mixture of Experts block using TransformerEngine GroupedLinear.""" + + def __init__(self, config: MixtralConfig, dispatcher: TokenDispatcher | None = None): + """Initialize the sparse MoE block.""" + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.jitter_noise = config.router_jitter_noise + + # Expert parallelism + self.ep_size = getattr(config, "expert_parallel_size", 1) + self.num_local_experts = self.num_experts // self.ep_size + self.moe_aux_loss_coeff = getattr(config, "moe_aux_loss_coeff", 0.0) + self._aux_loss: torch.Tensor = torch.tensor(0.0) + self.initializer_range = config.initializer_range + + self.dispatcher: TokenDispatcher = dispatcher or AllToAllTokenDispatcher( + self.num_experts, + self.num_local_experts, + self.hidden_size, + self.ep_size, + ) + + device = "meta" if torch.get_default_device() == torch.device("meta") else "cuda" + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + # Router always outputs num_experts logits (replicated across EP ranks) + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.gate = transformer_engine.pytorch.Linear( + self.hidden_size, + self.num_experts, + bias=False, + device=device, + params_dtype=config.dtype, + init_method=_init_method, + ) + + # Expert FFNs — only num_local_experts per rank when EP > 1 + self.experts_gate_up = transformer_engine.pytorch.GroupedLinear( + num_gemms=self.num_local_experts, + in_features=self.hidden_size, + out_features=2 * self.intermediate_size, + bias=False, + params_dtype=config.dtype, + device=device, + init_method=_init_method, + ) + self.experts_down = transformer_engine.pytorch.GroupedLinear( + num_gemms=self.num_local_experts, + in_features=self.intermediate_size, + out_features=self.hidden_size, + bias=False, + params_dtype=config.dtype, + device=device, + init_method=_init_method, + ) + + # Stack per-expert weights into single parameters (authoritative weight store). + # GroupedLinear's _parameters dict is emptied; weight attributes are set as views + # so that reset_parameters() / _get_weight_tensors() can still find them. + self.experts_gate_up_weight = nn.Parameter( + torch.stack( + [self.experts_gate_up._parameters.pop(f"weight{i}").data for i in range(self.num_local_experts)] + ) + ) # [num_local_experts, 2*intermediate_size, hidden_size] + + self.experts_down_weight = nn.Parameter( + torch.stack([self.experts_down._parameters.pop(f"weight{i}").data for i in range(self.num_local_experts)]) + ) # [num_local_experts, hidden_size, intermediate_size] + + # Set views back on GroupedLinear so getattr(self, "weight{i}") still works + # (needed by GroupedLinear.reset_parameters and _get_weight_tensors). + self._sync_expert_views() + + def _restack_from_views(self) -> None: + """Re-create stacked parameters on CUDA after meta init. + + Called by ``init_empty_weights()`` after ``reset_parameters()`` has been called + on all TE modules. Since GroupedLinear has no registered parameters (we popped them), + its ``reset_parameters()`` cannot move them from meta to CUDA. This method explicitly + creates the stacked parameters on CUDA and reinitializes them. + """ + device = torch.cuda.current_device() + for attr_name in ("experts_gate_up_weight", "experts_down_weight"): + old_param = getattr(self, attr_name) + if isinstance(old_param.data, DTensor): + # FSDP2 has sharded this param; materialize the local shard on CUDA + # and reconstruct the DTensor wrapper so FSDP2 can manage it. + local_data = old_param.data.to_local() + new_local = torch.empty(local_data.shape, dtype=local_data.dtype, device=device) + torch.nn.init.normal_(new_local, mean=0.0, std=self.initializer_range) + new_dtensor = DTensor.from_local( + new_local, + device_mesh=old_param.data.device_mesh, + placements=old_param.data.placements, + ) + setattr(self, attr_name, nn.Parameter(new_dtensor)) + else: + new_data = torch.empty_like(old_param, device=device) + torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range) + setattr(self, attr_name, nn.Parameter(new_data)) + + # Re-sync views to point to the new stacked parameter + self._sync_expert_views() + + def _sync_expert_views(self) -> None: + """Set GroupedLinear weight attributes as views of the stacked parameters. + + GroupedLinear internally uses ``getattr(self, f"weight{i}")`` in methods like + ``reset_parameters()`` and ``_get_weight_tensors()``. After popping the original + parameters, we set views of the stacked tensor so these methods keep working. + Uses ``object.__setattr__`` to bypass ``nn.Module.__setattr__`` and avoid + re-registering them as parameters. + """ + gate_up_w = self.experts_gate_up_weight + if isinstance(gate_up_w, DTensor): + gate_up_w = gate_up_w.to_local() + num_local = gate_up_w.shape[0] + for i in range(num_local): + object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i]) + + down_w = self.experts_down_weight + if isinstance(down_w, DTensor): + down_w = down_w.to_local() + num_local_down = down_w.shape[0] + for i in range(num_local_down): + object.__setattr__(self.experts_down, f"weight{i}", down_w[i]) + + def set_ep_group(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None: + """Set the expert-parallel process group and convert stacked weights to DTensors. + + Must be called before the first forward pass when ``ep_size > 1``. + + Args: + ep_group: A ``torch.distributed.ProcessGroup`` whose world size equals ``self.ep_size``. + ep_mesh: A 1-D ``DeviceMesh`` for expert parallelism. Used to wrap stacked weights + as ``DTensor(Shard(0))`` so that DCP can save/load/reshard them automatically. + """ + self.dispatcher.set_ep_group(ep_group) + # Convert stacked parameters to DTensors with Shard(0) on the expert dimension. + # Global shape is [num_experts, ...]; each rank stores [num_local_experts, ...]. + # Guard: only wrap plain tensors; skip if already DTensors (e.g. repeated calls). + if not isinstance(self.experts_gate_up_weight.data, DTensor): + self.experts_gate_up_weight = nn.Parameter( + DTensor.from_local(self.experts_gate_up_weight.data, device_mesh=ep_mesh, placements=[Shard(0)]) + ) + if not isinstance(self.experts_down_weight.data, DTensor): + self.experts_down_weight = nn.Parameter( + DTensor.from_local(self.experts_down_weight.data, device_mesh=ep_mesh, placements=[Shard(0)]) + ) + + def _expert_ffn(self, tokens: torch.Tensor, m_splits: list[int]) -> torch.Tensor: + """Run the expert SwiGLU FFN (gate_up -> silu -> down). + + Args: + tokens: Input tensor of shape [total_tokens, H], sorted by expert. + m_splits: Number of tokens per local expert. + + Returns: + Output tensor of shape [total_tokens, H]. + """ + gate_up_output = self.experts_gate_up(tokens, m_splits=m_splits) + gate_output, up_output = gate_up_output.chunk(2, dim=-1) + intermediate = torch.nn.functional.silu(gate_output) * up_output + return self.experts_down(intermediate, m_splits=m_splits) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass for the MoE block. + + Args: + hidden_states: Input tensor of shape [B, S, H] (bshd) or [T, H] (thd). + + Returns: + Output tensor of the same shape as the input. + """ + original_shape = hidden_states.shape + + # Apply multiplicative jitter noise to hidden states during training to encourage load balancing + if self.training and self.jitter_noise > 0: + hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_( + 1.0 - self.jitter_noise, 1.0 + self.jitter_noise + ) + + # Flatten to [N, H] for routing + if hidden_states.dim() == 3: + hidden_states = hidden_states.reshape(-1, self.hidden_size) + + # Router: compute expert assignments + with transformer_engine.pytorch.autocast(enabled=False): + # Keep the router logits in bf16 during FP8 training + router_logits = self.gate(hidden_states) # [N, num_experts] + + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # [N, top_k] + # Normalize routing weights + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + + # Auxiliary load-balancing loss (switch transformer style) + if self.moe_aux_loss_coeff > 0: + num_tokens = hidden_states.shape[0] + m_splits_tensor = torch.bincount(selected_experts.reshape(-1), minlength=self.num_experts).int() + # f_i: fraction of tokens dispatched to each expert + f = m_splits_tensor.float() / (num_tokens * self.top_k) + # P_i: mean router probability per expert (over all tokens) + router_probs = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float32) + p = router_probs.mean(dim=0) + self._aux_loss = self.moe_aux_loss_coeff * self.num_experts * (f * p).sum() + else: + self._aux_loss = torch.tensor(0.0, device=hidden_states.device) + + # Populate GroupedLinear weight attributes from stacked parameters. + # For EP, the stacked parameter is a DTensor; .to_local() gives the local shard. + self._sync_expert_views() + + dispatch_output = self.dispatcher.dispatch(hidden_states, selected_experts, routing_weights) + + expert_input = dispatch_output.expert_input + tokens_per_expert = dispatch_output.tokens_per_expert + + # MXFP8 requires both tensor dims divisible by 32. Upstream attention layers + # get this from the collator (pad_sequences_to_be_divisible_by=32), but after + # all-to-all dispatch the per-rank token count is data-dependent (routing + # decisions pick different expert loads). Pad here so GroupedLinear's MXFP8 + # kernels don't assert, then slice the padding off afterwards. + n_tokens = expert_input.shape[0] + mxfp8_pad = (32 - n_tokens % 32) % 32 + if mxfp8_pad: + expert_input = torch.nn.functional.pad(expert_input, (0, 0, 0, mxfp8_pad)) + # Attribute the padding tokens to the last expert so m_splits still sums correctly. + tokens_per_expert = list(tokens_per_expert) + tokens_per_expert[-1] += mxfp8_pad + + expert_output = self._expert_ffn(expert_input, tokens_per_expert) + + if mxfp8_pad: + expert_output = expert_output[:n_tokens] + + output = self.dispatcher.combine(expert_output, dispatch_output.handle) + + return output.reshape(original_shape) + + +class NVMixtralDecoderLayer(nn.Module): + """Mixtral decoder layer using TE attention and MoE MLP.""" + + def __init__(self, config: MixtralConfig, layer_idx: int, dispatcher: TokenDispatcher | None = None): + """Initialize the decoder layer.""" + super().__init__() + self.hidden_size = config.hidden_size + + device = "meta" if torch.get_default_device() == torch.device("meta") else "cuda" + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.self_attention = transformer_engine.pytorch.MultiheadAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_gqa_groups=config.num_key_value_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + attention_dropout=0, + fuse_qkv_params=True, + qkv_weight_interleaved=True, + normalization="RMSNorm", + input_layernorm=True, + qkv_format=config.attn_input_format, + attn_mask_type=config.self_attn_mask_type, + layer_number=layer_idx + 1, + params_dtype=config.dtype, + device=device, + init_method=_init_method, + output_layer_init_method=_init_method, + ) + + self.post_attention_layernorm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device=device, + ) + + self.mlp = NVMixtralSparseMoeBlock(config, dispatcher) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + rotary_pos_emb: torch.Tensor | None = None, + inference_params: InferenceParams | None = None, + **kwargs, + ) -> torch.Tensor: + """Forward pass for the decoder layer.""" + # Self attention with fused input layernorm + attn_output = self.self_attention( + hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=inference_params, + cu_seqlens_q=kwargs.get("cu_seqlens_q", None), + cu_seqlens_kv=kwargs.get("cu_seqlens_kv", None), + cu_seqlens_q_padded=kwargs.get("cu_seqlens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seqlens_kv_padded", None), + max_seqlen_q=kwargs.get("max_seqlen_q", None), + max_seqlen_kv=kwargs.get("max_seqlen_kv", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + # Residual connection + hidden_states = hidden_states + attn_output + + # Post-attention layernorm + MoE MLP + residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NVMixtralModel(NVMixtralPreTrainedModel): + """Mixtral model implemented in Transformer Engine.""" + + def __init__( + self, + config: MixtralConfig, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + dispatcher: TokenDispatcher | None = None, + ): + """Initialize the NVMixtral model. + + Args: + config: The configuration of the model. + fp8_recipe: The FP8 recipe for the model. + fp4_recipe: The FP4 recipe for the model. + dispatcher: The token dispatcher for the model. If None, the default AllToAllTokenDispatcher will be used. + """ + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe + self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe + + if self.config.layer_precision is None: + if fp8_recipe is not None and fp4_recipe is not None: + raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.") + if fp8_recipe is not None: + warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning) + self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers + elif fp4_recipe is not None: + raise RuntimeError( + "FP4 recipe provided but no layer_precision configured. " + "Set layer_precision explicitly when using FP4." + ) + + if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None: + raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.") + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + + layers: list[NVMixtralDecoderLayer] = [] + for layer_idx in range(config.num_hidden_layers): + with self.get_autocast_context(layer_idx, init=True): + layers += [NVMixtralDecoderLayer(config, layer_idx, dispatcher)] + + self.layers = nn.ModuleList(layers) + + self.norm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + + self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq + + self.gradient_checkpointing = False + + self.post_init() + + def set_ep_groups(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None: + """Propagate an expert-parallel process group and mesh to every MoE block. + + Args: + ep_group: The EP process group to set on each ``NVMixtralSparseMoeBlock``. + ep_mesh: A 1-D ``DeviceMesh`` for expert parallelism. + """ + for layer in self.layers: + layer.mlp.set_ep_group(ep_group, ep_mesh) + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: InferenceParams | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + """Forward pass for the NVMixtral model.""" + all_hidden_states = [] + output_hidden_states = kwargs.get("output_hidden_states", False) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # TE-specific input handling + has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]] + should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd" + + if should_pack_inputs: + assert attention_mask is not None, "Attention mask is required when packing BSHD inputs." + batch_size = hidden_states.size(0) + padded_seq_len = input_ids.size(1) + hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask) + kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens + kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + hidden_states = hidden_states.squeeze(0) + + if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2: + # Convert HF mask (1=attend, 0=pad) to TE boolean mask (True=masked, False=attend) + attention_mask = ~attention_mask[:, None, None, :].bool() + + if isinstance(past_key_values, InferenceParams): + lengths = ( + attention_mask.sum(dim=1).tolist() + if attention_mask.shape == input_ids.shape + else [1] * input_ids.shape[0] + ) + past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths))) + + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings) + + with self.get_autocast_context(None, outer=True): + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + with self.get_autocast_context(layer_idx): + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + if should_pack_inputs: + hidden_states = _pad_input(hidden_states, indices, batch_size, padded_seq_len) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + def get_autocast_context( + self, layer_number: int | None, init: bool = False, outer: bool = False + ) -> ContextManager: + """Return the appropriate TE autocast context manager for a given layer. + + This function handles both the quantized_model_init during layer creation and the te.autocast() during layer + forward pass. + + Args: + layer_number: The 0-indexed layer number. + init: Whether to return a `quantized_model_init` context for layer initialization. + outer: Whether to return a global te.autocast() context to wrap the entire model stack. + """ + if self.config.layer_precision is None: + return nullcontext() + + if outer: + if "fp8" not in self.config.layer_precision: + return nullcontext() + if self._fp8_recipe is None: + warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning) + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp8_recipe) + + precision = self.config.layer_precision[layer_number] + recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision) + + if init and self.config.use_quantized_model_init: + if precision in ("fp8", "fp4"): + return transformer_engine.pytorch.quantized_model_init(recipe=recipe) + return nullcontext() + + if precision == "fp8": + if recipe is None: + warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning) + return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe) + if precision == "fp4": + if recipe is None: + raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.") + return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe) + return transformer_engine.pytorch.autocast(enabled=False) + + +class NVMixtralForCausalLM(NVMixtralPreTrainedModel, transformers.GenerationMixin): + """Mixtral model with causal language head.""" + + _tied_weights_keys: ClassVar[list[str]] = [] + + def __init__( + self, + config, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + dispatcher: TokenDispatcher | None = None, + ): + """Initialize the NVMixtralForCausalLM model. + + Args: + config: The configuration of the model. + fp8_recipe: The FP8 recipe for the model. + fp4_recipe: The FP4 recipe for the model. + dispatcher: The token dispatcher for expert parallelism. If None, the default + AllToAllTokenDispatcher will be used. + """ + super().__init__(config) + self.model = NVMixtralModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe, dispatcher=dispatcher) + self.vocab_size = config.vocab_size + + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.lm_head = transformer_engine.pytorch.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: tuple[tuple[torch.Tensor, ...], ...] | None = None, + inputs_embeds: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + shift_labels: torch.Tensor | None = None, + use_cache: bool | None = None, + cache_position: torch.Tensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + """Forward pass for the NVMixtralForCausalLM model.""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + with transformer_engine.pytorch.autocast(enabled=False): + if hidden_states.ndim == 3: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + else: + logits = self.lm_head(hidden_states[slice_indices, :]) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, shift_labels=shift_labels, vocab_size=self.config.vocab_size, **kwargs + ) + + # Collect auxiliary load-balancing loss from all MoE layers + if self.config.moe_aux_loss_coeff > 0 and loss is not None: + aux_loss = sum(layer.mlp._aux_loss for layer in self.model.layers) + loss = loss + aux_loss + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def save_final_model_ep( + model: NVMixtralForCausalLM, + save_directory: str | os.PathLike, + dist_config=None, +) -> None: + """Gather all EP-sharded expert weights and save as safetensors. + + Uses ``get_model_state_dict(full_state_dict=True)`` to all-gather DTensors, + matching the pattern from ``save_final_model_fsdp2`` in the llama3 checkpoint module. + + All ranks must call this function. Only rank 0 writes files. + + Args: + model: The NVMixtral model (may have DTensor expert parameters). + save_directory: Directory to save ``model.safetensors`` and config. + dist_config: Optional distributed config with ``is_main_process()`` method. + If ``None``, only rank 0 saves. + """ + from safetensors.torch import save_file + + model_state_dict = get_model_state_dict( + model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + # Filter out TE _extra_state keys + model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")} + + is_main = dist_config.is_main_process() if dist_config is not None else (dist.get_rank() == 0) + if is_main: + os.makedirs(save_directory, exist_ok=True) + save_file(model_state_dict, os.path.join(save_directory, "model.safetensors")) + model.config.save_pretrained(save_directory) + logger.info(f"Saved final EP model to {save_directory}") + + +# Required for torch.compile'd functions below (_pad_input, _unpad_input, _build_expert_sort_indices) +# that use data-dependent scalar values (e.g., max_seqlen_in_batch.item()) or produce tensors +# whose shape depends on input data (e.g., repeat_interleave with tensor counts). +# These must be set at module level because torch.compile traces lazily on first call, +# so a scoped setting would not be active at trace time. +torch._dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.capture_dynamic_output_shape_ops = True + + +@torch.compile +def _pad_input(hidden_states, indices, batch, seqlen): + """Convert a THD tensor to a BSHD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +@torch.compile +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """Convert a BSHD tensor to a THD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + """ + batch_size = hidden_states.size(0) + seq_length = hidden_states.size(1) + + if attention_mask.shape[1] != seq_length: + return ( + hidden_states.squeeze(1), + torch.arange(batch_size, dtype=torch.int64, device=hidden_states.device), + torch.arange(batch_size + 1, dtype=torch.int32, device=hidden_states.device), + 1, + 1, + ) + + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + hidden_states.reshape(-1, *hidden_states.shape[2:])[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +class HFInferenceParams(InferenceParams): + """Extension of the InferenceParams class to support HF generate() and beam search.""" + + # Required by transformers >= 5.4 _valid_auto_compile_criteria(); this + # custom TE-based cache is not compatible with torch.compile generate(). + is_compileable = False + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Return the current cached sequence length. + + Required by HuggingFace transformers generate() to determine how many + tokens have already been cached. + """ + if not self.sequences: + return 0 + return max(self.sequences.values()) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder the cache based on the beam indices.""" + if isinstance(self.cache_manager, PagedKVCacheManager): + raise NotImplementedError("Beam search is not supported for paged cache manager.") + for layer_number, (key_cache, value_cache) in self.cache_manager.cache.items(): + updated_key_cache = key_cache.index_select(0, beam_idx) + updated_value_cache = value_cache.index_select(0, beam_idx) + self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache) + + +@torch.compile(fullgraph=True) +def _build_expert_sort_indices(recv_counts: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Build sort and unsort index tensors for reordering received tokens by local expert. + + After all-to-all, tokens arrive grouped by source rank: + ``[src0_exp0..src0_expL, src1_exp0..src1_expL, ...]``. ``GroupedLinear`` expects them + grouped by expert: ``[all_exp0, all_exp1, ...]``. + + Uses only vectorized tensor operations (no ``.item()`` calls or Python-level loops) + so that it is compatible with ``torch.compile(fullgraph=True)``. + + Args: + recv_counts: Integer tensor of shape ``[ep_size, num_local_experts]`` giving the + number of tokens received from each source rank for each local expert. + + Returns: + A ``(sort_indices, unsort_indices)`` pair of 1-D ``int64`` tensors that can be + used to reorder and restore the token dimension. + """ + ep_size, num_local_experts = recv_counts.shape + device = recv_counts.device + num_blocks = ep_size * num_local_experts + + # Source-grouped (row-major) block offsets: [s0e0, s0e1, ..., s1e0, s1e1, ...] + counts_src = recv_counts.reshape(-1).long() + offsets_src = torch.zeros(num_blocks, dtype=torch.long, device=device) + offsets_src[1:] = counts_src[:-1].cumsum(0) + + # Expert-grouped (column-major) block offsets: [e0s0, e0s1, ..., e1s0, e1s1, ...] + counts_exp = recv_counts.t().contiguous().reshape(-1).long() + offsets_exp = torch.zeros(num_blocks, dtype=torch.long, device=device) + offsets_exp[1:] = counts_exp[:-1].cumsum(0) + + total = counts_src.sum() + + # Mapping from source block index (s * L + e) to expert block index (e * S + s) + s_idx = torch.arange(ep_size, device=device).unsqueeze(1).expand(ep_size, num_local_experts) + e_idx = torch.arange(num_local_experts, device=device).unsqueeze(0).expand(ep_size, num_local_experts) + src_to_exp = (e_idx * ep_size + s_idx).reshape(-1) + + # Per-block positional shift from source layout to expert layout + shifts = offsets_exp[src_to_exp] - offsets_src + + # Expand per-block shifts to per-token + token_shifts = shifts.repeat_interleave(counts_src) + + # Map each source-grouped position to its expert-grouped destination + src_positions = torch.arange(total, device=device) + dst_positions = src_positions + token_shifts + + # sort_indices[exp_pos] = src_pos (gathers source tokens into expert order) + sort_indices = torch.empty(total, dtype=torch.long, device=device) + sort_indices[dst_positions] = src_positions + + # unsort_indices: inverse permutation (restores expert-ordered output to source order) + unsort_indices = torch.empty_like(sort_indices) + unsort_indices[sort_indices] = torch.arange(total, device=device) + + return sort_indices, unsort_indices + + +@dataclass +class _AllToAllHandle: + """Opaque handle for AllToAllTokenDispatcher, storing state between dispatch and combine.""" + + row_id_map: torch.Tensor + routing_weights: torch.Tensor + unsort_indices: torch.Tensor | None = None + input_split_sizes: list[int] | None = None + output_split_sizes: list[int] | None = None + + +class _DifferentiableAllToAll(torch.autograd.Function): + """Differentiable wrapper around dist.all_to_all_single. + + The forward pass performs the standard all-to-all communication. + The backward pass reverses the communication direction (swapping + input/output split sizes) so that gradients flow correctly. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + output_split_sizes: list[int], + input_split_sizes: list[int], + group: dist.ProcessGroup, + ) -> torch.Tensor: + """Perform all-to-all forward and save sizes for backward.""" + ctx.input_split_sizes = input_split_sizes + ctx.output_split_sizes = output_split_sizes + ctx.group = group + output = torch.empty( + sum(output_split_sizes), + input.shape[1], + device=input.device, + dtype=input.dtype, + ) + dist.all_to_all_single(output, input.contiguous(), output_split_sizes, input_split_sizes, group=group) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None]: + """Reverse all-to-all: swap input and output split sizes.""" + grad_input = torch.empty( + sum(ctx.input_split_sizes), + grad_output.shape[1], + device=grad_output.device, + dtype=grad_output.dtype, + ) + dist.all_to_all_single( + grad_input, + grad_output.contiguous(), + ctx.input_split_sizes, + ctx.output_split_sizes, + group=ctx.group, + ) + return grad_input, None, None, None + + +class AllToAllTokenDispatcher: + """TokenDispatcher using NCCL all-to-all for expert-parallel communication. + + Handles both EP=1 (no communication, just permute/unpermute) and EP>1 + (all-to-all token exchange between ranks) cases transparently. + + Args: + num_experts: Total number of experts (global). + num_local_experts: Number of experts on this rank. + hidden_size: Hidden dimension size. + ep_size: Expert parallel world size. + """ + + def __init__(self, num_experts: int, num_local_experts: int, hidden_size: int, ep_size: int): + """Initialize the AllToAllTokenDispatcher.""" + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.ep_size = ep_size + self._ep_group: dist.ProcessGroup | None = None + + def set_ep_group(self, ep_group: dist.ProcessGroup) -> None: + """Set the expert-parallel process group for all-to-all communication.""" + self._ep_group = ep_group + + def dispatch( + self, + hidden_states: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + ) -> DispatchOutput: + """Dispatch tokens to their assigned experts via permute and optional all-to-all. + + Args: + hidden_states: Flattened input tensor of shape ``[N, H]``. + selected_experts: Expert assignments, shape ``[N, top_k]``, int. + routing_weights: Normalized routing probabilities, shape ``[N, top_k]``, float32. + + Returns: + DispatchOutput with expert-sorted tokens, per-expert counts, and an opaque handle. + """ + # Permute tokens by expert using TE moe_permute + permuted_hidden, row_id_map = transformer_engine.pytorch.moe_permute( + hidden_states, selected_experts.to(torch.int32), map_type="index" + ) + + # Compute m_splits: number of tokens per expert + m_splits_tensor = torch.bincount(selected_experts.reshape(-1), minlength=self.num_experts).int() + + if self._ep_group is not None: + ep_group = self._ep_group + + # Token counts per expert, reshaped to [ep_size, num_local_experts] + send_counts = m_splits_tensor.reshape(self.ep_size, self.num_local_experts) + + # Exchange per-expert token counts between EP ranks + recv_counts = torch.empty_like(send_counts) + dist.all_to_all_single(recv_counts.flatten(), send_counts.flatten(), group=ep_group) + + # Derive split sizes for the token all-to-all + input_split_sizes = send_counts.sum(dim=1).tolist() + output_split_sizes = recv_counts.sum(dim=1).tolist() + local_m_splits = recv_counts.sum(dim=0).int().tolist() + + # Dispatch tokens to expert-owning ranks (differentiable) + recv_tokens = _DifferentiableAllToAll.apply( + permuted_hidden, output_split_sizes, input_split_sizes, ep_group + ) + + # Sort received tokens by local expert index. + # After all_to_all layout is [src0_exp0..src0_expL, src1_exp0..src1_expL, ...]. + # GroupedLinear needs [all_exp0, all_exp1, ...]. + sort_indices, unsort_indices = _build_expert_sort_indices(recv_counts) + + handle = _AllToAllHandle( + row_id_map=row_id_map, + routing_weights=routing_weights, + unsort_indices=unsort_indices, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes, + ) + return DispatchOutput( + expert_input=recv_tokens[sort_indices], + tokens_per_expert=local_m_splits, + handle=handle, + ) + + handle = _AllToAllHandle(row_id_map=row_id_map, routing_weights=routing_weights) + return DispatchOutput( + expert_input=permuted_hidden, + tokens_per_expert=m_splits_tensor.tolist(), + handle=handle, + ) + + def combine(self, expert_output: torch.Tensor, handle: _AllToAllHandle) -> torch.Tensor: + """Combine expert outputs back to the original token order. + + Args: + expert_output: Expert output tensor of shape ``[total_recv_tokens, H]``. + handle: Handle from ``dispatch()`` containing state for the reverse operation. + + Returns: + Combined output tensor of shape ``[N, H]`` with routing weights applied. + """ + if self._ep_group is not None: + assert handle.unsort_indices is not None + # Unsort back to source-rank-grouped order and reverse all_to_all (differentiable) + combined = _DifferentiableAllToAll.apply( + expert_output[handle.unsort_indices], + handle.input_split_sizes, + handle.output_split_sizes, + self._ep_group, + ) + else: + combined = expert_output + + # Unpermute and combine with routing weights (keep probs in float32 for numerical stability) + return transformer_engine.pytorch.moe_unpermute( + combined, + handle.row_id_map, + merging_probs=handle.routing_weights, + map_type="index", + ) diff --git a/bionemo-recipes/recipes/mixtral_native_te/perf_logger.py b/bionemo-recipes/recipes/mixtral_native_te/perf_logger.py new file mode 100644 index 0000000000..6db6175c61 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/perf_logger.py @@ -0,0 +1,308 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time + +import nvdlfw_inspect.api as debug_api +import nvtx +import torch +import torchmetrics +import wandb +from distributed_config import DistributedConfig +from omegaconf import DictConfig, OmegaConf +from torch.distributed.tensor import DTensor +from tqdm import tqdm +from transformers.modeling_outputs import CausalLMOutputWithPast + + +logger = logging.getLogger(__name__) + + +class PerfLogger: + """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. + + Args: + dist_config: The distributed configuration. + args: The arguments. + + Attributes: + min_loss: The minimum loss seen so far. + """ + + def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: int): + """Initialize the logger.""" + self._dist_config = dist_config + self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) + + self._device = torch.device(f"cuda:{dist_config.local_rank}") + self.min_loss = torch.tensor(float("inf"), device=self._device) + + self.logging_frequency = args.logger.frequency + + metrics_dict = { + "train/loss": torchmetrics.MeanMetric(), + "train/grad_norm": torchmetrics.MeanMetric(), + "train/learning_rate": torchmetrics.MeanMetric(), + "train/step_time": torchmetrics.MeanMetric(), + "train/tokens_per_second_per_gpu": torchmetrics.MeanMetric(), + "train/unpadded_tokens_per_second_per_gpu": torchmetrics.MeanMetric(), + "train/total_unpadded_tokens_per_batch": torchmetrics.SumMetric(), + "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), + "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), + } + + self.metrics = torchmetrics.MetricCollection(metrics_dict) + # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. + self.metrics.to(self._device) + self.previous_step_time = time.perf_counter() + self._profiler = None + + if self._dist_config.is_main_process(): + # Log the entire args object to wandb for experiment tracking and reproducibility. + self._wandb_run = wandb.init(**args.wandb, config=self._run_config) + self._progress_bar = tqdm(initial=start_step, total=args.num_train_steps, desc="Training") + + if args.profiler.enabled: + self._profiler = NsightProfiler( + **args.profiler, + wandb_run=self._wandb_run, + dist_config=dist_config, + ) + + # Gradient accumulation tracking + self.num_tokens = 0 + self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + self.running_loss = torch.tensor(0.0, device=self._device) + self.grad_acc_step_count = 0 + + # Whether to step debug_api.step() after each step + self.fp8_stats_enabled = args.fp8_stats_config.enabled + + @nvtx.annotate("PerfLogger.log_micro_step", color="pink") + def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast): + """Store data on micro step for gradient accumulation metrics. + + Args: + step: The step number. + batch: The batch of data for the micro step. + outputs: The outputs of the micro step. + """ + if self._dist_config.local_rank == 0: + logger.debug("log_micro_step") + + assert outputs.loss is not None, "Loss is None" + + with torch.no_grad(): + self.grad_acc_step_count += 1 + self.running_loss += outputs.loss + + if step % self.logging_frequency == 0 and step > 0: + self.num_tokens += batch["input_ids"].numel() + # Use attention_mask to count unpadded tokens (works for both BSHD and THD) + if "attention_mask" in batch: + self.num_unpadded_tokens += batch["attention_mask"].sum() + else: + # Fallback for pure sequence packing with no padding: all tokens are unpadded + self.num_unpadded_tokens += batch["input_ids"].numel() + + @nvtx.annotate("PerfLogger.log_step", color="purple") + def log_step( + self, + step: int, + grad_norm: torch.Tensor | DTensor, + lr: float, + ): + """Log a step to the logger and wandb. + + Args: + step: The step number. + grad_norm: The gradient norm of the step. + lr: The learning rate of the step. + """ + if self._dist_config.local_rank == 0: + logger.debug("log_step %s", step) + + with torch.no_grad(): + # Use accumulated metrics from gradient accumulation + assert self.grad_acc_step_count > 0, ( + f"Gradient accumulation steps ({self.grad_acc_step_count}) must be greater than 0, " + f"and can be incremented by log_micro_step()." + ) + + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.to_local() + + if self._profiler is not None: + self._profiler.step(step) + + if self.fp8_stats_enabled: + debug_api.step() + + if step % self.logging_frequency == 0 and step > 0: + # Calculate average loss over all micro steps in the logging window + avg_loss = self.running_loss / self.grad_acc_step_count + self.min_loss = torch.minimum(self.min_loss, avg_loss) + + # Calculate an average step time over all steps in the logging window + now = time.perf_counter() + step_time = (now - self.previous_step_time) / self.logging_frequency + self.previous_step_time = now + + # For some reason, these trigger a CudaStreamSynchronize call, which blocks the dataloader in the next + # step. We therefore only update these once every logging_frequency steps. + self.metrics["train/loss"].update(avg_loss) + self.metrics["train/learning_rate"].update(lr) + self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/step_time"].update(step_time) + self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) + self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) + + memory_allocated = torch.cuda.memory_allocated() / (1024**3) + self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + + metrics = self.metrics.compute() + self.metrics.reset() + metrics = { + k: v.detach().cpu().item() if isinstance(v, torch.Tensor) and v.dim() == 0 else v + for k, v in metrics.items() + } + metrics["train/global_step"] = step + + if self._dist_config.is_main_process(): + wandb.log(metrics, step=step) + self._progress_bar.update(self.logging_frequency) + self._progress_bar.set_postfix({"loss": avg_loss.item()}) + + if self._dist_config.local_rank == 0: + logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) + + # Reset running loss and other tracking variables for next window + self.running_loss.zero_() + self.num_tokens = 0 + self.num_unpadded_tokens.zero_() + self.grad_acc_step_count = 0 + + def finish(self): + """Finish the logger and close the progress bar.""" + if not self._dist_config.is_main_process(): + return + + wandb.finish() + self._progress_bar.close() + + if self.fp8_stats_enabled: + debug_api.end_debug() + + +class NsightProfiler: + """Nsight Systems profiler wrapper for performance analysis. + + This profiler uses NVIDIA Nsight Systems to capture detailed performance traces + including CUDA kernels, CPU activities, and memory operations. The profiler + uploads results to wandb as artifacts. + + Args: + enabled: Whether profiling is enabled. + start_step: The step number at which to start profiling. + end_step: The step number at which to end profiling. + wandb_run: The wandb run for logging artifacts. + dist_config: The distributed configuration. + + Attributes: + start_step: The step number at which to start profiling. + end_step: The step number at which to end profiling. + current_step: Current step counter. + profiling_started: Whether profiling has been started. + profiling_finished: Whether profiling has been finished. + """ + + def __init__( + self, + enabled: bool, + start_step: int, + end_step: int, + wandb_run: wandb.Run, + dist_config: DistributedConfig, + ): + """Initialize the Nsight profiler.""" + self._wandb_run = wandb_run + self._dist_config = dist_config + + self.start_step = start_step + self.end_step = end_step + + self.current_step = 0 + self.profiling_started = False + self.profiling_finished = False + + # Check if running under nsys + self.running_under_nsys = "NSYS_PROFILING_SESSION_ID" in os.environ + + if self.running_under_nsys: + logger.info("Detected running under nsys - will use CUDA Profiler API for range control") + else: + logger.warning( + "Not running under nsys. Profiling will be skipped. " + "To enable profiling, run your script with: " + "nsys profile -o output_trace --trace=cuda,nvtx,osrt,cudnn,cublas --capture-range=cudaProfilerApi " + "--capture-range-end=stop python train_fsdp2.py profiler.enabled=true" + ) + + def step(self, step_num: int): + """Record a training step and control profiling based on the schedule. + + Args: + step_num: The current training step number. + """ + if not self.running_under_nsys or self.profiling_finished: + return + + self.current_step = step_num + + # Start profiling at start_step + if self.current_step == self.start_step and not self.profiling_started: + self._start_profiling() + # Stop profiling at end_step + elif self.current_step == self.end_step and self.profiling_started: + self._stop_profiling() + + def _start_profiling(self): + """Start CUDA profiling using the CUDA Profiler API.""" + if self.profiling_started: + return + + logger.info(f"Starting Nsight profiling at step {self.current_step}") + try: + torch.cuda.cudart().cudaProfilerStart() # type: ignore[attr-defined] + self.profiling_started = True + except Exception as e: + logger.error(f"Failed to start CUDA profiler: {e}") + + def _stop_profiling(self): + """Stop CUDA profiling using the CUDA Profiler API.""" + if not self.profiling_started or self.profiling_finished: + return + + logger.info(f"Stopping Nsight profiling at step {self.current_step}") + try: + torch.cuda.cudart().cudaProfilerStop() # type: ignore[attr-defined] + self.profiling_started = False + self.profiling_finished = True + except Exception as e: + logger.error(f"Failed to stop CUDA profiler: {e}") diff --git a/bionemo-recipes/recipes/mixtral_native_te/requirements.txt b/bionemo-recipes/recipes/mixtral_native_te/requirements.txt new file mode 100644 index 0000000000..073d9b39e3 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/requirements.txt @@ -0,0 +1,12 @@ +datasets +hydra-core +torch +torchao!=0.14.0 +torchdata +torchmetrics +tqdm +transformer_engine[pytorch] +transformers +wandb +zstandard +nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect diff --git a/bionemo-recipes/recipes/mixtral_native_te/run_ep_test.sh b/bionemo-recipes/recipes/mixtral_native_te/run_ep_test.sh new file mode 100644 index 0000000000..f787a706cf --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/run_ep_test.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Usage: ./run_ep_test.sh +# Run from inside the recipe directory on DGX. +# Example: ./run_ep_test.sh test_8x1B_lingua_ep2 +set -e + +CONFIG=${1:-test_8x1B_lingua_ep2} + +export BIONEMO_DISABLE_TORCH_COMPILE_HELPERS=1 +export TOKENIZERS_PARALLELISM=false +export NCCL_DEBUG=WARN +# EP>1 triggers torch._dynamo internally (DTensor/FSDP2), which needs ptxas + cuda.h +export PATH="/usr/local/cuda/bin:${PATH}" +export CPATH="/usr/local/cuda/include:${CPATH:-}" + +echo "=== Starting run: $CONFIG ===" +echo "Time: $(date)" +echo "GPUs: $CUDA_VISIBLE_DEVICES" + +torchrun \ + --standalone \ + --nproc_per_node=8 \ + train_fsdp2.py \ + --config-name "$CONFIG" + +echo "=== Finished run: $CONFIG at $(date) ===" diff --git a/bionemo-recipes/recipes/mixtral_native_te/scheduler.py b/bionemo-recipes/recipes/mixtral_native_te/scheduler.py new file mode 100644 index 0000000000..044357817b --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/scheduler.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from torch.optim.lr_scheduler import LambdaLR + + +def get_cosine_annealing_schedule_with_warmup( + optimizer, + num_warmup_steps=2_000, + num_decay_steps=500_000, + min_lr_ratio=0.0, + last_epoch=-1, +): + """Cosine annealing scheduler with warmup. + + The learning rate is linearly warmed up from 0 to max_lr over num_warmup_steps, + then follows a cosine annealing schedule from max_lr to min_lr over num_decay_steps. + After warmup_steps + decay_steps, the learning rate remains at min_lr. + + Args: + optimizer: The optimizer to schedule. + num_warmup_steps: Number of warmup steps. + num_decay_steps: Number of decay steps after warmup. + min_lr_ratio: Minimum learning rate as a ratio of the initial learning rate. + If 0.0, decays to 0. Otherwise, decays to max_lr * min_lr_ratio. + last_epoch: The index of the last epoch. Default: -1. + """ + # Get the initial learning rate (max_lr) from the optimizer + max_lr = optimizer.param_groups[0]["lr"] + min_lr = max_lr * min_lr_ratio + + def lr_lambda(current_step: int): + if num_warmup_steps > 0 and current_step <= num_warmup_steps: + # Warmup phase: linearly increase learning rate from 0 to max_lr + # LambdaLR multiplies by this value, so return step/warmup_steps + return float(current_step) / float(max(1, num_warmup_steps)) + # For any steps larger than warmup_steps + decay_steps, use min_lr + if current_step > num_warmup_steps + num_decay_steps: + # Return multiplier that gives min_lr when multiplied by max_lr + return min_lr_ratio + # Cosine annealing phase: decay from max_lr to min_lr + num_steps_ = current_step - num_warmup_steps + decay_steps_ = num_decay_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + delta_lr = max_lr - min_lr + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + # Calculate the actual LR, then convert to multiplier + actual_lr = min_lr + coeff * delta_lr + return actual_lr / max_lr + + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/bionemo-recipes/recipes/mixtral_native_te/tests/conftest.py b/bionemo-recipes/recipes/mixtral_native_te/tests/conftest.py new file mode 100644 index 0000000000..23ef829390 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/tests/conftest.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import tempfile +from pathlib import Path +from unittest import mock + +import pytest +import torch +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace +from transformer_engine.pytorch import fp8 as te_fp8 +from transformers import PreTrainedTokenizerFast + + +sys.path.append(Path(__file__).parent.parent.as_posix()) +sys.path.append(Path(__file__).parent.as_posix()) +from distributed_config import DistributedConfig + + +def _create_local_tokenizer(directory: Path) -> str: + """Create a small local tokenizer so tests don't depend on HF Hub.""" + directory.mkdir(parents=True, exist_ok=True) + tokenizer = Tokenizer( + WordLevel( + vocab={ + "[UNK]": 0, + "[PAD]": 1, + "[BOS]": 2, + "[EOS]": 3, + "the": 4, + "quick": 5, + "brown": 6, + "fox": 7, + "jumps": 8, + "over": 9, + "lazy": 10, + "dog": 11, + }, + unk_token="[UNK]", + ) + ) + tokenizer.pre_tokenizer = Whitespace() + fast_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + unk_token="[UNK]", + pad_token="[PAD]", + bos_token="[BOS]", + eos_token="[EOS]", + ) + fast_tokenizer.save_pretrained(directory) + return str(directory) + + +@pytest.fixture(scope="session") +def local_tokenizer_path(): + """Session-scoped local tokenizer that avoids HF Hub downloads.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield _create_local_tokenizer(Path(tmpdir) / "tokenizer") + + +@pytest.fixture +def recipe_path() -> Path: + """Return the root directory of the recipe.""" + return Path(__file__).parent.parent + + +@pytest.fixture +def tokenizer_path(local_tokenizer_path): + """Get the path to the local test tokenizer.""" + return local_tokenizer_path + + +@pytest.fixture(autouse=True) +def debug_api_cleanup(): + """Ensure nvdlfw_inspect does not stay initialized across tests.""" + yield + try: + import nvdlfw_inspect.api as debug_api + + debug_api.end_debug() + except Exception: # pragma: no cover - best-effort cleanup for optional dependency + pass + + +def pytest_collection_modifyitems(items): + """Run FP8 stats logging tests first to avoid late debug initialization.""" + stats_test_names = { + "test_sanity_ddp_fp8_stats_logging", + "test_sanity_fsdp2_fp8_stats_logging", + } + stats_tests = [item for item in items if item.name in stats_test_names] + other_tests = [item for item in items if item.name not in stats_test_names] + items[:] = stats_tests + other_tests + + +# --------------------------------------------------------------------------- +# FP8 recipe parametrization +# --------------------------------------------------------------------------- + +# Each entry: (recipe_class_name, hydra_overrides, check_fn) +_FP8_RECIPE_CONFIGS = [ + ( + "DelayedScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.DelayedScaling"], + te_fp8.check_fp8_support, + ), + ( + "Float8CurrentScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8CurrentScaling"], + te_fp8.check_fp8_support, + ), + ( + "Float8BlockScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8BlockScaling"], + te_fp8.check_fp8_block_scaling_support, + ), + ( + "MXFP8BlockScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.MXFP8BlockScaling"], + te_fp8.check_mxfp8_support, + ), +] + + +def _parametrize_fp8_recipes(): + """Generate pytest.param objects with xfail marks for unsupported FP8 recipes.""" + params = [] + for name, overrides, check_fn in _FP8_RECIPE_CONFIGS: + supported, reason = check_fn() + params.append( + pytest.param( + overrides, + id=name, + marks=pytest.mark.xfail(condition=not supported, reason=reason), + ) + ) + return params + + +@pytest.fixture(params=_parametrize_fp8_recipes()) +def fp_recipe(request): + """Parametrized fixture providing FP8 recipe Hydra overrides for each supported TE recipe.""" + return request.param + + +@pytest.fixture(scope="session", autouse=True) +def device_mesh(): + """Create a re-usable torch process group for testing. + This is a "auto-use", session-scope fixture so that a single torch process group is created and used in all tests. + """ + # Initialize the distributed configuration, including creating the distributed process group. + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + # Mock these torch.distributed functions so that we re-use the same device mesh, and don't re-create or destroy the + # global process group. + with ( + mock.patch("torch.distributed.init_process_group", return_value=None), + mock.patch("torch.distributed.destroy_process_group", return_value=None), + ): + yield + + # At the end of all tests, destroy the process group and clear the device mesh resources. + torch.distributed.destroy_process_group() + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/bionemo-recipes/recipes/mixtral_native_te/tests/distributed_helpers.py b/bionemo-recipes/recipes/mixtral_native_te/tests/distributed_helpers.py new file mode 100644 index 0000000000..826d82c6c7 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/tests/distributed_helpers.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test utilities for distributed (EP/FSDP) tests.""" + +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path + +import torch + + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from modeling_mixtral_te import NVMixtralConfig + + +def create_small_mixtral_config(**overrides) -> NVMixtralConfig: + """Create a small Mixtral config suitable for testing.""" + defaults = { + "hidden_size": 128, + "intermediate_size": 256, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "num_local_experts": 4, + "num_experts_per_tok": 2, + "max_position_embeddings": 128, + "vocab_size": 1000, + "attn_input_format": "bshd", + "self_attn_mask_type": "causal", + "router_jitter_noise": 0.0, + } + defaults.update(overrides) + return NVMixtralConfig(**defaults) + + +def get_dummy_batch(vocab_size: int, seq_len: int = 32, batch_size: int = 2, device: str = "cuda"): + """Create a simple dummy batch for testing.""" + torch.manual_seed(42) + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids) + labels = input_ids.clone() + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + + +@dataclass(frozen=True) +class DistributedConfig: + """Distributed environment configuration.""" + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + _master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost")) + _master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355")) + + def is_main_process(self) -> bool: + """Return True if this is the global rank 0 process.""" + return self.rank == 0 diff --git a/bionemo-recipes/recipes/mixtral_native_te/tests/test_dataset.py b/bionemo-recipes/recipes/mixtral_native_te/tests/test_dataset.py new file mode 100644 index 0000000000..38c6f2a8f3 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/tests/test_dataset.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset and dataloader tests for the Mixtral native TE recipe.""" + +import gc +import sys +from pathlib import Path + +import datasets +import pytest +import torch +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace +from transformers import PreTrainedTokenizerFast + + +sys.path.insert(0, Path(__file__).parent.parent.as_posix()) + +from dataset import create_bshd_dataloader, create_thd_dataloader +from distributed_config import DistributedConfig + + +def _cleanup(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _create_local_tokenizer(tmp_path: Path) -> str: + """Create a small local WordLevel tokenizer that does not require HuggingFace Hub.""" + tokenizer_dir = tmp_path / "tokenizer" + tokenizer_dir.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer( + WordLevel( + vocab={ + "[UNK]": 0, + "[PAD]": 1, + "[BOS]": 2, + "[EOS]": 3, + "hello": 4, + "world": 5, + "token": 6, + "checkpoint": 7, + "mixtral": 8, + "data": 9, + "test": 10, + "pack": 11, + }, + unk_token="[UNK]", + ) + ) + tokenizer.pre_tokenizer = Whitespace() + + fast_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + unk_token="[UNK]", + pad_token="[PAD]", + bos_token="[BOS]", + eos_token="[EOS]", + ) + fast_tokenizer.save_pretrained(tokenizer_dir) + return str(tokenizer_dir) + + +def _make_tiny_dataset(): + """Return an in-memory HuggingFace dataset with short repeated text.""" + return datasets.Dataset.from_dict({"text": ["hello world token checkpoint " * 50] * 20}) + + +@pytest.fixture +def local_tokenizer(tmp_path): + return _create_local_tokenizer(tmp_path) + + +@pytest.fixture +def tiny_parquet(tmp_path): + """Write the tiny dataset to a parquet file and return its path.""" + ds = _make_tiny_dataset() + path = tmp_path / "tiny.parquet" + ds.to_parquet(str(path)) + return str(path) + + +def test_create_bshd_dataloader_returns_correct_batch_keys(local_tokenizer, tiny_parquet): + """BSHD dataloader batches must contain input_ids, attention_mask, and labels with matching shapes.""" + dist_config = DistributedConfig(rank=0, world_size=1) + micro_batch_size = 4 + max_seq_length = 32 + + dataloader, _ = create_bshd_dataloader( + distributed_config=dist_config, + tokenizer_name_or_path=local_tokenizer, + load_dataset_kwargs={"path": "parquet", "data_files": tiny_parquet, "split": "train"}, + micro_batch_size=micro_batch_size, + num_workers=0, + max_seq_length=max_seq_length, + stride=10, + ) + + batch = next(iter(dataloader)) + + assert "input_ids" in batch, "Batch missing input_ids" + assert "attention_mask" in batch, "Batch missing attention_mask" + assert "labels" in batch, "Batch missing labels" + + assert batch["input_ids"].shape[0] == micro_batch_size + assert batch["attention_mask"].shape == batch["input_ids"].shape + assert batch["labels"].shape == batch["input_ids"].shape + + _cleanup() + + +def test_create_thd_dataloader_returns_packed_batch(local_tokenizer, tiny_parquet): + """THD dataloader batches must contain input_ids, labels, cu_seq_lens_q, and cu_seq_lens_k.""" + dist_config = DistributedConfig(rank=0, world_size=1) + max_seq_length = 32 + + dataloader, _ = create_thd_dataloader( + distributed_config=dist_config, + tokenizer_name_or_path=local_tokenizer, + load_dataset_kwargs={"path": "parquet", "data_files": tiny_parquet, "split": "train", "streaming": True}, + token_micro_batch_size=128, + num_workers=0, + max_seq_length=max_seq_length, + stride=10, + ) + + batch = next(iter(dataloader)) + + assert "input_ids" in batch, "Batch missing input_ids" + assert "labels" in batch, "Batch missing labels" + assert "cu_seq_lens_q" in batch, "Batch missing cu_seq_lens_q" + assert "cu_seq_lens_k" in batch, "Batch missing cu_seq_lens_k" + + _cleanup() + + +def test_bshd_dataloader_sequence_length(local_tokenizer, tiny_parquet): + """BSHD batches must not exceed max_seq_length in the sequence dimension.""" + dist_config = DistributedConfig(rank=0, world_size=1) + max_seq_length = 16 + + dataloader, _ = create_bshd_dataloader( + distributed_config=dist_config, + tokenizer_name_or_path=local_tokenizer, + load_dataset_kwargs={"path": "parquet", "data_files": tiny_parquet, "split": "train"}, + micro_batch_size=2, + num_workers=0, + max_seq_length=max_seq_length, + stride=5, + ) + + for batch in dataloader: + seq_len = batch["input_ids"].shape[1] + assert seq_len <= max_seq_length, f"Sequence length {seq_len} exceeds max_seq_length {max_seq_length}" + + _cleanup() + + +def test_thd_dataloader_token_packing_no_padding(local_tokenizer, tiny_parquet): + """THD batches should have batch_size=1 because sequences are packed into a single flat tensor.""" + dist_config = DistributedConfig(rank=0, world_size=1) + max_seq_length = 32 + + dataloader, _ = create_thd_dataloader( + distributed_config=dist_config, + tokenizer_name_or_path=local_tokenizer, + load_dataset_kwargs={"path": "parquet", "data_files": tiny_parquet, "split": "train", "streaming": True}, + token_micro_batch_size=128, + num_workers=0, + max_seq_length=max_seq_length, + stride=10, + ) + + for batch in dataloader: + assert batch["input_ids"].shape[0] == 1, ( + f"THD packed batch should have batch_size=1, got {batch['input_ids'].shape[0]}" + ) + + _cleanup() diff --git a/bionemo-recipes/recipes/mixtral_native_te/tests/test_distributed_checkpointing.py b/bionemo-recipes/recipes/mixtral_native_te/tests/test_distributed_checkpointing.py new file mode 100644 index 0000000000..0a4610a547 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/tests/test_distributed_checkpointing.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributed checkpoint stop-go tests for the Mixtral native TE recipe.""" + +import gc +import os +import socket +import subprocess +from pathlib import Path + +import pytest +import torch +from hydra import compose, initialize_config_dir +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace +from train_ddp import main as main_ddp +from train_fsdp2 import main as main_fsdp2 +from transformers import PreTrainedTokenizerFast + + +os.environ["WANDB_DISABLED"] = "true" +os.environ["WANDB_MODE"] = "disabled" + + +def _reserve_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +os.environ["MASTER_PORT"] = str(_reserve_port()) + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +def _compose_config(recipe_path, tmp_path, overrides): + ckpt_dir = str(tmp_path / "ckpt") + base = [ + f"checkpoint.ckpt_dir={ckpt_dir}", + f"+wandb.dir={tmp_path}", + "dataset.use_stateful_dataloader=true", + ] + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + return compose(config_name="L0_sanity", overrides=base + list(overrides)) + + +def _create_local_tokenizer(tmp_path: Path) -> str: + tokenizer_dir = tmp_path / "tokenizer" + tokenizer_dir.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer( + WordLevel( + vocab={ + "[UNK]": 0, + "[PAD]": 1, + "[BOS]": 2, + "[EOS]": 3, + "hello": 4, + "world": 5, + "mixtral": 6, + "token": 7, + "checkpoint": 8, + }, + unk_token="[UNK]", + ) + ) + tokenizer.pre_tokenizer = Whitespace() + + fast_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + unk_token="[UNK]", + pad_token="[PAD]", + bos_token="[BOS]", + eos_token="[EOS]", + ) + fast_tokenizer.save_pretrained(tokenizer_dir) + return str(tokenizer_dir) + + +def _assert_loss_valid(loss): + assert loss is not None + loss_val = float(loss) + assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}" + + +def _assert_checkpoint_step(ckpt_subdir, step, num_ranks, is_ddp, use_distributed_checkpoint=False): + step_dir = os.path.join(ckpt_subdir, f"step_{step}") + assert os.path.isdir(step_dir), f"Step {step} directory not found: {step_dir}" + files = os.listdir(step_dir) + if is_ddp and not use_distributed_checkpoint: + assert "checkpoint.pt" in files, f"Missing checkpoint.pt in {step_dir}: {files}" + if use_distributed_checkpoint: + model_files = [f for f in files if f.startswith("model_rank_")] + optimizer_files = [f for f in files if f.startswith("optimizer_rank_")] + assert len(model_files) >= num_ranks, f"Expected model files for {num_ranks} ranks in {step_dir}: {files}" + assert len(optimizer_files) >= num_ranks, ( + f"Expected optimizer files for {num_ranks} ranks in {step_dir}: {files}" + ) + assert "metadata.pt" in files, f"Missing metadata.pt in {step_dir}: {files}" + dataloader_files = [f for f in files if "dataloader" in f] + assert len(dataloader_files) >= num_ranks, ( + f"Expected dataloader files for {num_ranks} ranks in {step_dir}: {files}" + ) + + +def _run_single_process_checkpoint_test(recipe_path, tmp_path, main_fn, ckpt_subdir_name, extra_overrides, is_ddp): + tokenizer_path = _create_local_tokenizer(tmp_path) + expert_parallel_size = int( + next(o.split("=", 1)[1] for o in extra_overrides if o.startswith("expert_parallel_size=")) + ) + use_distributed_checkpoint = is_ddp and expert_parallel_size > 1 + common = [ + "checkpoint.save_every_n_steps=5", + "checkpoint.async_save=false", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + *extra_overrides, + ] + + cfg1 = _compose_config( + recipe_path, + tmp_path, + ["num_train_steps=10", "checkpoint.resume_from_checkpoint=false", *common], + ) + loss1 = main_fn(cfg1) + gc.collect() + torch.cuda.empty_cache() + + ckpt_subdir = os.path.join(str(tmp_path / "ckpt"), ckpt_subdir_name) + _assert_checkpoint_step( + ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint + ) + + cfg2 = _compose_config( + recipe_path, + tmp_path, + ["num_train_steps=15", "checkpoint.resume_from_checkpoint=true", *common], + ) + loss2 = main_fn(cfg2) + gc.collect() + torch.cuda.empty_cache() + + _assert_checkpoint_step( + ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint + ) + _assert_checkpoint_step( + ckpt_subdir, 10, num_ranks=1, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint + ) + _assert_loss_valid(loss1) + _assert_loss_valid(loss2) + + +def _run_multi_process_checkpoint_test( + recipe_path, tmp_path, train_script_name, ckpt_subdir_name, extra_overrides, is_ddp +): + ckpt_dir = str(tmp_path / "ckpt") + tokenizer_path = _create_local_tokenizer(tmp_path) + expert_parallel_size = int( + next(o.split("=", 1)[1] for o in extra_overrides if o.startswith("expert_parallel_size=")) + ) + use_distributed_checkpoint = is_ddp and expert_parallel_size > 1 + env = os.environ.copy() + env["WANDB_MODE"] = "disabled" + env["MASTER_PORT"] = str(_reserve_port()) + env["PATH"] = f"/usr/local/cuda/bin:{env['PATH']}" + env["CPATH"] = f"/usr/local/cuda/include:{env.get('CPATH', '')}".rstrip(":") + env["BIONEMO_DISABLE_TORCH_COMPILE_HELPERS"] = "1" + env["TOKENIZERS_PARALLELISM"] = "false" + env["NCCL_DEBUG"] = "WARN" + env["TRITON_PTXAS_PATH"] = "/usr/local/cuda/bin/ptxas" + train_script = recipe_path / train_script_name + common = [ + f"checkpoint.ckpt_dir={ckpt_dir}", + "checkpoint.save_every_n_steps=5", + "checkpoint.async_save=false", + "dataset.use_stateful_dataloader=true", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + *extra_overrides, + ] + base_cmd = ["torchrun", "--standalone", "--nproc_per_node=2", str(train_script)] + + result1 = subprocess.run( + [*base_cmd, "num_train_steps=10", "checkpoint.resume_from_checkpoint=false", *common], + check=False, + capture_output=True, + text=True, + env=env, + ) + assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" + + ckpt_subdir = os.path.join(ckpt_dir, ckpt_subdir_name) + _assert_checkpoint_step( + ckpt_subdir, 5, num_ranks=2, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint + ) + + result2 = subprocess.run( + [*base_cmd, "num_train_steps=15", "checkpoint.resume_from_checkpoint=true", *common], + check=False, + capture_output=True, + text=True, + env=env, + ) + assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" + + _assert_checkpoint_step( + ckpt_subdir, 5, num_ranks=2, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint + ) + _assert_checkpoint_step( + ckpt_subdir, 10, num_ranks=2, is_ddp=is_ddp, use_distributed_checkpoint=use_distributed_checkpoint + ) + + +def test_checkpoint_save_and_load_single_process_ddp_ep1(recipe_path, tmp_path): + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_ddp, + ckpt_subdir_name="train_ddp", + extra_overrides=["expert_parallel_size=1"], + is_ddp=True, + ) + + +def test_checkpoint_save_and_load_single_process_fsdp2_ep1(recipe_path, tmp_path): + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2, + ckpt_subdir_name="train_fsdp2", + extra_overrides=["expert_parallel_size=1"], + is_ddp=False, + ) + + +@requires_multi_gpu +@pytest.mark.xfail( + reason=( + "DDP stop-go checkpointing with expert_parallel_size > 1 is currently unsupported in this recipe. " + "Resume drops EP expert weights from the saved model state; use the FSDP2 recipe for EP save/resume." + ), + strict=False, +) +def test_checkpoint_save_and_load_two_processes_ddp_ep2(recipe_path, tmp_path): + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_ddp.py", + ckpt_subdir_name="train_ddp", + extra_overrides=["expert_parallel_size=2"], + is_ddp=True, + ) + + +@requires_multi_gpu +def test_checkpoint_save_and_load_two_processes_fsdp2_ep2(recipe_path, tmp_path): + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_fsdp2.py", + ckpt_subdir_name="train_fsdp2", + extra_overrides=["expert_parallel_size=2"], + is_ddp=False, + ) diff --git a/bionemo-recipes/recipes/mixtral_native_te/tests/test_fsdp_ep.py b/bionemo-recipes/recipes/mixtral_native_te/tests/test_fsdp_ep.py new file mode 100644 index 0000000000..e3b58b645c --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/tests/test_fsdp_ep.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for FSDP2 + Expert Parallelism (EP) in the mixtral_native_te recipe. + +Verifies that FSDP2 and EP can be composed together: +- FSDP=2, EP=1 (2 GPUs): Data-parallel sharding, all experts on each rank. +- FSDP=1, EP=2 (2 GPUs): Expert-parallel training, no data parallelism. +""" + +import subprocess +import sys +from pathlib import Path + + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +import pytest +import torch +from distributed_helpers import DistributedConfig, create_small_mixtral_config, get_dummy_batch +from modeling_mixtral_te import NVMixtralForCausalLM + + +requires_2_gpus = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +def _distribute_state_dict(full_state_dict: dict, model: torch.nn.Module, device: torch.device) -> dict: + """Distribute a full (EP=1) state dict to match a model's DTensor sharding. + + After calling ``set_ep_groups``, expert weight parameters become DTensors with + ``Shard(0)`` placement. This function uses ``distribute_tensor`` to automatically + shard full expert weights according to those annotations, avoiding manual slicing. + + Args: + full_state_dict: Complete state dict from an EP=1 model (plain tensors). + model: Target EP model whose expert parameters are already DTensors. + device: Device to move source tensors to before distributing. + """ + from torch.distributed.tensor import DTensor, distribute_tensor + + distributed_state: dict = {} + # model.state_dict() filters _extra_state keys via the NVMixtralPreTrainedModel + # override, so use nn.Module.state_dict to get the unfiltered dict that includes + # TransformerEngine _extra_state entries required by load_state_dict(strict=True). + for key, value in torch.nn.Module.state_dict(model).items(): + if key.endswith("_extra_state"): + distributed_state[key] = value + elif key not in full_state_dict: + continue + elif isinstance(value, DTensor): + distributed_state[key] = distribute_tensor( + full_state_dict[key].to(device), + value.device_mesh, + list(value.placements), + ) + else: + distributed_state[key] = full_state_dict[key] + return distributed_state + + +def _train_step(model, batch): + """Run a single forward + backward + optimizer step. + + Returns: + Tuple of (loss value, dict of gradient norms, dict of weight change norms). + """ + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + # Snapshot weights before step + pre_weights = {n: p.detach().clone() for n, p in model.named_parameters()} + + optimizer.zero_grad() + outputs = model(**batch) + loss = outputs.loss + loss.backward() + + grad_norms = {} + for name, param in model.named_parameters(): + if param.grad is not None: + g = param.grad + if hasattr(g, "full_tensor"): + g = g.full_tensor() + grad_norms[name] = g.detach().float().norm().item() + + optimizer.step() + + # Measure weight changes + weight_changes = {} + for name, param in model.named_parameters(): + pre = pre_weights[name] + cur = param.detach() + if hasattr(pre, "full_tensor"): + pre = pre.full_tensor() + if hasattr(cur, "full_tensor"): + cur = cur.full_tensor() + weight_changes[name] = (cur.float() - pre.float()).norm().item() + + return loss.detach().item(), grad_norms, weight_changes + + +# --------------------------------------------------------------------------- +# Pytest entry points — launch torchrun subprocesses +# --------------------------------------------------------------------------- + + +def _run_torchrun(test_fn_name: str, port: int, nproc: int = 2): + """Run a named worker function via torchrun.""" + recipe_dir = str(Path(__file__).resolve().parent.parent) + script = str(Path(__file__).resolve()) + cmd = [ + "torchrun", + f"--nproc_per_node={nproc}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{port}", + script, + test_fn_name, + ] + result = subprocess.run( + cmd, + check=False, + text=True, + cwd=recipe_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=300, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"{test_fn_name} failed with exit code {result.returncode}") + + +@requires_2_gpus +def test_fsdp2_ep1(free_tcp_port): + """Test FSDP=2, EP=1: data-parallel training with all experts on each rank.""" + _run_torchrun("fsdp2_ep1", free_tcp_port, nproc=2) + + +@requires_2_gpus +def test_fsdp1_ep2(free_tcp_port): + """Test FSDP=1, EP=2: expert-parallel training without data parallelism.""" + _run_torchrun("fsdp1_ep2", free_tcp_port, nproc=2) + + +# --------------------------------------------------------------------------- +# Distributed workers executed via torchrun +# --------------------------------------------------------------------------- + + +def _worker_fsdp2_ep1(): + """FSDP=2, EP=1: weights sharded by FSDP, all experts on each rank. + + Uses a 2D device mesh (dp=2, ep=1) so that DTensor multi-dimensional + placement logic is exercised even though the EP dimension is trivial. + + 1. Init distributed, create 2D device mesh with ep=1. + 2. Create model with EP=1, set EP groups on the trivial EP sub-mesh. + 3. Wrap with FSDP2 on the DP sub-mesh. + 4. Run one training step, verify loss/gradients are finite and weights update. + """ + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import fully_shard + + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.cuda.set_device(device) + torch.distributed.init_process_group(backend="nccl", device_id=device) + + ep_size = 1 + dp_size = dist_config.world_size + device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep")) + + config = create_small_mixtral_config(expert_parallel_size=ep_size) + torch.manual_seed(0) + model = NVMixtralForCausalLM(config).to(dtype=torch.bfloat16, device=device) + + # EP setup with trivial (size-1) EP sub-mesh + ep_mesh = device_mesh["ep"] + ep_group = ep_mesh.get_group() + model.model.set_ep_groups(ep_group, ep_mesh) + + # FSDP2 wrapping on DP sub-mesh + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) + + model.train() + batch = get_dummy_batch(config.vocab_size, device=str(device)) + + loss_val, grad_norms, weight_changes = _train_step(model, batch) + + assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}" + assert len(grad_norms) > 0, "No gradients computed" + for name, gnorm in grad_norms.items(): + assert torch.isfinite(torch.tensor(gnorm)), f"Gradient for {name} is not finite: {gnorm}" + assert any(wc > 0 for wc in weight_changes.values()), "No weights updated after optimizer step" + + torch.distributed.destroy_process_group() + + +def _worker_fsdp1_ep2(): + """FSDP=1, EP=2: experts sharded across ranks, trivial data parallelism. + + Uses a 2D device mesh (dp=1, ep=2) so that DTensor multi-dimensional + placement logic is exercised even though the DP dimension is trivial. + + 1. Init distributed, create 2D device mesh with dp=1. + 2. Create full EP=1 model for reference weights. + 3. Create EP=2 model, set EP groups (DTensor annotations), load via distribute_tensor. + 4. Wrap with FSDP2 on the trivial DP sub-mesh. + 5. Run one training step, verify loss/gradients are finite and weights update. + """ + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import fully_shard + + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.cuda.set_device(device) + torch.distributed.init_process_group(backend="nccl", device_id=device) + + ep_size = dist_config.world_size + dp_size = 1 + device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep")) + + ep_mesh = device_mesh["ep"] + ep_group = ep_mesh.get_group() + + # Get reference weights from a full EP=1 model + config_full = create_small_mixtral_config(expert_parallel_size=1) + torch.manual_seed(0) + full_model = NVMixtralForCausalLM(config_full).to(dtype=torch.bfloat16, device="cpu") + full_state_dict = {k: v.clone() for k, v in full_model.state_dict().items()} + del full_model + + # Create EP=2 model, set EP groups to create DTensor annotations, then load weights + config_ep = create_small_mixtral_config(expert_parallel_size=ep_size) + torch.manual_seed(0) + model = NVMixtralForCausalLM(config_ep).to(dtype=torch.bfloat16, device=device) + + # EP setup on EP sub-mesh first (creates DTensor annotations on expert weights) + model.model.set_ep_groups(ep_group, ep_mesh) + + # Load EP=1 weights — distribute_tensor handles expert sharding automatically + distributed_state = _distribute_state_dict(full_state_dict, model, device) + model.load_state_dict(distributed_state, strict=True) + + # FSDP2 wrapping on trivial (size-1) DP sub-mesh + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) + + model.train() + batch = get_dummy_batch(config_ep.vocab_size, device=str(device)) + + loss_val, grad_norms, weight_changes = _train_step(model, batch) + + assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}" + assert len(grad_norms) > 0, "No gradients computed" + for name, gnorm in grad_norms.items(): + assert torch.isfinite(torch.tensor(gnorm)), f"Gradient for {name} is not finite: {gnorm}" + assert any(wc > 0 for wc in weight_changes.values()), "No weights updated after optimizer step" + + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + test_name = sys.argv[1] + + workers = { + "fsdp2_ep1": _worker_fsdp2_ep1, + "fsdp1_ep2": _worker_fsdp1_ep2, + } + workers[test_name]() diff --git a/bionemo-recipes/recipes/mixtral_native_te/tests/test_lingua_8x1B.py b/bionemo-recipes/recipes/mixtral_native_te/tests/test_lingua_8x1B.py new file mode 100644 index 0000000000..dc62ce2160 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/tests/test_lingua_8x1B.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from hydra import compose, initialize_config_dir +from omegaconf import OmegaConf +from scheduler import get_cosine_annealing_schedule_with_warmup +from torch.optim import AdamW + + +def test_lingua_8x1b_optimizer_golden_values(recipe_path): + """Test that optimizer and scheduler match the recipe configuration.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + config = compose(config_name="L2_lingua_8x1B") + + model = torch.nn.Linear(10, 1) + optimizer = AdamW(model.parameters(), **OmegaConf.to_container(config.adamw_kwargs, resolve=True)) # type: ignore[arg-type] + + assert optimizer.param_groups[0]["lr"] == config.adamw_kwargs.lr + assert list(optimizer.param_groups[0]["betas"]) == list(config.adamw_kwargs.betas) + assert optimizer.param_groups[0]["eps"] == config.adamw_kwargs.eps + assert optimizer.param_groups[0]["weight_decay"] == config.adamw_kwargs.weight_decay + + scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **config.lr_scheduler_kwargs) + + for _ in range(3): + optimizer.step() + scheduler.step() + + assert optimizer.param_groups[0]["lr"] > 0 diff --git a/bionemo-recipes/recipes/mixtral_native_te/tests/test_train.py b/bionemo-recipes/recipes/mixtral_native_te/tests/test_train.py new file mode 100644 index 0000000000..069a2cc0ac --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/tests/test_train.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random + +import pytest +import torch +from hydra import compose, initialize_config_dir +from train_ddp import main as main_ddp +from train_fsdp2 import main as main_fsdp2 +from transformer_engine.pytorch.fp8 import check_fp8_support + + +# TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed. +requires_datacenter_hardware = pytest.mark.skipif( + not torch.cuda.is_available() + or not any( + gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"] + ), + reason="Test requires datacenter hardware (H100, H200, B100, B200, B300)", +) + +_fp8_support_result = check_fp8_support() if torch.cuda.is_available() else (False, "CUDA not available") +requires_fp8 = pytest.mark.skipif( + not torch.cuda.is_available() or not _fp8_support_result[0], + reason=f"Test requires FP8 support: {_fp8_support_result[1]}", +) + + +def _cleanup(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +@pytest.fixture(autouse=True) +def set_seed(): + """Set random seeds for reproducibility.""" + random.seed(42) + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + +def test_sanity_convergence_fsdp2_te_bshd(tmp_path, recipe_path, local_tokenizer_path): + tokenizer_path = local_tokenizer_path + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + "checkpoint.resume_from_checkpoint=false", + "num_train_steps=40", + "config_kwargs.attn_input_format=bshd", + ], + ) + + final_loss = main_fsdp2(sanity_config) + _cleanup() + + assert final_loss < 8.5, f"Final loss {final_loss} is too high, expected < 8.5" + + +def test_sanity_convergence_fsdp2_te_thd(tmp_path, recipe_path, local_tokenizer_path): + tokenizer_path = local_tokenizer_path + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + "checkpoint.resume_from_checkpoint=false", + "num_train_steps=40", + "use_sequence_packing=true", + "config_kwargs.attn_input_format=thd", + "config_kwargs.self_attn_mask_type=padding_causal", + ], + ) + + final_loss = main_fsdp2(sanity_config) + _cleanup() + + assert final_loss < 8.5, f"Final loss {final_loss} is too high, expected < 8.5" + + +def test_sanity_convergence_fsdp2_te_bshd_grad_acc(tmp_path, recipe_path, local_tokenizer_path): + """Test FSDP2 training with gradient accumulation.""" + tokenizer_path = local_tokenizer_path + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + "checkpoint.resume_from_checkpoint=false", + "num_train_steps=40", + "config_kwargs.attn_input_format=bshd", + "grad_acc_steps=2", + ], + ) + + final_loss = main_fsdp2(sanity_config) + _cleanup() + + # Grad accumulation halves effective optimizer steps, so convergence is weaker + assert final_loss < 8.5, f"Final loss {final_loss} is too high, expected < 8.5" + + +def test_sanity_convergence_ddp_te(tmp_path, recipe_path, local_tokenizer_path): + """Test that DDP training converges on sanity-scale data.""" + tokenizer_path = local_tokenizer_path + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + "checkpoint.resume_from_checkpoint=false", + "num_train_steps=40", + "config_kwargs.attn_input_format=bshd", + ], + ) + + final_loss = main_ddp(sanity_config) + _cleanup() + + assert final_loss < 8.5, f"Final loss {final_loss} is too high, expected < 8.5" + + +def test_sanity_convergence_ddp_te_grad_acc(tmp_path, recipe_path, local_tokenizer_path): + """Test DDP training with gradient accumulation.""" + tokenizer_path = local_tokenizer_path + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + "checkpoint.resume_from_checkpoint=false", + "num_train_steps=40", + "config_kwargs.attn_input_format=bshd", + "grad_acc_steps=2", + ], + ) + + final_loss = main_ddp(sanity_config) + _cleanup() + + assert final_loss < 8.5, f"Final loss {final_loss} is too high, expected < 8.5" + + +def test_sanity_convergence_fsdp2_hf(tmp_path, recipe_path, local_tokenizer_path): + """Test that FSDP2 training converges with HuggingFace (non-TE) model.""" + tokenizer_path = local_tokenizer_path + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + "checkpoint.resume_from_checkpoint=false", + "num_train_steps=40", + "use_te=false", + "use_meta_device=false", + ], + ) + + final_loss = main_fsdp2(sanity_config) + _cleanup() + + assert final_loss < 8.5, f"Final loss {final_loss} is too high, expected < 8.5" + + +@requires_fp8 +@requires_datacenter_hardware +def test_sanity_convergence_fsdp2_te_fp8(tmp_path, recipe_path, local_tokenizer_path, fp_recipe): + """Test FSDP2 training with FP8 enabled using parametrized FP8 recipes.""" + tokenizer_path = local_tokenizer_path + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + "checkpoint.resume_from_checkpoint=false", + "num_train_steps=40", + "config_kwargs.attn_input_format=bshd", + "fp8_config.enabled=true", + *fp_recipe, + ], + ) + + final_loss = main_fsdp2(sanity_config) + _cleanup() + + assert final_loss < 8.5, f"Final loss {final_loss} is too high, expected < 8.5" diff --git a/bionemo-recipes/recipes/mixtral_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/mixtral_native_te/tests/test_train_two_gpu.py new file mode 100644 index 0000000000..665f5650b6 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/tests/test_train_two_gpu.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess + +import pytest +import torch + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +def run_train_cmd(cmd, recipe_path): + """Run a training command and check for errors.""" + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + cwd=str(recipe_path), + ) + + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}") + + +@requires_multi_gpu +def test_multi_gpu_train_ddp(recipe_path, local_tokenizer_path): + """Test DDP training on 2 GPUs.""" + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node", + "2", + "train_ddp.py", + "--config-name", + "L0_sanity", + "num_train_steps=4", + "expert_parallel_size=1", + f"dataset.tokenizer_name_or_path={local_tokenizer_path}", + ], + recipe_path, + ) + + +@requires_multi_gpu +def test_multi_gpu_train_fsdp2(recipe_path, local_tokenizer_path): + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node", + "2", + "train_fsdp2.py", + "--config-name", + "L0_sanity", + "num_train_steps=4", + f"dataset.tokenizer_name_or_path={local_tokenizer_path}", + ], + recipe_path, + ) + + +@requires_multi_gpu +def test_multi_gpu_train_fsdp2_with_checkpointing(tmp_path, recipe_path, local_tokenizer_path): + """Test FSDP2 training on 2 GPUs with checkpoint saving.""" + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node", + "2", + "train_fsdp2.py", + "--config-name", + "L0_sanity", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.save_every_n_steps=5", + "dataset.use_stateful_dataloader=true", + "expert_parallel_size=1", + f"dataset.tokenizer_name_or_path={local_tokenizer_path}", + ], + recipe_path, + ) + + # Verify checkpoint was created + ckpt_dir = tmp_path / "train_fsdp2" + assert ckpt_dir.exists(), f"Checkpoint directory not created: {ckpt_dir}" + assert (ckpt_dir / "step_5").exists(), "Checkpoint at step 5 not found" + + +@requires_multi_gpu +def test_multi_gpu_train_fsdp2_ep2(recipe_path, local_tokenizer_path): + """Test FSDP2 training with expert parallelism (EP=2) on 2 GPUs.""" + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node", + "2", + "train_fsdp2.py", + "--config-name", + "L0_sanity", + "num_train_steps=4", + "expert_parallel_size=2", + f"dataset.tokenizer_name_or_path={local_tokenizer_path}", + ], + recipe_path, + ) diff --git a/bionemo-recipes/recipes/mixtral_native_te/train_ddp.py b/bionemo-recipes/recipes/mixtral_native_te/train_ddp.py new file mode 100644 index 0000000000..d6152a438c --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/train_ddp.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributed Data Parallel (DDP) training script for Mixtral with TransformerEngine.""" + +import gc +import logging +from contextlib import nullcontext +from pathlib import Path + +import hydra +import nvdlfw_inspect.api as debug_api +import torch +import transformer_engine.pytorch +from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint +from dataset import create_bshd_dataloader, create_thd_dataloader +from distributed_config import DistributedConfig +from fp8_debugging import initialize_fp8_debugging +from modeling_mixtral_te import NVMixtralConfig, NVMixtralForCausalLM +from omegaconf import DictConfig, OmegaConf +from perf_logger import PerfLogger +from scheduler import get_cosine_annealing_schedule_with_warmup +from torch.distributed.device_mesh import init_device_mesh +from torch.optim import AdamW +from train_fsdp2 import _build_dispatcher, clip_grad_norm_ep_aware +from transformer_engine.common.recipe import Format +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") +def main(args: DictConfig) -> float | None: + """Train Mixtral with TE layers using DDP. + + Returns: + float: The minimum loss value observed during training. + """ + # --- Distributed Setup --- + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + if args.fp8_stats_config.enabled: + initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) + + # --- Model Configuration --- + ep_size = args.expert_parallel_size + if dist_config.world_size % ep_size != 0: + raise ValueError( + f"world_size ({dist_config.world_size}) must be divisible by expert_parallel_size ({ep_size})" + ) + if ep_size > 1: + raise ValueError( + "DDP stop-go checkpointing with expert_parallel_size > 1 is currently unsupported for this recipe. " + "Use train_fsdp2.py for EP checkpoint save/resume." + ) + dp_size = dist_config.world_size // ep_size + device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep")) + + fp8_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + + fp4_recipe = None + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs) + + # --- Model Initialization --- + if args.use_te: + config = NVMixtralConfig.from_pretrained( + args.config_name_or_path, + dtype=torch.bfloat16, + expert_parallel_size=ep_size, + **args.config_kwargs, + ) + dispatcher = _build_dispatcher(args, config) + with torch.device("meta") if args.use_meta_device else nullcontext(): + model = NVMixtralForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe, dispatcher=dispatcher) + else: + config = MixtralConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + with torch.device("meta") if args.use_meta_device else nullcontext(): + model = MixtralForCausalLM(config) + + logger.info("Initialized Model:\n%s", model) + + # --- Expert Parallelism Setup --- + if args.use_te and ep_size > 1: + ep_mesh = device_mesh["ep"] + ep_group = ep_mesh.get_group() + model.model.set_ep_groups(ep_group, ep_mesh) + + # --- Distributed Wrapping (DDP) --- + if args.use_meta_device: + if args.use_te: + model.init_empty_weights() + else: + model.to_empty(device=device) + model.apply(model._init_weights) + + if args.fp8_stats_config.enabled: + debug_api.infer_and_assign_layer_names(model) + + model = model.to(device=device) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + device_mesh=device_mesh["dp"], + ) + + # --- Optimizer & Scheduler --- + optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore[arg-type] + scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + if args.use_torch_compile: + model = torch.compile(model) + + # --- Data Loading --- + if args.use_sequence_packing: + train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) + else: + train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) + + # --- Checkpoint Resume --- + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None + if args.checkpoint.resume_from_checkpoint and ckpt_path: + logger.info("Attempting to load checkpoint from %s", ckpt_path) + model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_ddp( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + dist_config=dist_config, + dataloader=train_dataloader, + expert_parallel_size=ep_size, + ) + logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) + else: + logger.info("No checkpoint to load, starting from scratch") + start_step = 0 + epoch = 0 + + perf_logger = PerfLogger(dist_config, args, start_step=start_step) + + # --- Training Loop --- + gc.collect() + torch.cuda.empty_cache() + + logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) + step = start_step + micro_step = 0 + while step < args.num_train_steps: + for batch in train_dataloader: + device_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + micro_step += 1 + + with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext(): + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + outputs = model(**device_batch) + + loss = outputs.loss / args.grad_acc_steps + loss.backward() + + perf_logger.log_micro_step(step=step, batch=device_batch, outputs=outputs) + + if micro_step % args.grad_acc_steps == 0: + micro_step = 0 + + total_norm = clip_grad_norm_ep_aware(model.parameters(), max_norm=1.0, ep_size=ep_size) + + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + perf_logger.log_step(step=step, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"]) + + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): + save_checkpoint_ddp( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + step=step, + epoch=epoch, + dist_config=dist_config, + dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, + max_checkpoints=args.checkpoint.max_checkpoints, + expert_parallel_size=ep_size, + ) + + step += 1 + if step >= args.num_train_steps: + break + + epoch += 1 + dataset_or_sampler.set_epoch(epoch) + + # --- Cleanup --- + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_ddp( + model=model, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + perf_logger.finish() + torch.distributed.destroy_process_group() + + return perf_logger.min_loss + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/mixtral_native_te/train_fsdp2.py b/bionemo-recipes/recipes/mixtral_native_te/train_fsdp2.py new file mode 100644 index 0000000000..e970300cf5 --- /dev/null +++ b/bionemo-recipes/recipes/mixtral_native_te/train_fsdp2.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fully Sharded Data Parallel v2 (FSDP2) training script for Mixtral with TransformerEngine.""" + +import gc +import logging +from contextlib import nullcontext +from pathlib import Path +from typing import Iterable + +import hydra +import nvdlfw_inspect.api as debug_api +import torch +import transformer_engine.pytorch +from checkpoint import ( + _ckpt_futures, + load_checkpoint_fsdp2, + save_checkpoint_fsdp2, + save_final_model_fsdp2, + should_save_checkpoint, +) +from dataset import create_bshd_dataloader, create_thd_dataloader +from distributed_config import DistributedConfig +from fp8_debugging import initialize_fp8_debugging +from modeling_mixtral_te import NVMixtralConfig, NVMixtralForCausalLM +from omegaconf import DictConfig, OmegaConf +from perf_logger import PerfLogger +from scheduler import get_cosine_annealing_schedule_with_warmup +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import fully_shard +from torch.optim import AdamW +from transformer_engine.common.recipe import Format +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def _build_dispatcher(args: DictConfig, config: NVMixtralConfig): + """Build the requested token dispatcher for EP runs. + + Returns None for the default alltoall dispatcher (handled natively by TE). + Returns a FusedTokenRouter when token_dispatcher=fused_deepep and deep_ep is available. + Falls back to alltoall (returns None) when fused_deepep is requested but unavailable, + if token_dispatcher_fallback=alltoall is set. + """ + token_dispatcher = str(getattr(args, "token_dispatcher", "alltoall")) + fallback_dispatcher = str(getattr(args, "token_dispatcher_fallback", "error")) + if config.expert_parallel_size == 1: + return None + if token_dispatcher == "alltoall": + return None + if token_dispatcher != "fused_deepep": + raise ValueError(f"Unsupported token_dispatcher: {token_dispatcher!r}. Expected 'alltoall' or 'fused_deepep'.") + + try: + from fused_token_router import FusedTokenRouter + + return FusedTokenRouter( + num_experts=config.num_local_experts, + num_local_experts=config.num_local_experts // config.expert_parallel_size, + hidden_size=config.hidden_size, + ep_size=config.expert_parallel_size, + ) + except ImportError as exc: + if fallback_dispatcher == "alltoall": + logger.warning("Fused DeepEP dispatcher unavailable (%s). Falling back to AllToAllTokenDispatcher.", exc) + return None + raise + + +def clip_grad_norm_ep_aware(params: Iterable[torch.nn.Parameter], max_norm: float, ep_size: int) -> torch.Tensor: + """Clip gradient norms, handling expert parallelism (DTensor parameters on different meshes). + + When ep_size > 1, parameters may be DTensors on different device meshes (dp vs ep), + which prevents torch.nn.utils.clip_grad_norm_ from stacking norms across them. + This function computes norms per-parameter from local shards and clips accordingly. + + Args: + params: Model parameters (may include DTensor expert weights). + ep_size: Expert parallelism size. If 1, falls back to standard clip_grad_norm_. + max_norm: Maximum gradient norm. + + Returns: + Total gradient norm (approximate for ep_size > 1). + """ + if ep_size == 1: + return torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm) + + # Compute per-param local norms, handling DTensor by extracting the local shard. + param_list = list(params) + norms = [] + for p in param_list: + if p.grad is None: + continue + g = p.grad.detach() + if hasattr(g, "to_local"): + g = g.to_local() # Extract local shard of DTensor gradient + norms.append(g.float().norm()) + + if not norms: + return torch.tensor(0.0) + + total_norm = torch.stack(norms).norm() + clip_coef = max_norm / (total_norm + 1e-6) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for p in param_list: + if p.grad is not None: + p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device)) + + return total_norm + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") +def main(args: DictConfig) -> float | None: + """Train Mixtral with TE layers using FSDP2. + + Returns: + float: The minimum loss value observed during training. + """ + # --- Distributed Setup --- + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + if args.fp8_stats_config.enabled: + initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) + + # --- Model Configuration --- + ep_size = args.expert_parallel_size + if dist_config.world_size % ep_size != 0: + raise ValueError( + f"world_size ({dist_config.world_size}) must be divisible by expert_parallel_size ({ep_size})" + ) + dp_size = dist_config.world_size // ep_size + device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep")) + + fp8_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + + fp4_recipe = None + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs) + + # --- Model Initialization --- + if args.use_te: + # Pass expert_parallel_size to config so the model initializes with the correct + # num_local_experts = num_experts // expert_parallel_size per rank. + config = NVMixtralConfig.from_pretrained( + args.config_name_or_path, + dtype=torch.bfloat16, + expert_parallel_size=ep_size, + **args.config_kwargs, + ) + dispatcher = _build_dispatcher(args, config) + with torch.device("meta") if args.use_meta_device else nullcontext(): + model = NVMixtralForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe, dispatcher=dispatcher) + else: + config = MixtralConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + with torch.device("meta") if args.use_meta_device else nullcontext(): + model = MixtralForCausalLM(config) + + logger.info("Initialized Model:\n%s", model) + + # --- Expert Parallelism Setup --- + # Expert parallelism setup — MUST happen before fully_shard() + # Wraps expert weights as DTensors with Shard(0) on the expert dimension. + if args.use_te and ep_size > 1: + ep_mesh = device_mesh["ep"] + ep_group = ep_mesh.get_group() + model.model.set_ep_groups(ep_group, ep_mesh) + + # --- Distributed Wrapping (FSDP2) --- + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) + + if args.use_meta_device: + if args.use_te: + model.init_empty_weights() + else: + model.to_empty(device=device) + model.apply(model._init_weights) + + if args.fp8_stats_config.enabled: + debug_api.infer_and_assign_layer_names(model) + + # --- Optimizer & Scheduler --- + optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore[arg-type] + scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + if args.use_torch_compile: + model = torch.compile(model) + + # --- Data Loading --- + if args.use_sequence_packing: + train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) + else: + train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) + + # --- Checkpoint Resume --- + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None + if args.checkpoint.resume_from_checkpoint and ckpt_path: + logger.info("Attempting to load checkpoint from %s", ckpt_path) + model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + dist_config=dist_config, + dataloader=train_dataloader, + process_group=device_mesh.get_group("dp"), + expert_parallel_size=ep_size, + ) + logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) + else: + logger.info("No checkpoint to load, starting from scratch") + start_step = 0 + epoch = 0 + + perf_logger = PerfLogger(dist_config, args, start_step=start_step) + + # --- Training Loop --- + gc.collect() + torch.cuda.empty_cache() + + logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) + step = start_step + micro_step = 0 + while step < args.num_train_steps: + for batch in train_dataloader: + device_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + micro_step += 1 + + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + outputs = model(**device_batch) + + loss = outputs.loss / args.grad_acc_steps + loss.backward() + + perf_logger.log_micro_step(step=step, batch=device_batch, outputs=outputs) + + if micro_step % args.grad_acc_steps == 0: + micro_step = 0 + + total_norm = clip_grad_norm_ep_aware(model.parameters(), max_norm=1.0, ep_size=ep_size) + + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + perf_logger.log_step(step=step, grad_norm=total_norm, lr=optimizer.param_groups[0]["lr"]) + + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): + save_checkpoint_fsdp2( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + step=step, + epoch=epoch, + dist_config=dist_config, + dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, + process_group=device_mesh.get_group("dp"), + max_checkpoints=args.checkpoint.max_checkpoints, + async_save=args.checkpoint.async_save, + expert_parallel_size=ep_size, + ) + + step += 1 + if step >= args.num_train_steps: + break + + epoch += 1 + dataset_or_sampler.set_epoch(epoch) + + # --- Cleanup --- + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_fsdp2( + model=model, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + if args.checkpoint.async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None: + _ckpt_futures["fsdp2"].result() + + perf_logger.finish() + torch.distributed.destroy_process_group() + + return perf_logger.min_loss + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/Dockerfile b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/Dockerfile new file mode 100644 index 0000000000..faedb5f609 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/Dockerfile @@ -0,0 +1,9 @@ +# syntax=docker/dockerfile:1.4 +FROM nvcr.io/nvidia/pytorch:26.03-py3 + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=requirements.txt,target=/requirements.txt \ + PIP_CONSTRAINT= pip install -r /requirements.txt + +WORKDIR /workspace/bionemo +COPY . . diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/README.md b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/README.md new file mode 100644 index 0000000000..4a0ed66d3a --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/README.md @@ -0,0 +1,54 @@ +# TransformerEngine-accelerated Mixtral training for OpenGenome2 + +This folder demonstrates how to train TE-accelerated Mixtral with a native PyTorch training loop for autoregressive DNA +token prediction on the OpenGenome2 metagenome subset. It follows the same recipe conventions as +`opengenome2_llama_native_te` for dataset loading, genomic masking, validation, checkpointing, and W&B logging. + +## Supported features + +- FSDP2 training +- THD sequence packing +- nucleotide tokenizer packaged with the recipe +- genomic label masking +- FP32 master weights through FSDP mixed precision policy +- validation logging during training + +## Not supported in this v1 recipe + +- context parallelism +- Llama-specific OG2 initialization features such as Spike-No-More and Megatron scaled residual init + +## Commands + +Single-GPU sanity run: + +```bash +python train_fsdp2.py --config-name L0_sanity +``` + +Single-GPU bounded OG2 smoke run: + +```bash +python train_fsdp2.py --config-name og2_small_thd_moe \ + num_train_steps=20 \ + checkpoint.ckpt_dir=./checkpoints +``` + +Cluster handoff: + +```bash +torchrun --standalone --nproc_per_node=2 train_fsdp2.py --config-name og2_small_thd_moe +``` + +## Data + +Download a bounded OpenGenome2 subset for local runs: + +```bash +hf download arcinstitute/opengenome2 --repo-type dataset \ + --include "json/pretraining_or_both_phases/metagenomes/data_metagenomics_train_chunk1.jsonl.gz" \ + --include "json/pretraining_or_both_phases/metagenomes/data_metagenomics_valid_chunk1.jsonl.gz" \ + --local-dir /data/opengenome2 +``` + +Use `WANDB_KEY` for Weights & Biases logging. diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/checkpoint.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/checkpoint.py new file mode 100644 index 0000000000..a40a8a982a --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/checkpoint.py @@ -0,0 +1,430 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import time +from dataclasses import dataclass +from pathlib import Path +from typing import NamedTuple + +import torch +from distributed_config import DistributedConfig +from safetensors.torch import save_file +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_state_dict, + set_state_dict, +) +from torch.distributed.checkpoint.state_dict_loader import load as dcp_load +from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save +from torch.distributed.checkpoint.state_dict_saver import save as dcp_save +from torch.distributed.checkpoint.stateful import Stateful +from torchdata.stateful_dataloader import StatefulDataLoader + + +logger = logging.getLogger(__name__) + + +class LenientLoadPlanner(DefaultLoadPlanner): + """A load planner that skips keys missing from the checkpoint. + + Handles checkpoints saved without TransformerEngine _extra_state keys + (FP8 metadata). These keys are registered by newer TE versions even when + FP8 is disabled, but older checkpoints don't contain them. + """ + + def create_local_plan(self): + """Create a local load plan, skipping keys missing from the checkpoint.""" + missing_keys = [fqn for fqn in self.state_dict if fqn not in self.metadata.state_dict_metadata] + if missing_keys: + logger.warning( + "Skipping %d keys not found in checkpoint: %s%s", + len(missing_keys), + missing_keys[:5], + "..." if len(missing_keys) > 5 else "", + ) + for key in missing_keys: + del self.state_dict[key] + return super().create_local_plan() + + +# Tracks in-flight async checkpoint futures keyed by strategy name (e.g. "fsdp2"). +# Each entry holds the Future returned by dcp_async_save so we can await it before starting +# the next async save or before shutting down. +_ckpt_futures: dict = {} + + +class CheckpointOutput(NamedTuple): + """Output of checkpoint loading.""" + + model: torch.nn.Module + optimizer: torch.optim.Optimizer + scheduler: torch.optim.lr_scheduler.LRScheduler + dataloader: StatefulDataLoader | None + step: int + epoch: int + + +# ============================================================================ +# Helper functions +# ============================================================================ + + +def get_latest_checkpoint(ckpt_path: str | os.PathLike) -> tuple[Path | None, int]: + """Get the latest checkpoint path and step number. + + Returns: + Tuple of (checkpoint path, step number). + If no checkpoint files are found, returns (None, 0). + """ + ckpt_path = Path(ckpt_path) + if not ckpt_path.exists(): + return None, 0 + + checkpoints = [f for f in ckpt_path.iterdir() if f.name.startswith("step_")] + + if not checkpoints: + return None, 0 + + latest = max(checkpoints, key=lambda x: int(Path(x).stem.split("_")[1])) + step = int(Path(latest).stem.split("_")[1]) + return latest, step + + +def should_save_checkpoint(step: int, save_every_n_steps: int) -> bool: + """Determine if a checkpoint should be saved.""" + if save_every_n_steps > 0 and step % save_every_n_steps == 0 and step > 0: + return True + return False + + +def prune_checkpoints(ckpt_path: str | os.PathLike, max_checkpoints: int) -> None: + """Prune checkpoints to keep only the latest `max_checkpoints` checkpoints.""" + ckpt_path = Path(ckpt_path) + checkpoints = [f for f in ckpt_path.iterdir() if f.name.startswith("step_")] + checkpoints.sort(key=lambda x: int(Path(x).stem.split("_")[1])) + if len(checkpoints) > max_checkpoints: + for checkpoint in checkpoints[:-max_checkpoints]: + logger.info(f"Pruning checkpoint {checkpoint}") + if checkpoint.is_dir(): + shutil.rmtree(checkpoint) + else: + os.remove(checkpoint) + + +def get_fsdp2_checkpoint_process_group( + process_group: torch.distributed.ProcessGroup | None, + *, + expert_parallel_size: int, +) -> torch.distributed.ProcessGroup | None: + """Choose the DCP process group for FSDP2 checkpoints. + + Expert-parallel parameters are DTensors sharded across the EP mesh, so all + participating ranks must join checkpoint save/load when EP is enabled. + Returning None lets DCP use the default world process group. + """ + if expert_parallel_size > 1: + return None + return process_group + + +# ============================================================================ +# FSDP2 Checkpointing +# ============================================================================ + + +@dataclass +class AppState(Stateful): + """AppState for FSDP2 checkpoint. + + Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html + """ + + model: torch.nn.Module + optimizer: torch.optim.Optimizer + scheduler: torch.optim.lr_scheduler.LRScheduler + step: int = 0 + epoch: int = 0 + + def state_dict(self): + """Get the state dict for the model, optimizer, scheduler, and step.""" + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) + return { + "model": model_state_dict, + "optim": optimizer_state_dict, + "scheduler": self.scheduler.state_dict(), + "step": self.step, + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict: dict): + """Load the state dict for the model, optimizer, scheduler, and step.""" + # Use strict=False to handle checkpoints saved without TransformerEngine + # _extra_state keys (FP8 metadata). These keys are registered by newer TE + # versions even when FP8 is disabled, and are safe to skip. + incompatible = set_state_dict( + self.model, + self.optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"], + options=StateDictOptions(strict=False), + ) + if incompatible and (incompatible.missing_keys or incompatible.unexpected_keys): + if incompatible.missing_keys: + logger.warning(f"Missing keys when loading checkpoint: {incompatible.missing_keys}") + if incompatible.unexpected_keys: + logger.warning(f"Unexpected keys when loading checkpoint: {incompatible.unexpected_keys}") + self.scheduler.load_state_dict(state_dict["scheduler"]) + self.step = state_dict["step"] + self.epoch = state_dict["epoch"] + + +def load_checkpoint_fsdp2( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + process_group: torch.distributed.ProcessGroup | None = None, + expert_parallel_size: int = 1, +) -> CheckpointOutput: + """Load FSDP2 checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The directory containing checkpoints. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + process_group: The process group to use for checkpointing. + expert_parallel_size: Expert parallelism size. When > 1, loads expert weights with EP-aware state dict handling. + """ + checkpoint_path, _ = get_latest_checkpoint(ckpt_path) + if not checkpoint_path: + logger.info("No FSDP2 checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + app_state = AppState( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + state_dict = {"app": app_state} + checkpoint_process_group = get_fsdp2_checkpoint_process_group( + process_group, + expert_parallel_size=expert_parallel_size, + ) + dcp_load( + state_dict, + checkpoint_id=checkpoint_path, + process_group=checkpoint_process_group, + planner=LenientLoadPlanner(), + ) + + if dataloader is not None: + load_dataloader( + dataloader=dataloader, + ckpt_path=checkpoint_path, + dist_config=dist_config, + ) + + logger.info(f"Loaded distributed FSDP2 checkpoint from step {app_state.step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, app_state.step + 1, app_state.epoch) + + +def save_checkpoint_fsdp2( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + epoch: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + process_group: torch.distributed.ProcessGroup | None = None, + expert_parallel_size: int = 1, + max_checkpoints: int | None = None, + async_save: bool = False, +) -> None: + """Save FSDP2 checkpoint. + + Args: + model: The model to save. + optimizer: The optimizer to save. + scheduler: The LR scheduler to save. + ckpt_path: The directory to save the checkpoint. + step: The step number to save the checkpoint. + epoch: The epoch number to save the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to save. + process_group: The process group to use for checkpointing. + max_checkpoints: The maximum number of checkpoints to keep. + async_save: Whether to save the checkpoint asynchronously. + expert_parallel_size: Expert parallelism size. When > 1, saves expert weights with EP-aware state dict handling. + """ + start_time = time.perf_counter() + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + if dataloader is not None: + save_dataloader( + dataloader=dataloader, + ckpt_path=checkpoint_path, + dist_config=dist_config, + ) + logger.info(f"Saved FSDP2 dataloader to {ckpt_path}") + + state_dict = {"app": AppState(model=model, optimizer=optimizer, scheduler=scheduler, step=step, epoch=epoch)} + checkpoint_process_group = get_fsdp2_checkpoint_process_group( + process_group, + expert_parallel_size=expert_parallel_size, + ) + if async_save: + # If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time. + if "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None: + _ckpt_futures["fsdp2"].result() + + _ckpt_futures["fsdp2"] = dcp_async_save( + state_dict, + checkpoint_id=checkpoint_path, + process_group=checkpoint_process_group, + ) + else: + dcp_save(state_dict, checkpoint_id=checkpoint_path, process_group=checkpoint_process_group) + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + if dist_config.is_main_process(): + logger.info( + f"Saved distributed FSDP2 checkpoint to {checkpoint_path} " + f"in {time.perf_counter() - start_time:.2f} seconds" + ) + + +def save_final_model_fsdp2( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for FSDP2 - gather on all ranks, save on main.""" + # ALL ranks must participate in gathering + model_state_dict = get_model_state_dict( + model=model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=True, + ), + ) + + # Only main process saves + if not dist_config.is_main_process(): + return + + os.makedirs(save_directory, exist_ok=True) + + # Save just the weights using safetensors + save_file(model_state_dict, os.path.join(save_directory, "model.safetensors")) + + # Save the config + underlying_model = model.module if hasattr(model, "module") else model + if hasattr(underlying_model, "config"): + underlying_model.config.save_pretrained(save_directory) + + logger.info(f"Saved final FSDP2 model to {save_directory} (weights + config only)") + + +# ============================================================================ +# Dataloader Checkpointing +# ============================================================================ + + +def save_dataloader( + dataloader: StatefulDataLoader | None, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, +): + """Save the dataloader state to a file. + + Args: + dataloader: The dataloader to save the state of. + ckpt_path: The path to save the dataloader state to. + dist_config: The distributed configuration. + """ + if dataloader is None: + return + + ckpt_path = Path(ckpt_path) + ckpt_path.mkdir(parents=True, exist_ok=True) + dataloader_path = ckpt_path / f"dataloader_rank_{dist_config.rank}.pt" + + dataloader_state = dataloader.state_dict() + dataloader_state["num_workers"] = dataloader.num_workers + dataloader_state["num_ranks"] = dist_config.world_size + torch.save(dataloader_state, dataloader_path) + if dist_config.is_main_process(): + logger.info(f"Saved dataloader state to {dataloader_path}") + + +def load_dataloader( + dataloader: StatefulDataLoader | None, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, +) -> StatefulDataLoader | None: + """Load the dataloader state from a file. + + Args: + dataloader: The dataloader to load the state of. + ckpt_path: The path to load the dataloader state from. + dist_config: The distributed configuration. + """ + if dataloader is None: + return dataloader + + dataloader_path = Path(ckpt_path) / f"dataloader_rank_{dist_config.rank}.pt" + if not dataloader_path.exists(): + logger.warning( + f"No dataloader checkpoint found for rank {dist_config.rank}, starting dataloader from scratch." + ) + return dataloader + + dataloader_state = torch.load(dataloader_path, weights_only=True) + + if ( + dataloader.num_workers != dataloader_state["num_workers"] + or dist_config.world_size != dataloader_state["num_ranks"] + ): + logger.warning( + f"Dataloader num_workers mismatch: {dataloader.num_workers} != {dataloader_state['num_workers']} or " + f"num_ranks mismatch: {dist_config.world_size} != {dataloader_state['num_ranks']}, " + "starting dataloader from scratch." + ) + return dataloader + + dataloader.load_state_dict(dataloader_state) + if dist_config.is_main_process(): + logger.info(f"Loaded dataloader state from {dataloader_path}") + + return dataloader diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/collator.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/collator.py new file mode 100644 index 0000000000..4555c1762a --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/collator.py @@ -0,0 +1,1042 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/esm2/collator.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""Data collators for sequence packing and context parallel training. + +This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. +""" + +import logging +import threading +from dataclasses import dataclass, field +from typing import Any, TypedDict + +import datasets +import nvtx +import torch +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp +from transformers import DataCollator, DataCollatorForLanguageModeling + + +logger = logging.getLogger(__name__) + + +@dataclass +class DataCollatorWithFlattening: + """Data collator that wraps a DataCollatorForLanguageModeling and flattens inputs for flash-attention. + + This collator enables efficient training on batches containing variable-length sequences, by first flattening + (packing) multiple input sequences into a single contiguous tensor without padding between sequences. Then, it + applies masked language modeling (MLM) masking using the provided DataCollatorForLanguageModeling instance. + + The collator also generates metadata required for Flash Attention or context-parallel attention: + - `cu_seq_lens_q` and `cu_seq_lens_k` tensors, denoting cumulative sequence lengths so that sequence boundaries + within the packed tensor are known during attention computation. + + Optionally, the collator can: + - Pad the total number of tokens in the batch to be divisible by `pad_to_multiple_of` (by appending a mock + sequence). + - Pad each individual sequence to be divisible by `pad_sequences_to_be_divisible_by` if provided. + + Only PyTorch tensors (`return_tensors="pt"`) are supported. + + Args: + collator (DataCollatorForLanguageModeling): The collator to use for MLM masking. This is a captive + collator and should be constructed externally and passed in. + return_position_ids (bool): Whether to return position ids (default False). + pad_to_multiple_of (int, optional): If set, pads the total sequence length to be divisible by this number. + pad_sequences_to_be_divisible_by (int, optional): If set, each individual sequence is padded to this value. + separator_id (int, optional): A label to insert between sequences, typically should be -100 for causal LM. + + Example: + >>> from transformers import AutoTokenizer, DataCollatorForLanguageModeling + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") + >>> mlm_collator = DataCollatorForLanguageModeling(tokenizer) + >>> flat_collator = DataCollatorWithFlattening( + ... collator=mlm_collator, + ... pad_to_multiple_of=8, + ... ) + >>> + >>> # Input: variable length protein sequences + >>> sequences = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... {"input_ids": [0, 12, 13, 2]}, # 4 tokens + ... ] # Total: 15 tokens + >>> batch = flat_collator(sequences) + >>> print(batch['input_ids'].shape) # torch.Size([1, 16]) + >>> print(batch['labels'].shape) # torch.Size([1, 16]) + >>> print(batch['cu_seq_lens_q']) # tensor([0, 5, 11, 15, 16], dtype=torch.int32) + + Note: + The output is a THD-format (Total, Height, Depth) batch, where all input sequences are packed without + inter-sequence padding. Sequence boundaries are preserved using `cu_seq_lens_q`/`cu_seq_lens_k`, enabling + Flash Attention or context-parallelism without traditional attention masks. + """ + + collator: DataCollatorForLanguageModeling + return_position_ids: bool = False + pad_to_multiple_of: int | None = None + pad_sequences_to_be_divisible_by: int | None = None + separator_id: int | None = None + + def __post_init__(self): + """Ensure padding options are not used together.""" + if self.pad_sequences_to_be_divisible_by is not None and self.pad_to_multiple_of is not None: + raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together") + + def __call__(self, features, return_tensors=None): + """Process a batch of variable-length sequences for Flash Attention with MLM. + + This method performs the following steps: + 1. Flattens multiple sequences into a single packed tensor with Flash Attention metadata + 2. Applies MLM masking to the flattened sequence while preserving special tokens + 3. Optionally pads to a multiple of a specified number for hardware optimization + + Args: + features (List[Dict[str, List[int]]]): List of tokenized sequences, each containing + 'input_ids' and optionally 'attention_mask'. Example: + [ + {"input_ids": [0, 5, 6, 7, 2]}, # Protein sequence 1 + {"input_ids": [0, 8, 9, 10, 11, 2]}, # Protein sequence 2 + {"input_ids": [0, 12, 13, 2]} # Protein sequence 3 + ] + return_tensors (str, optional): Format for returned tensors. Only "pt" (PyTorch) + is supported. Defaults to None (uses collator default). + + Returns: + Dict[str, torch.Tensor]: Batch dictionary containing: + - input_ids (torch.Tensor): Flattened and MLM-masked token sequences. + Shape: [1, total_tokens] where total_tokens = sum of all sequence lengths + (plus padding if pad_to_multiple_of is specified). + - labels (torch.Tensor): MLM labels with -100 for non-masked tokens and + original token IDs for masked positions. Same shape as input_ids. + - cu_seq_lens_q (torch.IntTensor): Cumulative sequence lengths for queries. + Shape: [num_sequences + 1] or [num_sequences + 2] if padding is added. + Example: [0, 5, 11, 15] or [0, 5, 11, 15, 16] with padding. + - cu_seq_lens_k (torch.IntTensor): Cumulative sequence lengths for keys. + Same as cu_seq_lens_q for self-attention. + - max_length_q (int): Maximum sequence length in the batch. + - max_length_k (int): Same as max_length_q for self-attention. + - attention_mask (torch.Tensor): Attention mask with 1s for actual tokens + and 0s for padding tokens (if any). + + Raises: + NotImplementedError: If return_tensors is not "pt". + + Example: + >>> # Input features + >>> features = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... ] + >>> + >>> batch = collator(features) + >>> + >>> # Output shapes and values + >>> batch['input_ids'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['labels'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['cu_seq_lens_q'] # tensor([0, 5, 11], dtype=torch.int32) or larger + + Note: + The output is in THD (Total, Height, Depth) format with batch_size=1 and + sequence_length=total_tokens, optimized for Flash Attention's variable-length + sequence processing capabilities. When pad_to_multiple_of is used, an additional + mock sequence is appended to reach the desired total length. + """ + if return_tensors is not None and return_tensors != "pt": + raise NotImplementedError(f"Only return_tensors='pt' is supported, got '{return_tensors}'") + + # Perform the masking with the BSHD collator. + bshd_batch = self.collator(features, return_tensors=return_tensors) + + # Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values. + packed_batch = _pt_flatten_collate(features, return_position_ids=self.return_position_ids) + + # Get the masked input_ids and labels from the BSHD batch. + masked_input_ids = bshd_batch["input_ids"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + masked_labels = bshd_batch["labels"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + + if self.separator_id is not None: + masked_labels[:, packed_batch["cu_seq_lens_q"][1:-1]] = self.separator_id + + # Update the packed batch with the masked input_ids and labels. + packed_batch["input_ids"] = masked_input_ids + packed_batch["labels"] = masked_labels + + if self.pad_to_multiple_of is not None: + packed_batch = self._pad_batch_to_multiple_of(packed_batch) + + elif self.pad_sequences_to_be_divisible_by is not None: + packed_batch = self._pad_sequences_to_be_divisible_by(packed_batch) + + return packed_batch + + def _pad_batch_to_multiple_of(self, batch): + """Add a mock sequence to make the total number of tokens divisible by pad_to_multiple_of.""" + # Ensure token_pad is an integer, defaulting to 1 if pad_token_id is None or invalid + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_to_multiple_of is not None, "pad_to_multiple_of must be set" + + return _pt_pad_to_multiple_of( + batch, + self.pad_to_multiple_of, + token_pad=pad_token_id, + label_pad=-100, + ) + + def _pad_sequences_to_be_divisible_by(self, batch): + """Pad individual sequences using cu_seq_lens_*_padded for context parallelism.""" + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_sequences_to_be_divisible_by is not None, "pad_sequences_to_be_divisible_by must be set" + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + batch["input_ids"], + batch["labels"], + batch["cu_seq_lens_q"], + self.pad_sequences_to_be_divisible_by, + padding_token_id=pad_token_id, + padding_label_id=-100, + ) + + batch["input_ids"] = input_ids_padded.unsqueeze(0) + batch["labels"] = labels_padded.unsqueeze(0) + batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) + batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) + batch["pad_between_seqs"] = True + return batch + + +@dataclass +class TokenPackingDataset(torch.utils.data.IterableDataset): + """Dataset that uses sequence packing to construct batches with variable length up to a maximum number of tokens.""" + + dataset: datasets.IterableDataset + """Dataset to pack.""" + max_tokens_per_batch: int + """Maximum number of tokens per batch.""" + drop_last: bool = True + """Whether to drop the last batch if it's less than max_length.""" + split_samples: bool = False + """Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens.""" + pad_sequences_to_be_divisible_by: int | None = None + """If set, account for per-sequence padding when accumulating batches. + + Each sequence's contribution to the batch length is rounded up to the nearest multiple of this value, + matching the padding behavior of DataCollatorWithFlattening with the same parameter. When used with + split_samples=True, the split point is chosen so that the first part (after padding) exactly fills + the remaining batch capacity. + """ + + def __post_init__(self): + """Validate padding configuration.""" + if ( + self.pad_sequences_to_be_divisible_by is not None + and self.max_tokens_per_batch % self.pad_sequences_to_be_divisible_by != 0 + ): + logger.warning( + "max_tokens_per_batch (%d) is not divisible by pad_sequences_to_be_divisible_by (%d). " + "Batches may not fill to exactly max_tokens_per_batch when split_samples=True.", + self.max_tokens_per_batch, + self.pad_sequences_to_be_divisible_by, + ) + + def _padded_len(self, length: int) -> int: + """Return the padded length of a sequence, rounding up to the nearest multiple of pad_sequences_to_be_divisible_by.""" + if self.pad_sequences_to_be_divisible_by is None: + return length + return -(-length // self.pad_sequences_to_be_divisible_by) * self.pad_sequences_to_be_divisible_by + + def __iter__(self): + """Yield batches of samples, each with a variable number of tokens up to the maximum length. + + When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting + the final sample if needed. The remaining tokens from the split sample start the next batch. + + When pad_sequences_to_be_divisible_by is set, each sequence's padded length is used when + accumulating batch sizes, so the total padded length of the batch matches max_tokens_per_batch. + + Returns: + A generator of batches of samples, each with a variable number of tokens up to the maximum length. + """ + samples = [] + current_length = 0 + for sample in iter(self.dataset): + sample_length = len(sample["input_ids"]) + padded_len = self._padded_len(sample_length) + if padded_len > self.max_tokens_per_batch: + raise ValueError( + f"TokenPackingDataset: Padded sample length ({padded_len}) exceeds max_tokens_per_batch " + f"({self.max_tokens_per_batch}). Set truncation or a maximum length in your tokenizer or dataset to" + " ensure all samples fit within max_tokens_per_batch." + ) + + current_length += padded_len + if current_length == self.max_tokens_per_batch: + yield [*samples, sample] + samples = [] + current_length = 0 + + elif current_length > self.max_tokens_per_batch: + if not self.split_samples: + # Yield the current batch (before this sample) and start a new one with this sample. + if samples: + yield samples + samples = [sample] + current_length = padded_len + else: + # Calculate how many padded tokens are already in the batch. + tokens_in_batch = current_length - padded_len + # Calculate how many tokens we can fit from this sample, ensuring the + # padded length doesn't exceed the remaining capacity. + tokens_available = self.max_tokens_per_batch - tokens_in_batch + if self.pad_sequences_to_be_divisible_by is not None: + d = self.pad_sequences_to_be_divisible_by + tokens_available = (tokens_available // d) * d + if tokens_available <= 0: + # Remaining capacity is less than pad_sequences_to_be_divisible_by; + # can't fit any tokens from this sample. Yield current batch and start fresh. + if samples: + yield samples + samples = [sample] + current_length = padded_len + else: + first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) + yield [*samples, first_part] + samples = [remaining_part] + current_length = self._padded_len(len(samples[0]["input_ids"])) + else: + samples.append(sample) + + if not self.drop_last and samples: + yield samples + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset.""" + self.dataset.set_epoch(epoch) + + +@dataclass +class DataCollatorForContextParallel: + """A collator that is aware of context parallelism. + + For the case of context parallelism, padded sequences will be returned from the wrapped collator, and then split + into shards for each context parallelism rank. + + The shards are then typically sent to the ContextParallelDataLoaderWrapper which will scatter them to the + appropriate GPUs. + + Note: + When used with the ContextParallelDataLoaderWrapper and both context parallelism and tensor parallelism are + used, the collator inspects the ordering of the mesh dimensions to determine the layout of the flattened batch. + + If "cp" comes before "tp" in the mesh dimension names (CP row-major), the flattened batch will be: + [(cp0, tp0), (cp0, tp1), ..., (cp1, tp0), (cp1, tp1), ...] + + If "tp" comes before "cp" (TP row-major), the flattened batch will be: + [(tp0, cp0), (tp0, cp1), ..., (tp1, cp0), (tp1, cp1), ...] + + Args: + collator: The collator to use for the batch. + device_mesh: The device mesh with named dimensions. Must contain either a "cp" dimension for context parallelism + and/or a "tp" dimension for tensor parallelism. + qkv_format: The format of the query-key-value (QKV) tensor. + is_causal_lm: Whether the collator is for a causal language model. If True, the labels will be shifted before + being split into CP shards, and will be returned in the `shift_labels` field. + + """ + + collator: DataCollator + device_mesh: torch.distributed.device_mesh.DeviceMesh + qkv_format: str = "thd" + is_causal_lm: bool = False + + # Derived fields, initialized in __post_init__. + cp_world_size: int = field(init=False) + tp_world_size: int | None = field(init=False) + _is_cp_row_major: bool = field(init=False) + + def __post_init__(self): + """Initialize the cp_world_size, tp_world_size, and _is_cp_row_major fields based on the device mesh.""" + dim_names = self.device_mesh.mesh_dim_names + if dim_names is None: + raise ValueError("device_mesh must have mesh_dim_names") + + self.cp_world_size = self.device_mesh.size(dim_names.index("cp")) if "cp" in dim_names else 1 + self.tp_world_size = self.device_mesh.size(dim_names.index("tp")) if "tp" in dim_names else None + + # Determine whether CP is the row (outer) dimension of the 2D mesh. + # When flattened, the row-major dimension's index changes slowest. + # If "cp" comes before "tp" in mesh_dim_names, CP is the row dimension. + if "cp" in dim_names and "tp" in dim_names: + self._is_cp_row_major = dim_names.index("cp") < dim_names.index("tp") + else: + self._is_cp_row_major = True + + def __call__(self, features) -> list[dict[str, Any]]: + """Process batches of data and create shards for each context parallelism rank. + + Args: + features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'. + + Returns: + A list of dictionaries, each containing a shard of the batch for a given context parallelism rank. + """ + batch = self.collator(features) + + # Remove the attention mask from the batch, it's not valid for CP. + batch.pop("attention_mask", None) + + if self.is_causal_lm: + labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) + batch["labels"] = labels[..., 1:].contiguous() + + combined_batch = [] + for cp_rank in range(self.cp_world_size): + input_ids_sharded, labels_sharded = _split_batch_by_cp_rank( + cu_seqlens_padded=batch.get("cu_seq_lens_q_padded", None), # This will be None for BSHD format. + input_ids_padded=batch["input_ids"], + labels_padded=batch["labels"], + qvk_format=self.qkv_format, + cp_rank=cp_rank, + cp_world_size=self.cp_world_size, + ) + batch_shard = dict(batch) + batch_shard["input_ids"] = input_ids_sharded + if self.is_causal_lm: + batch_shard["shift_labels"] = labels_sharded + batch_shard["labels"] = None + else: + batch_shard["labels"] = labels_sharded + # Now determine the max length of the sequence. + if self.qkv_format == "thd": + seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1] + max_length = seqlens_q.max().item() + elif self.qkv_format == "bshd": + max_length = batch["input_ids"].shape[1] + else: + raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") + + batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64 + combined_batch.append(batch_shard) + + if self.tp_world_size is not None: + # Replicate each CP shard for TP ranks. The ordering depends on which dimension forms the rows in the + # flattened mesh. + if self._is_cp_row_major: + # Flattened mesh: [(cp0,tp0), (cp0,tp1), (cp1,tp0), (cp1,tp1)] + # Output: [cp0, cp0, cp1, cp1] + combined_batch = [batch for batch in combined_batch for _ in range(self.tp_world_size)] + else: + # Flattened mesh: [(tp0,cp0), (tp0,cp1), (tp1,cp0), (tp1,cp1)] + # Output: [cp0, cp1, cp0, cp1] + combined_batch = [ + combined_batch[cp_rank] for _ in range(self.tp_world_size) for cp_rank in range(self.cp_world_size) + ] + + return combined_batch + + +class ContextParallelDataLoaderWrapper: + """A dataloader that is aware of context and tensor parallelism.""" + + def __init__( + self, + dataloader: torch.utils.data.DataLoader | None, + cp_tp_mesh: torch.distributed.device_mesh.DeviceMesh, + ): + """A dataloader wrapper that distributes the data across the context and tensor parallelism groups. + + This class materializes a single dataloader for each data parallel mesh rank, and splits / replicates the data + from this dataloader across the context and tensor parallelism groups. + + Args: + dataloader: The dataloader to use. + cp_tp_mesh: The context parallel mesh, or a flattened, combined context parallel and tensor parallel mesh. + If a flattened mesh is provided, the cp / tp dimensions should be in the order they appeared in the + mesh_dim_names as passed to DataCollatorForContextParallel. + """ + if cp_tp_mesh.get_local_rank() == 0: + assert dataloader is not None, "dataloader must be provided on rank 0" + self.dataloader = dataloader + + else: + assert dataloader is None, "Dataloader on non-rank 0 will not be used" + + self.cp_tp_rank = cp_tp_mesh.get_local_rank() + self.cp_tp_group = cp_tp_mesh.get_group() + self.num_cp_tp_ranks = cp_tp_mesh.size() + self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None + + logger.debug( + "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", + torch.distributed.get_rank() if torch.distributed.is_initialized() else "", + self.cp_tp_rank, + ) + + def __iter__(self): + """Make the dataloader iterable.""" + if self.cp_tp_rank == 0: + self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() + return self + + @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") + def __next__(self): + """Get the batch from the dataloader for the current CP rank.""" + if self._prefetch_thread is not None: + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, Exception): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. + + This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the + given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their + corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called. + """ + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except StopIteration as e: + self._prefetch_result = e + except Exception as e: + self._prefetch_result = e + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + @nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green") + def _send_data_to_cp_tp_ranks(self): + """Send data to all the CP/TP ranks. + + This function will get the batch from the dataloader on CP rank 0, and then determine + the shards for all the different CP group members. + combined_batch = [, , ..., ] + Then it will scatter the shards to the different CP group members. + The shards are then combined into a single batch and returned to the caller + for the current CP rank. + + If tensor parallelism is also being used, the combined batch will look like: + combined_batch = [, , ..., , , ...] + where there are cp_world_size shards, and each shard is replicated tp_world_size times. The ordering of the + shards depends on which dimension forms the rows in the flattened mesh. + + Scalability: + Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they + do not grow linearly with CP size. + + Args: + None + + Returns: + batch: The batch for the current CP/TP rank. + + """ + try: + with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"): + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + except StopIteration as ex: + # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so + # that the dataloader can be restarted. + combined_batch = [ex] * self.num_cp_tp_ranks + + batch_on_this_rank = _scatter_batch_to_cp_tp_ranks(combined_batch, self.cp_tp_group) + + if isinstance(batch_on_this_rank, StopIteration): + raise batch_on_this_rank + + return batch_on_this_rank + + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_tp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + + +def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]: + """Split a sample dictionary at a specified number of tokens. + + This function splits a sample into two parts: the first part contains exactly `num_tokens` tokens, + and the second part contains the remaining tokens. All fields that are sequences (input_ids, attention_mask, + token_type_ids, labels, etc.) are split accordingly. + + Args: + sample: Dictionary containing sample data with fields like input_ids, attention_mask, token_type_ids, labels, etc. + num_tokens: Number of tokens to include in the first part of the split. + + Returns: + A tuple of two dictionaries: (first_part, remaining_part), where: + - first_part contains the first `num_tokens` tokens from each sequence field + - remaining_part contains the remaining tokens from each sequence field + + Example: + >>> sample = { + ... "input_ids": [0, 5, 6, 7, 8, 9, 2], + ... "attention_mask": [1, 1, 1, 1, 1, 1, 1], + ... "labels": [0, 5, 6, 7, 8, 9, 2] + ... } + >>> first, remaining = split_sample_by_num_tokens(sample, 3) + >>> first["input_ids"] # [0, 5, 6] + >>> remaining["input_ids"] # [7, 8, 9, 2] + """ + sample_length = len(sample["input_ids"]) + if num_tokens >= sample_length: + raise ValueError( + f"num_tokens ({num_tokens}) must be less than sample length ({sample_length}) to split the sample" + ) + if num_tokens <= 0: + raise ValueError(f"num_tokens ({num_tokens}) must be positive") + + first_part = {} + remaining_part = {} + + # Fields that should be split by tokens (sequence fields) + sequence_fields = ["input_ids", "attention_mask", "token_type_ids", "token_type", "labels"] + + for key, value in sample.items(): + if key in sequence_fields: + # Handle both list and tensor inputs + if isinstance(value, torch.Tensor): + first_part[key] = value[:num_tokens].clone() + remaining_part[key] = value[num_tokens:].clone() + elif isinstance(value, list): + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + else: + # For other types, try to slice if possible + try: + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + except (TypeError, IndexError): + # If slicing doesn't work, copy the value to both parts + # This handles fields that shouldn't be split (like metadata) + first_part[key] = value + remaining_part[key] = value + else: + # For non-sequence fields, copy to both parts + # This handles metadata fields that shouldn't be split + first_part[key] = value + remaining_part[key] = value + + return first_part, remaining_part + + +def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths. + + Args: + features: List of tokenized samples, each containing at least ``input_ids``. + return_position_ids: Whether to return position ids for each token. + + Returns: + A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and + ``max_length_q``/``max_length_k``. + """ + is_labels_provided = "labels" in features[0] + sample_lengths = [len(sample["input_ids"]) for sample in features] + + batch = {} + batch["max_length_q"] = batch["max_length_k"] = max(sample_lengths) + batch["input_ids"] = torch.tensor( + [[token for sample in features for token in sample["input_ids"]]], dtype=torch.int64 + ) + if is_labels_provided: + batch["labels"] = torch.tensor( + [[label for sample in features for label in sample["labels"]]], dtype=torch.int64 + ) + cu_seq_lens = torch.zeros(len(features) + 1, dtype=torch.int32) + cu_seq_lens[1:] = torch.cumsum(torch.tensor(sample_lengths), dim=0, dtype=torch.int32) + batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens + if "attention_mask" in features[0]: + batch["attention_mask"] = torch.tensor( + [[v for sample in features for v in sample["attention_mask"]]], dtype=torch.int64 + ) + if return_position_ids: + batch["position_ids"] = torch.hstack( + [torch.arange(sample_len, dtype=torch.int64) for sample_len in sample_lengths] + ).unsqueeze(0) + + return batch + + +def _find_seq_dim(tensor: torch.Tensor, seq_len: int) -> int: + """Find which dimension of tensor matches the expected sequence length. + + Args: + tensor: The tensor to inspect. + seq_len: The expected sequence length to match against tensor dimensions. + + Returns: + The dimension index that matches the sequence length. + + Raises: + ValueError: If no dimension matches the expected sequence length. + """ + if tensor.ndim == 1: + if tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"1D tensor shape {tensor.shape} doesn't match sequence length {seq_len}") + elif tensor.ndim >= 2: + if tensor.shape[1] == seq_len: + return 1 + elif tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"Tensor shape {tensor.shape} doesn't match sequence length {seq_len} in dim 0 or 1") + raise ValueError(f"Unexpected tensor ndim={tensor.ndim}") + + +def _process_tensor_thd( + val: torch.Tensor | None, + seq_len: int, + slice_sizes: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + cp_rank: int, + total_slices: int, +) -> torch.Tensor | None: + """Extract the THD context-parallel shard for a single tensor. + + For each sequence in the batch, selects two slices (one from the beginning and one from the end) + corresponding to the given CP rank, following the zigzag CP sharding pattern. + + Args: + val: The tensor to shard, or None (returned as-is). + seq_len: Total sequence length (from cu_seqlens_padded[-1]). + slice_sizes: Per-sequence slice sizes, computed as sequence_lengths // total_slices. + cu_seqlens_padded: Cumulative sequence lengths including padding. + cp_rank: The context parallelism rank index. + total_slices: Total number of slices per sequence (2 * cp_world_size). + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + """ + if val is None: + return val + + seq_dim = _find_seq_dim(val, seq_len) + + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices - cp_rank - 1) * slice_size), + seq_start + ((total_slices - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(seq_dim, torch.cat(cp_rank_slices)) + + +def _process_tensor_bshd( + val: torch.Tensor | None, + cp_rank: int, + cp_world_size: int, +) -> torch.Tensor | None: + """Extract the BSHD context-parallel shard for a single tensor. + + Splits a BSHD-format tensor along the sequence dimension (dim=1) into 2*cp_world_size chunks, + then selects the two chunks corresponding to the given CP rank (zigzag pattern). + + Args: + val: The tensor to shard, or None (returned as-is). + cp_rank: The context parallelism rank index. + cp_world_size: Total number of context parallelism ranks. + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + + Raises: + ValueError: If the tensor has fewer than 2 dimensions or its sequence length + is not divisible by 2 * cp_world_size. + """ + if val is None: + return val + + if val.ndim < 2: + raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D") + + seq_len = val.shape[1] + + # Calculate chunk size + total_chunks = 2 * cp_world_size + chunk_size = seq_len // total_chunks + + if seq_len % total_chunks != 0: + raise ValueError( + f"Sequence length {seq_len} must be divisible by {total_chunks} " + f"(2 * cp_world_size) for BSHD context parallelism" + ) + + # Determine which chunks this rank should get + # Rank 0 gets chunks [0, total_chunks-1] + # Rank 1 gets chunks [1, total_chunks-2] + # Rank k gets chunks [k, total_chunks-k-1] + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + + # Collect slices for this rank + rank_slices = [] + for chunk_idx in chunk_indices: + start_idx = chunk_idx * chunk_size + end_idx = start_idx + chunk_size + rank_slices.append(torch.arange(start_idx, end_idx, device=val.device)) + + # Concatenate indices for all chunks this rank should get + indices = torch.cat(rank_slices) + + # Select along sequence dimension (dim=1) + return val.index_select(1, indices) + + +def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token_pad: int, label_pad: int): + """Pad a batch to a multiple of pad_to_multiple_of. + + Appends a mock sequence to the end of the batch with the given token_pad and label_pad to make the total number of + tokens divisible by pad_to_multiple_of. + + Args: + batch: Input batch, possibly containing labels and/or cu_seq_lens / max_length keys. + pad_to_multiple_of: Multiple to pad to. + token_pad: Token to pad with. + label_pad: Label to pad with. + + Returns: + Batch dictionary with padded input_ids, labels, cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k. + """ + # Number of tokens we need to pad to make the total number of tokens divisible by pad_to_multiple_of + remainder = -batch["input_ids"].numel() % pad_to_multiple_of + + if remainder == 0: + return batch + + batch["input_ids"] = torch.cat( + [batch["input_ids"], torch.full((1, remainder), token_pad, dtype=batch["input_ids"].dtype)], dim=1 + ) + + if "labels" in batch: + batch["labels"] = torch.cat( + [batch["labels"], torch.full((1, remainder), label_pad, dtype=batch["labels"].dtype)], dim=1 + ) + + if "cu_seq_lens_q" in batch: + batch["cu_seq_lens_q"] = torch.cat( + [ + batch["cu_seq_lens_q"], + torch.tensor([batch["cu_seq_lens_q"][-1] + remainder], dtype=batch["cu_seq_lens_q"].dtype), + ], + dim=0, + ) + batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"] + + if "max_length_q" in batch: + batch["max_length_q"] = max(batch["max_length_q"], remainder) + batch["max_length_k"] = batch["max_length_q"] + + if "attention_mask" in batch: + batch["attention_mask"] = torch.cat( + [batch["attention_mask"], torch.zeros((1, remainder), dtype=batch["attention_mask"].dtype)], dim=1 + ) + + if "position_ids" in batch: + batch["position_ids"] = torch.cat( + [batch["position_ids"], torch.arange(remainder, dtype=batch["position_ids"].dtype).unsqueeze(0)], dim=1 + ) + + return batch + + +# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 +# we can replace this with the one in TransformerEngine. +@nvtx.annotate("collator._split_batch_by_cp_rank", color="green") +def _split_batch_by_cp_rank( + cu_seqlens_padded: torch.Tensor | None, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup | None = None, + qvk_format: str = "thd", + cp_rank: int | None = None, + cp_world_size: int | None = None, +): + """Slice batch input along sequence dimension into multiple chunks for THD or BSHD format. + + This function is intended for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths for THD format, + and with padded sequences for BSHD format. + + Args: + cu_seqlens_padded: Cumulative sequence length. Required for THD format, optional for BSHD format. + input_ids_padded: Input IDs. + labels_padded: Labels. + cp_group: Context parallel group. + qvk_format: Format of the input data ("thd" or "bshd"). + cp_world_size: The size of the context parallelism group. If provided, the function will use this value to determine the rank. + cp_rank: Optional manual CP rank index. When provided, the function shards tensors as if it + were executing on that rank without querying `torch.distributed.get_rank`. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + + if cp_world_size is None or cp_world_size <= 1: + # No splitting needed + return input_ids_padded, labels_padded + + if cp_rank is None: + cp_rank = torch.distributed.get_rank(group=cp_group) + elif not (0 <= cp_rank < cp_world_size): + raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.") + + if qvk_format == "thd": + if cu_seqlens_padded is None: + raise ValueError("cu_seqlens_padded is required for THD format") + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_world_size + slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence + + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + last_elem = cu_seqlens_padded[-1] + seq_len_val = last_elem.item() if isinstance(last_elem, torch.Tensor) else last_elem + + input_ids_padded = _process_tensor_thd( + input_ids_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + labels_padded = _process_tensor_thd( + labels_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + + elif qvk_format == "bshd": + input_ids_padded = _process_tensor_bshd(input_ids_padded, cp_rank, cp_world_size) + labels_padded = _process_tensor_bshd(labels_padded, cp_rank, cp_world_size) + + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded + + +class BatchType(TypedDict): + """The fields in the batch dictionary for THD context parallel.""" + + input_ids: torch.Tensor + labels: torch.Tensor | None + shift_labels: torch.Tensor | None + cu_seq_lens_q: torch.Tensor + cu_seq_lens_k: torch.Tensor + cu_seq_lens_q_padded: torch.Tensor + cu_seq_lens_k_padded: torch.Tensor + max_length_q: int + max_length_k: int + pad_between_seqs: bool + + +@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") +def _scatter_batch_to_cp_tp_ranks( + all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None +) -> BatchType | StopIteration: + """Scatter a batch to all the CP ranks. + + Args: + all_batches (list[BatchType] | list[StopIteration]): A list of already-sharded batches to scatter to the CP/TP + ranks. + cp_tp_group (torch.distributed.ProcessGroup | None): The process group to scatter the batches to. + + Returns: + BatchType | StopIteration: The batch on this rank. + """ + scatter_object_output_list = [None] + # Note: This does not provide an async_op handle. Thus its blocking. + torch.distributed.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=all_batches, + group=cp_tp_group, + group_src=0, + ) + return scatter_object_output_list[0] diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/dataset.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/dataset.py new file mode 100644 index 0000000000..d35e966b27 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/dataset.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset and dataloader creation for OpenGenome2 Mixtral pre-training. + +Simplified dataset module that always shuffles after tokenization for best +batch diversity. Supports both windowed and pre-chunked tokenization paths. +""" + +import logging + +import datasets +import datasets.distributed +from collator import ( + DataCollatorWithFlattening, + TokenPackingDataset, +) +from distributed_config import DistributedConfig +from opengenome_collator import GenomicDataCollator +from torch.utils.data import DataLoader, DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer +from transformers.data.data_collator import DataCollatorForLanguageModeling + + +logger = logging.getLogger(__name__) + + +def create_tokenized_dataset( + distributed_config: DistributedConfig, + tokenizer_name_or_path: str, + load_dataset_kwargs: dict, + max_seq_length: int | None = 8192, + stride: int | None = 200, + buffer_size: int = 50_000, + text_column: str = "text", + tokenize_batch_size: int = 100, +): + """Create a tokenized dataset, optionally with windowing. + + When ``max_seq_length`` and ``stride`` are both provided, long sequences are chunked into + overlapping windows of ``max_seq_length`` tokens with ``stride`` overlap using the tokenizer's + ``return_overflowing_tokens`` mechanism. + + When ``stride`` is ``None``, sequences are assumed to be pre-chunked (e.g. from globally-shuffled + shards) and are tokenized directly with BOS/EOS tokens added. No windowing or truncation is applied. + + Streaming datasets are always shuffled after tokenization for best batch diversity. + + Args: + distributed_config: The distributed configuration. + tokenizer_name_or_path: Name or path to the nucleotide tokenizer directory. + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. + max_seq_length: The maximum length of sequences (window size). Only used when stride is not None. + stride: The stride for windowing (overlap = stride tokens). Set to None to disable windowing + for pre-chunked datasets. + buffer_size: The buffer size for shuffle operations. + text_column: Name of the column containing genomic sequences (default: "text"). + tokenize_batch_size: The batch size for tokenization. + + Returns: + Tuple of (tokenized_dataset, tokenizer). + """ + use_windowing = stride is not None + + logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}") + dataset = datasets.load_dataset(**load_dataset_kwargs) + + if isinstance(dataset, datasets.IterableDataset): + # Hugging Face's `split_dataset_by_node` is quite sensitive to the total number of shards -- if the number of + # shards is not perfectly divisible by the world size, it defaults to loading the same shards on all nodes and + # using strided sampling to avoid loading the same data on all nodes. This can be quite inefficient with large + # numbers of shards and workers, so we use `dataset.shard` instead. + if distributed_config.world_size > dataset.num_shards: + logger.info(f"Sharding dataset with {dataset.num_shards} shards with split_dataset_by_node") + dataset = datasets.distributed.split_dataset_by_node( + dataset, rank=distributed_config.rank, world_size=distributed_config.world_size + ) + else: + logger.info(f"Sharding dataset with {dataset.num_shards} shards with dataset.shard") + dataset = dataset.shard(num_shards=distributed_config.world_size, index=distributed_config.rank) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + + if use_windowing: + + def tokenize_with_windowing(examples): + """Tokenize nucleotide sequences with windowing (one-to-many mapping).""" + result = tokenizer( + examples[text_column], + max_length=max_seq_length, + stride=stride, + truncation=True, + return_overflowing_tokens=True, + add_special_tokens=True, + ) + return result + + tokenize_fn = tokenize_with_windowing + logger.info(f"Using windowed tokenization: max_seq_length={max_seq_length}, stride={stride}") + else: + + def tokenize_direct(examples): + """Tokenize pre-chunked sequences directly, adding only BOS/EOS tokens.""" + result = tokenizer( + examples[text_column], + add_special_tokens=True, + truncation=False, + ) + return result + + tokenize_fn = tokenize_direct + logger.info("Using direct tokenization (pre-chunked dataset, no windowing)") + + tokenized_dataset = dataset.select_columns(text_column).map( + tokenize_fn, + batched=True, + batch_size=tokenize_batch_size, + remove_columns=[text_column], + ) + + # Always shuffle after tokenization for best batch diversity + if isinstance(tokenized_dataset, datasets.IterableDataset): + logger.info(f"Shuffling tokenized windows with buffer_size={buffer_size}") + tokenized_dataset = tokenized_dataset.shuffle(seed=42, buffer_size=buffer_size) + + # Even in THD mode, we use a base MLM collator that requires a padding token to be set. + if tokenizer.pad_token is None: + logger.warning(f"Tokenizer does not have a padding token. Setting it to the EOS token: {tokenizer.eos_token}") + tokenizer.pad_token = tokenizer.eos_token + + return tokenized_dataset, tokenizer + + +def create_bshd_dataloader( + distributed_config: DistributedConfig, + tokenizer_name_or_path: str, + load_dataset_kwargs: dict, + micro_batch_size: int, + num_workers: int = 1, + prefetch_factor: int = 4, + max_seq_length: int | None = 8192, + stride: int | None = 200, + seed: int = 42, + buffer_size: int = 50_000, + use_stateful_dataloader: bool = False, + text_column: str = "text", + uppercase_labels: bool = False, + mask_degenerate_bases: bool = True, + pad_sequences_to_be_divisible_by: int | None = None, +): + """Create a BSHD dataloader for OpenGenome2 pre-training. + + Args: + distributed_config: The distributed configuration. + tokenizer_name_or_path: Name or path to the nucleotide tokenizer directory. + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. + micro_batch_size: The batch size per device. + num_workers: The number of workers to use for the dataloader. + prefetch_factor: The prefetch factor to use for the dataloader. + max_seq_length: The maximum length of sequences (window size). + stride: The stride for windowing (overlap = stride tokens). + seed: The seed to use for the distributed sampler and data collator. + buffer_size: The buffer size for shuffle operations. + use_stateful_dataloader: Whether to use the StatefulDataLoader. + text_column: Name of the column containing text sequences (default: "text"). + uppercase_labels: Whether to uppercase labels (genomic masking). Default: False. + mask_degenerate_bases: Whether to mask non-ACGT bases in labels. Default: True. + pad_sequences_to_be_divisible_by: The number to pad sequences to be divisible by, required for FP8 training. + + Returns: + A tuple of (dataloader, dataset_or_sampler). + """ + tokenized_dataset, tokenizer = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_name_or_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=max_seq_length, + stride=stride, + buffer_size=buffer_size, + text_column=text_column, + tokenize_batch_size=micro_batch_size * prefetch_factor, + ) + + if isinstance(tokenized_dataset, datasets.IterableDataset): + sampler = None + else: + sampler = DistributedSampler( + tokenized_dataset, + rank=distributed_config.rank, + num_replicas=distributed_config.world_size, + seed=seed, + ) + + data_collator = GenomicDataCollator( + base_collator=DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=pad_sequences_to_be_divisible_by, + ), + uppercase_labels=uppercase_labels, + mask_degenerate_bases=mask_degenerate_bases, + ) + + dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader + train_dataloader = dataloader_class( + tokenized_dataset, + sampler=sampler, + batch_size=micro_batch_size, + collate_fn=data_collator, + num_workers=num_workers, + pin_memory=True if not use_stateful_dataloader else False, + persistent_workers=num_workers > 0, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + ) + + return train_dataloader, tokenized_dataset if sampler is None else sampler + + +def create_thd_dataloader( + distributed_config: DistributedConfig, + tokenizer_name_or_path: str, + load_dataset_kwargs: dict, + micro_batch_size: int | None = None, + token_micro_batch_size: int | None = None, + num_workers: int = 1, + prefetch_factor: int = 4, + max_seq_length: int | None = 8192, + stride: int | None = 200, + buffer_size: int = 50_000, + use_stateful_dataloader: bool = False, + text_column: str = "text", + uppercase_labels: bool = False, + mask_degenerate_bases: bool = True, + split_samples_in_token_packing: bool = True, + pad_sequences_to_be_divisible_by: int | None = None, +): + """Create a dataloader that packs up to the maximum number of tokens per batch. + + Args: + distributed_config: The distributed configuration. + tokenizer_name_or_path: Name or path to the nucleotide tokenizer directory. + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. + micro_batch_size: The batch size per device. + token_micro_batch_size: The maximum number of tokens per batch. If None, the micro_batch_size * max_seq_length + will be used. Defaults to None. + num_workers: The number of workers to use for the dataloader. + prefetch_factor: The prefetch factor to use for the dataloader. + max_seq_length: The maximum length of sequences (window size). + stride: The stride for windowing (overlap = stride tokens). + buffer_size: The buffer size for shuffle operations. + use_stateful_dataloader: Whether to use the StatefulDataLoader. + text_column: Name of the column containing genomic sequences (default: "text"). + uppercase_labels: Whether to uppercase labels (genomic masking). Default: False. + mask_degenerate_bases: Whether to mask non-ACGT bases in labels. Default: True. + split_samples_in_token_packing: Whether to split samples to form batches with exactly token_micro_batch_size + tokens. Default: True. + pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value. + + Returns: + A tuple of (dataloader, dataset_or_sampler). + """ + tokenized_dataset, tokenizer = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_name_or_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=max_seq_length, + stride=stride, + buffer_size=buffer_size, + text_column=text_column, + ) + + assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset." + if token_micro_batch_size is None: + assert micro_batch_size is not None, "Only one of micro_batch_size or token_micro_batch_size can be provided." + assert max_seq_length is not None, ( + "max_seq_length must be set when using micro_batch_size (needed to compute token_micro_batch_size). " + "Use token_micro_batch_size directly for pre-chunked datasets." + ) + token_micro_batch_size = micro_batch_size * max_seq_length + else: + assert micro_batch_size is None, "Only one of micro_batch_size or token_micro_batch_size can be provided." + + data_collator = GenomicDataCollator( + base_collator=DataCollatorWithFlattening( + collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), + pad_sequences_to_be_divisible_by=pad_sequences_to_be_divisible_by, + ), + uppercase_labels=uppercase_labels, + mask_degenerate_bases=mask_degenerate_bases, + ) + + dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader + train_dataloader = dataloader_class( + TokenPackingDataset( + tokenized_dataset, + max_tokens_per_batch=token_micro_batch_size, + split_samples=split_samples_in_token_packing, + ), + batch_size=None, # The TokenPackingDataset will handle the batching. + collate_fn=data_collator, + num_workers=num_workers, + pin_memory=True if not use_stateful_dataloader else False, + persistent_workers=num_workers > 0, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + ) + + return train_dataloader, tokenized_dataset diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/distributed_config.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/distributed_config.py new file mode 100644 index 0000000000..60f5a52fd4 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/distributed_config.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from dataclasses import dataclass, field + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class DistributedConfig: + """Class to track distributed ranks and handle basic distributed training setup.""" + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + _master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost")) + _master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355")) + + def is_main_process(self) -> bool: + """Return True on global rank 0.""" + return self.rank == 0 diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fp8_debugging.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fp8_debugging.py new file mode 100644 index 0000000000..68e6a5b9d9 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fp8_debugging.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from pathlib import Path + +import nvdlfw_inspect.api as debug_api +import transformer_engine +from distributed_config import DistributedConfig + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def initialize_fp8_debugging( + dist_config: DistributedConfig, + enabled: bool, + fp8_stats_file: str, + fp8_log_dir: str | os.PathLike, + fp8_enabled: bool, +) -> None: + """Initialize FP8 statistics logging.""" + if not enabled: + return + + if not fp8_enabled: + raise ValueError( + "fp8_stats_config.enabled is true but fp8_config.enabled is false, " + "please enable fp8_config.enabled to collect FP8 stats" + ) + + fp8_log_dir = Path(fp8_log_dir) / f"rank_{dist_config.rank}" + fp8_log_dir.mkdir(parents=True, exist_ok=True) + logger.info("Logging FP8 stats to %s", fp8_log_dir) + te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") + debug_api.initialize( + config_file=fp8_stats_file, + feature_dirs=[te_features_dir], + log_dir=fp8_log_dir.as_posix(), + default_logging_enabled=True, + ) diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_a2a.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_a2a.py new file mode 100644 index 0000000000..96a4b34862 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_a2a.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/mixtral/fused_a2a.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +# Portions of this code are from DeepSeek DeepEP project +# Copyright (c) 2025 DeepSeek +# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE + +import os + + +try: + from deep_ep import Buffer + from deep_ep.utils import EventHandle, EventOverlap + + HAVE_DEEP_EP = True + Buffer.set_num_sms(int(os.environ.get("DEEP_EP_SM_NUMS", "20"))) +except ImportError: + HAVE_DEEP_EP = False + +import torch + + +_buffer = None +_nvshmem_available = None + + +def _is_nvshmem_available() -> bool: + """Check if DeepEP was compiled with NVSHMEM support. + + Probes NVSHMEM by calling get_rdma_buffer_size_hint, since + is_sm90_compiled() alone is not a reliable proxy — SM90 can + be compiled while NVSHMEM is still disabled. + """ + global _nvshmem_available # noqa: PLW0603 + if _nvshmem_available is None: + try: + config = Buffer.get_dispatch_config(2) + config.get_rdma_buffer_size_hint(256, 2) + _nvshmem_available = True + except RuntimeError: + _nvshmem_available = False + return _nvshmem_available + + +def get_hidden_bytes(x: torch.Tensor) -> int: + """Calculate the number of hidden bytes for a tensor. + + Args: + x (torch.Tensor): Input tensor + + Returns: + int: Number of hidden bytes + """ + return x.size(1) * max(x.element_size(), 2) + + +def get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int): + """Get or create a buffer for all-to-all communication. + + Args: + group (torch.distributed.ProcessGroup): Process group for communication + hidden_bytes (int): Number of hidden bytes needed + + Returns: + Buffer: Communication buffer + """ + global _buffer # noqa: PLW0603 + num_nvl_bytes, num_rdma_bytes = 0, 0 + nvshmem = _is_nvshmem_available() + for config in ( + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), + ): + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) + if nvshmem: + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) + + # Allocate buffer if not existed or not enough buffer + # NOTES: the adaptive routing configuration of the network **must be off** + if ( + _buffer is None + or _buffer.group != group + or _buffer.num_nvl_bytes < num_nvl_bytes + or _buffer.num_rdma_bytes < num_rdma_bytes + ): + _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) + return _buffer + + +class FusedDispatch(torch.autograd.Function): + """Fused dispatch operation for MoE routing combining computation and communication.""" + + @staticmethod + def forward( + ctx, + x, + token_indices, + token_probs, + num_experts, + group, + async_finish=False, + allocate_on_comm_stream=False, + ): + """Forward pass of fused dispatch.""" + previous_event = None + if async_finish: + previous_event = EventOverlap(EventHandle()) + # Calculate layout before actual dispatch + buffer = get_buffer(group, get_hidden_bytes(x)) + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + event, + ) = buffer.get_dispatch_layout( + token_indices, + num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, + # so this is not compatible with CUDA graph + ( + recv_x, + recv_token_indices, + recv_token_probs, + num_recv_tokens_per_expert_list, + handle, + after_event_overlap, + ) = buffer.dispatch( + x, + topk_idx=token_indices, + topk_weights=token_probs, # DeepEP only supports float32 probs + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=event, # wait in deepep::intra/inter_dispatch + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + # Make sure current stream is synchronized + if async_finish: + after_event_overlap.current_stream_wait() + + # Save for backward + ctx.group = group + ctx.handle = handle + ctx.async_finish = async_finish + ctx.allocate_on_comm_stream = allocate_on_comm_stream + tokens_per_expert = torch.tensor(num_recv_tokens_per_expert_list) + + return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle) + + @staticmethod + def backward( + ctx, + grad_output, + grad_token_indices, + grad_token_probs, + grad_tokens_per_expert, + grad_handle, + ): + """Backward pass of fused dispatch.""" + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) + handle = ctx.handle + previous_event = None + if ctx.async_finish: + previous_event = EventOverlap(EventHandle()) + grad_x, grad_token_probs, after_event = buffer.combine( + grad_output.contiguous(), + handle, + topk_weights=grad_token_probs.float(), + previous_event=previous_event, + async_finish=ctx.async_finish, + allocate_on_comm_stream=ctx.allocate_on_comm_stream, + ) + # Make sure current stream is synchronized + if ctx.async_finish: + after_event.current_stream_wait() + return grad_x, None, grad_token_probs, None, None, None, None + + +class FusedCombine(torch.autograd.Function): + """Fused combine operation for MoE output combining computation and communication.""" + + @staticmethod + def forward(ctx, x, group, handle, async_finish=False, allocate_on_comm_stream=False): + """Forward pass of fused combine.""" + previous_event = None + if async_finish: + previous_event = EventOverlap(EventHandle()) + buffer = get_buffer(group, get_hidden_bytes(x)) + combined_x, _, after_event = buffer.combine( + x, + handle=handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + # Make sure current stream is synchronized + if async_finish: + after_event.current_stream_wait() + + ctx.handle = handle + ctx.group = group + ctx.async_finish = async_finish + ctx.allocate_on_comm_stream = allocate_on_comm_stream + return combined_x, None + + @staticmethod + def backward(ctx, grad_output, previous_event=None): + """Backward pass of fused combine.""" + previous_event = None + if ctx.async_finish: + previous_event = EventOverlap(EventHandle()) + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) + grad_x, _, _, _, _, after_event = buffer.dispatch( + grad_output.contiguous(), + handle=ctx.handle, + previous_event=previous_event, + async_finish=ctx.async_finish, + allocate_on_comm_stream=ctx.allocate_on_comm_stream, + ) + # Make sure current stream is synchronized + if ctx.async_finish: + after_event.current_stream_wait() + return grad_x, None, None, None, None + + +if HAVE_DEEP_EP: + + def fused_dispatch( + x, + token_indices, + token_probs, + num_experts, + group, + async_finish=False, + allocate_on_comm_stream=False, + ): + """Perform fused dispatch operation if deep_ep is available. + + Args: + x: Input tensor [num_tokens, hidden_size] + token_indices: Token routing indices [num_tokens, topk] + token_probs: Token routing probabilities [num_tokens, topk] + num_experts: Number of experts + group: Process group + async_finish: Whether to finish asynchronously + allocate_on_comm_stream: Whether to allocate on communication stream + + Returns: + Result of FusedDispatch + """ + return FusedDispatch.apply( + x.contiguous(), + token_indices, + token_probs, + num_experts, + group, + async_finish, + allocate_on_comm_stream, + ) + + def fused_combine(x, group, handle, async_finish=False, allocate_on_comm_stream=False): + """Perform fused combine operation if deep_ep is available. + + Args: + x: Input tensor + group: Process group + handle: Communication handle + async_finish: Whether to finish asynchronously + allocate_on_comm_stream: Whether to allocate on communication stream + + Returns: + Result of FusedCombine + """ + return FusedCombine.apply(x, group, handle, async_finish, allocate_on_comm_stream) + +else: + fused_dispatch = None + fused_combine = None diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_indices_converter.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_indices_converter.py new file mode 100644 index 0000000000..71bfaba4d3 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_indices_converter.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/mixtral/fused_indices_converter.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +import math +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +import torch + + +def _identity_decorator(fn): + """Return the decorated callable unchanged (no-op decorator fallback).""" + return fn + + +null_decorator = _identity_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + triton.autotune = null_decorator + triton.heuristics = null_decorator + tl = MagicMock() + + +if TYPE_CHECKING: + import triton + import triton.language as tl + + +# Assign a block to a row([1,topk]), generate a local routing map([1,num_of_local_experts]) +@triton.jit +def _indices_to_multihot_kernel( + indices_ptr, + probs_in_indices_ptr, + multihot_indices_ptr, # bool + probs_in_multihot_ptr, + position_map_ptr, + num_of_local_experts: tl.constexpr, + num_of_local_experts_next_power_of_2: tl.constexpr, + topk: tl.constexpr, + topk_next_power_of_2: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # noqa: N803 +): + """Triton kernel for converting indices to multihot representation. + + Input: + indices: [num_of_tokens, topk] + probs_in_indices: [num_of_tokens, topk] + Output: + multihot_indices: [num_of_tokens, num_of_local_experts] + probs_in_multihot: [num_of_tokens, num_of_local_experts] + + Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, + then the kernel can process the following conversion: + + Input Example: + indices = [ + [0, 1], + [1, 2] + ] + probs_in_indices = [ + [0.1, 0.2], + [0.3, 0.4] + ] + Output Example: + multihot_indices = [ + [1, 1, -1, -1], + [-1, 1, 1, -1] + ] + probs_in_multihot = [ + [0.1, 0.2, 0.0, 0.0], + [0.0, 0.3, 0.4, 0.0] + ] + """ + # Prepare the [0, topk) row + topk_row = tl.arange(0, topk_next_power_of_2) + topk_row = tl.where(topk_row < topk, topk_row, -1) + topk_row_mask = topk_row != -1 + # Prepare the [0, num_of_local_experts) row + num_exp_row = tl.arange(0, num_of_local_experts_next_power_of_2) + num_exp_row = tl.where(num_exp_row < num_of_local_experts, num_exp_row, -1) + num_exp_row_mask = num_exp_row != -1 + + # Load a [1, topk] row from the indices buffer + row_idx = tl.program_id(0) + indices_row = tl.load(indices_ptr + row_idx * topk + topk_row, mask=topk_row_mask) + indices_row = tl.where(topk_row_mask, indices_row, -1) + probs_row = tl.load(probs_in_indices_ptr + row_idx * topk + topk_row, mask=topk_row_mask) + + # Get the position of the each index in the indices_row, which is saved for backwards + position_row = tl.where(indices_row != -1, topk_row, -1) + # Mask of the valid indices + mask = (indices_row != -1) & (indices_row < num_of_local_experts) + + row_idx_offset = row_idx * num_of_local_experts + # Store to initialize + tl.store(multihot_indices_ptr + row_idx_offset + num_exp_row, 0, mask=num_exp_row_mask) + tl.store(probs_in_multihot_ptr + row_idx_offset + num_exp_row, 0, mask=num_exp_row_mask) + tl.store(position_map_ptr + row_idx_offset + num_exp_row, -1, mask=num_exp_row_mask) + # Use barrier to make sure the initialization is done + tl.debug_barrier() + # Store the indices and probs_in_indices + tl.store(multihot_indices_ptr + row_idx_offset + indices_row, 1, mask) + tl.store(probs_in_multihot_ptr + row_idx_offset + indices_row, probs_row, mask) + # Store the position of the position_row for backwards + tl.store(position_map_ptr + row_idx_offset + indices_row, position_row, mask) + + +# Assign a block to a row([1,topk]), generate a probs_indices([1,topk]) +@triton.jit +def _multihot_to_indices_kernel( + probs_in_multihot_ptr, + position_map_ptr, + probs_indices_ptr, + num_of_local_experts: tl.constexpr, + num_of_local_experts_next_power_of_2: tl.constexpr, + topk: tl.constexpr, + topk_next_power_of_2: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # noqa: N803 +): + """Triton kernel for converting multihot representation to indices. + + Input: + probs_in_multihot: [num_of_tokens, num_of_local_experts] + position_map: [num_of_tokens, num_of_local_experts] + Output: + probs_indices: [num_of_tokens, topk] + + Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2, + then the kernel can process the following conversion: + + Input Example: + probs_in_multihot = [ + [0.7, 0.8, 0.0, 0.0], + [0.0, 0.1, 0.9, 0.0] + ] + position_map = [ + [1, 1, -1, -1], + [-1, 1, 1, -1] + ] + Output Example: + probs_indices = [ + [0.7, 0.8], + [0.1, 0.9] + ] + """ + # Prepare the [0, topk) row + topk_row = tl.arange(0, topk_next_power_of_2) + topk_row = tl.where(topk_row < topk, topk_row, -1) + topk_row_mask = topk_row != -1 + # Prepare the [0, num_of_local_experts) row + num_exp_row = tl.arange(0, num_of_local_experts_next_power_of_2) + num_exp_row = tl.where(num_exp_row < num_of_local_experts, num_exp_row, -1) + num_exp_row_mask = num_exp_row != -1 + + # Load a [1, num_of_local_experts] row from the local routing map + row_idx = tl.program_id(0) + ptr_offset = row_idx * num_of_local_experts + num_exp_row + probs_in_multihot_row = tl.load(probs_in_multihot_ptr + ptr_offset, mask=num_exp_row_mask) + + # Get the original position of the valid value in the the indices + position_map_row = tl.load(position_map_ptr + ptr_offset, mask=num_exp_row_mask) + position_map_row = tl.where(num_exp_row_mask, position_map_row, -1) + mask = position_map_row != -1 + + # Store to initialize + tl.store(probs_indices_ptr + row_idx * topk + topk_row, 0, mask=topk_row_mask) + # Use barrier to make sure the initialization is done + tl.debug_barrier() + # Restore the indices and probs_indices + tl.store( + probs_indices_ptr + row_idx * topk + position_map_row, + probs_in_multihot_row, + mask, + ) + + +class IndicesToMultihot(torch.autograd.Function): + """Convert moe topk indices to multihot representation. + + This class implements a custom forward and backward propagation + operation for efficiently converting indices to multihot + representation. + It is an experimental feature and may change in future versions. + """ + + @staticmethod + def forward(ctx, indices, probs_indices, num_of_local_experts): # noqa: D417 + """Forward function for IndicesToMultihot. + + Convert indices to multihot representation. + + Args: + indices: [num_of_tokens, topk] + probs_indices: [num_of_tokens, topk] + num_of_local_experts: int + + Returns: + multihot_indices: [num_of_tokens, num_of_local_experts] + probs_in_multihot: [num_of_tokens, num_of_local_experts] + """ + assert HAVE_TRITON, "Triton is not installed" + num_of_tokens = indices.shape[0] + assert indices.shape == probs_indices.shape, "indices and probs_indices must have the same shape" + topk = indices.shape[1] + device = indices.device + multihot_indices = torch.empty((num_of_tokens, num_of_local_experts), dtype=torch.bool, device=device) + probs_in_multihot = torch.empty( + (num_of_tokens, num_of_local_experts), + dtype=probs_indices.dtype, + device=device, + ) + position_map = torch.empty((num_of_tokens, num_of_local_experts), dtype=torch.int32, device=device) + # Compute the next power of 2 for the topk and num_of_local_experts + topk_next_power_of_2 = 2 ** math.ceil(math.log2(topk)) + num_of_local_experts_next_power_of_2 = 2 ** math.ceil(math.log2(num_of_local_experts)) + grid = (num_of_tokens,) + _indices_to_multihot_kernel[grid]( + indices, + probs_indices, + multihot_indices, + probs_in_multihot, + position_map, + num_of_local_experts, + num_of_local_experts_next_power_of_2, + topk, + topk_next_power_of_2, + BLOCK_SIZE=32, # use only 1 warp per block + num_warps=1, + ) + + ctx.save_for_backward(position_map) + ctx.num_of_tokens = num_of_tokens + ctx.num_of_local_experts = num_of_local_experts + ctx.topk = topk + return multihot_indices, probs_in_multihot + + @staticmethod + def backward(ctx, grad_multihot_indices, grad_probs_in_multihot): # noqa: D417 + """Backward function for IndicesToMultihot. + + Convert multihot probs representation to indices. + indices is ignored in the backward function. + + Args: + grad_multihot_indices: [num_of_tokens, num_of_local_experts] + grad_probs_in_multihot: [num_of_tokens, num_of_local_experts] + + Returns: + grad_probs_indices: [num_of_tokens, topk] + """ + position_map = ctx.saved_tensors[0] + num_of_tokens = ctx.num_of_tokens + num_of_local_experts = ctx.num_of_local_experts + topk = ctx.topk + + # Initialize the gradient of the indices and probs_indices + grad_probs_indices = torch.empty( + (num_of_tokens, topk), dtype=grad_probs_in_multihot.dtype, device=grad_probs_in_multihot.device + ) + # Compute the next power of 2 for the topk and num_of_local_experts + topk_next_power_of_2 = 2 ** math.ceil(math.log2(topk)) + num_of_local_experts_next_power_of_2 = 2 ** math.ceil(math.log2(num_of_local_experts)) + + grid = (num_of_tokens,) + _multihot_to_indices_kernel[grid]( + # if the grad_probs_in_multihot is all-one/all-zero, + # overlapping stride will cause error without contiguous() + grad_probs_in_multihot.contiguous(), + position_map, + grad_probs_indices, + num_of_local_experts, + num_of_local_experts_next_power_of_2, + topk, + topk_next_power_of_2, + BLOCK_SIZE=32, # use only 1 warp per block + num_warps=1, + ) + return None, grad_probs_indices, None + + +def fused_indices_to_multihot(indices, probs_indices, num_of_local_experts): + """Convert moe topk indices to multihot representation.""" + return IndicesToMultihot.apply(indices, probs_indices, num_of_local_experts) diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_token_router.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_token_router.py new file mode 100644 index 0000000000..7dd3c68e5e --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/fused_token_router.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --- BEGIN COPIED FILE NOTICE --- +# This file is copied from: bionemo-recipes/models/mixtral/fused_token_router.py +# Do not modify this file directly. Instead, modify the source and run: +# python ci/scripts/check_copied_files.py --fix +# --- END COPIED FILE NOTICE --- + +"""DeepEP-backed TokenDispatcher using fused all-to-all and Triton index conversion.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch +import torch.distributed as dist +import transformer_engine.pytorch +from fused_a2a import fused_combine, fused_dispatch +from fused_indices_converter import HAVE_TRITON, fused_indices_to_multihot +from modeling_mixtral_te import DispatchOutput + + +@dataclass +class _FusedHandle: + """Opaque state for FusedTokenRouter between dispatch and combine.""" + + deepep_handle: Any + row_id_map: torch.Tensor + probs_multihot: torch.Tensor + recv_shape: torch.Size + + +class FusedTokenRouter: + """TokenDispatcher using DeepEP fused communication and Triton index conversion. + + Dispatch flow: + 1. ``fused_dispatch`` — DeepEP all-to-all sends tokens to expert-owning ranks. + 2. ``fused_indices_to_multihot`` — Triton kernel converts sparse ``[N, top_k]`` + indices to dense ``[N, num_local_experts]`` mask with differentiable probs. + 3. ``moe_permute(map_type="mask")`` — TE sorts received tokens by local expert. + + Combine flow: + 1. ``moe_unpermute(map_type="mask")`` — TE unsorts and applies routing weights. + 2. ``fused_combine`` — DeepEP reverse all-to-all sends results back. + + Args: + num_experts: Total number of experts (global, across all EP ranks). + num_local_experts: Number of experts hosted on this rank. + hidden_size: Hidden dimension size. + ep_size: Expert parallel world size. + """ + + def __init__(self, num_experts: int, num_local_experts: int, hidden_size: int, ep_size: int): + """Initialize the FusedTokenRouter.""" + if fused_dispatch is None or fused_combine is None: + raise ImportError("deep_ep is required for FusedTokenRouter. Install it with: pip install deep_ep") + if not HAVE_TRITON: + raise ImportError( + "Triton is required for FusedTokenRouter (used by fused_indices_to_multihot). " + "Install it with: pip install triton" + ) + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.ep_size = ep_size + self._ep_group: dist.ProcessGroup | None = None + + def set_ep_group(self, ep_group: dist.ProcessGroup) -> None: + """Set the expert-parallel process group for communication.""" + self._ep_group = ep_group + + def dispatch( + self, + hidden_states: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + ) -> DispatchOutput: + """Dispatch tokens to their assigned experts via DeepEP fused all-to-all. + + Args: + hidden_states: Flattened input tensor of shape ``[N, H]``. + selected_experts: Expert assignments, shape ``[N, top_k]``, int. + routing_weights: Normalized routing probabilities, shape ``[N, top_k]``, float32. + + Returns: + DispatchOutput with expert-sorted tokens, per-expert counts, and an opaque handle. + """ + assert self._ep_group is not None, "EP group must be set via set_ep_group() before dispatch" + + # Step 1: Fused all-to-all dispatch (DeepEP) + recv_x, recv_indices, recv_probs, tokens_per_expert, deepep_handle = fused_dispatch( + hidden_states, + selected_experts, + routing_weights.float(), # DeepEP requires float32 probs + self.num_experts, + self._ep_group, + ) + + # Step 2: Convert sparse [N, top_k] indices to dense [N, num_local_experts] multihot (Triton) + # Note: DeepEP returns local expert indices (0-based per rank), not global indices. + multihot_mask, probs_multihot = fused_indices_to_multihot(recv_indices, recv_probs, self.num_local_experts) + + # Step 3: Permute received tokens by local expert for GroupedLinear + num_out_tokens = int(tokens_per_expert.sum().item()) + permuted_x, row_id_map = transformer_engine.pytorch.moe_permute( + recv_x, multihot_mask.to(torch.int32), num_out_tokens=num_out_tokens, map_type="mask" + ) + + handle = _FusedHandle( + deepep_handle=deepep_handle, + row_id_map=row_id_map, + probs_multihot=probs_multihot, + recv_shape=recv_x.shape, + ) + + return DispatchOutput( + expert_input=permuted_x, + tokens_per_expert=tokens_per_expert.tolist(), + handle=handle, + ) + + def combine(self, expert_output: torch.Tensor, handle: _FusedHandle) -> torch.Tensor: + """Combine expert outputs back to the original token order. + + Args: + expert_output: Expert output tensor of shape ``[total_recv_tokens, H]``. + handle: Opaque state from ``dispatch()``. + + Returns: + Combined output tensor of shape ``[N, H]`` with routing weights applied. + """ + # Step 1: Unpermute expert output and apply routing weights + unpermuted = transformer_engine.pytorch.moe_unpermute( + expert_output, + handle.row_id_map, + merging_probs=handle.probs_multihot, + restore_shape=handle.recv_shape, + map_type="mask", + ) + + # Step 2: Fused all-to-all combine (reverse dispatch) + combined, _ = fused_combine(unpermuted, self._ep_group, handle.deepep_handle) + + return combined diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/L0_sanity.yaml b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/L0_sanity.yaml new file mode 100644 index 0000000000..4dec75f40e --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/L0_sanity.yaml @@ -0,0 +1,56 @@ +defaults: + - defaults + - _self_ + +# Use tiny Mixtral config for fast convergence testing +config_name_or_path: ./model_configs/og2-mixtral-8x1B +config_kwargs: + hidden_size: 384 + intermediate_size: 1536 + num_hidden_layers: 2 + num_attention_heads: 6 + num_key_value_heads: 6 + num_local_experts: 4 + num_experts_per_tok: 2 + max_position_embeddings: 256 + attn_input_format: bshd + self_attn_mask_type: causal + router_jitter_noise: 0.0 + +num_train_steps: 20 + +use_torch_compile: false +use_meta_device: false # small model fits on device directly; avoids meta-device complexity with EP +use_fp32_master_weights: false + +# EP=1 for single-GPU sanity testing. Multi-GPU EP tests are in test_fsdp_ep.py. +expert_parallel_size: 1 + +dataset: + tokenizer_name_or_path: ./tokenizers/nucleotide_fast_tokenizer + micro_batch_size: 1 + num_workers: 0 + max_seq_length: 256 + stride: 32 + text_column: sequence + mask_degenerate_bases: false + load_dataset_kwargs: + path: json + split: train + data_files: test_genomic_sequences.jsonl + streaming: true + +wandb: + name: "og2_mixtral_8x1B_sanity" + mode: "offline" + +lr_scheduler_kwargs: + num_warmup_steps: 10 + num_decay_steps: 240 + +checkpoint: + ckpt_dir: null + save_final_model: false + +logger: + frequency: 1 diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/defaults.yaml new file mode 100644 index 0000000000..b4fa80d6ff --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/defaults.yaml @@ -0,0 +1,95 @@ +# OpenGenome2 Mixtral training defaults + +use_te: true +config_name_or_path: ??? + +use_weight_decay_grouping: true +skip_embedding_weight_decay: true +use_fp32_master_weights: true + +config_kwargs: {} + +num_train_steps: ??? +grad_acc_steps: 1 +seed: 42 + +use_meta_device: false +use_torch_compile: false +use_sequence_packing: false + +# Expert parallelism: number of GPUs per expert-parallel group. +# Must divide world_size evenly. Set > 1 to enable MoE expert parallelism. +expert_parallel_size: 1 +# Token dispatcher for EP runs. Options: "alltoall" (NCCL, always available) or +# "fused_deepep" (requires deep_ep + Triton). token_dispatcher_fallback controls +# what happens when fused_deepep is unavailable: "alltoall" to fall back silently, +# or "error" to raise immediately. +token_dispatcher: alltoall +token_dispatcher_fallback: error + +dataset: + tokenizer_name_or_path: ??? + micro_batch_size: 8 + num_workers: 4 + max_seq_length: 8192 + stride: 200 + buffer_size: 50_000 + use_stateful_dataloader: false + pad_sequences_to_be_divisible_by: null + uppercase_labels: false + mask_degenerate_bases: true + load_dataset_kwargs: + path: ??? + split: "train" + streaming: true + +wandb: + name: ??? + project: null + +fp8_config: + enabled: false + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: false + +adamw_kwargs: + lr: 3e-3 + fused: true + betas: [0.9, 0.95] + eps: 1e-5 + weight_decay: 0.1 + +lr_scheduler_kwargs: + num_warmup_steps: 2_000 + num_decay_steps: 498_000 + min_lr_ratio: 0.000001 + +checkpoint: + ckpt_dir: ??? + save_final_model: true + resume_from_checkpoint: true + save_every_n_steps: 50 + max_checkpoints: 5 + async_save: true + +logger: + frequency: 100 + +validation: + enabled: false + eval_interval: 500 + num_batches: 10 + data_path: null + +fp8_stats_config: + enabled: false + fp8_stats_file: ./fp8_debugging_stats.yaml + fp8_log_dir: ./log_fp8_stats + +profiler: + enabled: false + start_step: 10 + end_step: 15 diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/og2_experiment.yaml b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/og2_experiment.yaml new file mode 100644 index 0000000000..6b91bc0ea2 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/og2_experiment.yaml @@ -0,0 +1,65 @@ +defaults: + - defaults + - _self_ + +config_name_or_path: ./model_configs/og2-mixtral-8x1B + +config_kwargs: + attn_input_format: thd + self_attn_mask_type: padding_causal + +use_weight_decay_grouping: true +skip_embedding_weight_decay: true +use_fp32_master_weights: true +use_meta_device: false +use_torch_compile: false +use_sequence_packing: true + +num_train_steps: 2000 +grad_acc_steps: 2 + +dataset: + tokenizer_name_or_path: ./tokenizers/nucleotide_fast_tokenizer + micro_batch_size: 2 + num_workers: 0 + max_seq_length: 1024 + stride: 200 + buffer_size: 50_000 + text_column: text + mask_degenerate_bases: true + uppercase_labels: false + load_dataset_kwargs: + path: parquet + data_files: ../../opengenome2_llama_native_te/dlcm_sanity_dataset.parquet + split: train + streaming: true + +adamw_kwargs: + lr: 1e-3 + fused: true + betas: [0.9, 0.95] + eps: 1e-5 + weight_decay: 0.1 + +lr_scheduler_kwargs: + num_warmup_steps: 100 + num_decay_steps: 1900 + min_lr_ratio: 0.01 + +checkpoint: + ckpt_dir: null + save_final_model: false + resume_from_checkpoint: false + save_every_n_steps: 0 + async_save: false + +validation: + enabled: false + +logger: + frequency: 1 + +wandb: + mode: online + project: og2-architecture-comparison + name: og2-mixtral-8x1B diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/og2_small_thd_moe.yaml b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/og2_small_thd_moe.yaml new file mode 100644 index 0000000000..6082a2510e --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/og2_small_thd_moe.yaml @@ -0,0 +1,66 @@ +defaults: + - defaults + - _self_ + +config_name_or_path: ./model_configs/og2-mixtral-8x1B + +config_kwargs: + attn_input_format: thd + self_attn_mask_type: padding_causal + +use_sequence_packing: true +use_meta_device: false +use_fp32_master_weights: true + +dataset: + tokenizer_name_or_path: ./tokenizers/nucleotide_fast_tokenizer + micro_batch_size: 2 + num_workers: 1 + max_seq_length: 2048 + stride: 200 + buffer_size: 50_000 + mask_degenerate_bases: true + uppercase_labels: false + load_dataset_kwargs: + path: json + data_files: /data/opengenome2/json/pretraining_or_both_phases/metagenomes/data_metagenomics_train_*.jsonl.gz + split: train + streaming: true + +num_train_steps: 20_000 +grad_acc_steps: 4 + +adamw_kwargs: + lr: 3e-4 + fused: true + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 0.1 + +lr_scheduler_kwargs: + num_warmup_steps: 500 + num_decay_steps: 19_500 + min_lr_ratio: 0.02 + +checkpoint: + ckpt_dir: null + save_final_model: true + resume_from_checkpoint: true + save_every_n_steps: 1_000 + async_save: true + +validation: + enabled: true + eval_interval: 200 + num_batches: 10 + data_path: /data/opengenome2/json/pretraining_or_both_phases/metagenomes/data_metagenomics_valid_chunk1.jsonl.gz + micro_batch_size: 4 + max_seq_length: null + stride: null + +logger: + frequency: 10 + +wandb: + name: og2_mixtral_8x1B_thd_moe + project: opengenome2-mixtral diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/model_configs/og2-mixtral-8x1B/config.json b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/model_configs/og2-mixtral-8x1B/config.json new file mode 100644 index 0000000000..6a0f67816f --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/model_configs/og2-mixtral-8x1B/config.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 0, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 8192, + "model_type": "mixtral", + "num_attention_heads": 16, + "num_hidden_layers": 16, + "num_key_value_heads": 8, + "num_local_experts": 8, + "num_experts_per_tok": 2, + "output_router_logits": false, + "pad_token_id": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 500000.0, + "router_aux_loss_coef": 0.0, + "router_jitter_noise": 0.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0", + "use_cache": true, + "vocab_size": 256 +} diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/model_configs/og2-mixtral-8x7B/config.json b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/model_configs/og2-mixtral-8x7B/config.json new file mode 100644 index 0000000000..24d1928022 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/model_configs/og2-mixtral-8x7B/config.json @@ -0,0 +1,27 @@ +{ + "architectures": [ + "MixtralForCausalLM" + ], + "attention_dropout": 0.0, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "model_type": "mixtral", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "num_local_experts": 8, + "num_experts_per_tok": 2, + "output_router_logits": false, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.0, + "router_jitter_noise": 0.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0", + "use_cache": true, + "vocab_size": 256 +} diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/modeling_mixtral_te.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/modeling_mixtral_te.py new file mode 100644 index 0000000000..61430fb053 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/modeling_mixtral_te.py @@ -0,0 +1,1170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TransformerEngine-optimized Mixtral model with Mixture of Experts.""" + +import logging +import os +import warnings +from collections import OrderedDict +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Any, ClassVar, ContextManager, Protocol, Unpack + +import torch +import torch.distributed as dist +import torch.nn as nn +import transformer_engine.common.recipe +import transformer_engine.pytorch +import transformers +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict +from torch.distributed.tensor import DTensor, Shard +from torch.distributed.tensor.device_mesh import DeviceMesh +from transformer_engine.pytorch.attention import InferenceParams +from transformer_engine.pytorch.attention.inference import PagedKVCacheManager +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers import MixtralConfig, PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers.utils.generic import TransformersKwargs + + +logger = logging.getLogger(__name__) + + +AUTO_MAP = { + "AutoConfig": "modeling_mixtral_te.NVMixtralConfig", + "AutoModel": "modeling_mixtral_te.NVMixtralModel", + "AutoModelForCausalLM": "modeling_mixtral_te.NVMixtralForCausalLM", +} + + +class NVMixtralConfig(MixtralConfig): + """NVMixtral configuration.""" + + # Attention input format: + # "bshd" = Batch, Sequence, Head, Dimension (standard padded format) + # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) + attn_input_format: str = "thd" + self_attn_mask_type: str = "padding_causal" + layer_precision: list[str | None] | None = None + use_quantized_model_init: bool = False + expert_parallel_size: int = 1 + moe_aux_loss_coeff: float = 0.0 + + def __init__(self, **kwargs): + """Initialize the NVMixtralConfig with additional TE-related config options.""" + super().__init__(**kwargs) + + if self.layer_precision is not None: + if len(self.layer_precision) != self.num_hidden_layers: + raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}") + for precision in self.layer_precision: + if precision not in {"fp8", "fp4", None}: + raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}') + + if self.num_local_experts % self.expert_parallel_size != 0: + raise ValueError( + f"num_local_experts ({self.num_local_experts}) must be divisible by " + f"expert_parallel_size ({self.expert_parallel_size})" + ) + + +@dataclass +class DispatchOutput: + """Output of TokenDispatcher.dispatch(). + + Attributes: + expert_input: Tokens sorted by local expert, shape ``[total_recv_tokens, H]``. + tokens_per_expert: Token count per local expert. + handle: Opaque state needed by ``combine()`` to reverse the dispatch. + """ + + expert_input: torch.Tensor + tokens_per_expert: list[int] + handle: Any + + +class TokenDispatcher(Protocol): + """Protocol for MoE token dispatch/combine strategies. + + Encapsulates the full dispatch cycle (permute -> communicate -> sort) and + combine cycle (unsort -> communicate -> unpermute) so that the MoE block + is agnostic to the communication backend (NCCL all-to-all, HybridEP, etc.). + """ + + def dispatch( + self, + hidden_states: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + ) -> DispatchOutput: + """Dispatch tokens to their assigned experts. + + Args: + hidden_states: Flattened input tensor of shape ``[N, H]``. + selected_experts: Expert assignments, shape ``[N, top_k]``, int. + routing_weights: Normalized routing probabilities, shape ``[N, top_k]``, float32. + + Returns: + DispatchOutput with expert-sorted tokens, per-expert counts, and an opaque handle. + """ + ... + + def combine( + self, + expert_output: torch.Tensor, + handle: Any, + ) -> torch.Tensor: + """Combine expert outputs back to the original token order. + + Args: + expert_output: Expert output tensor of shape ``[total_recv_tokens, H]``. + handle: Opaque state from ``dispatch()``. + + Returns: + Combined output tensor of shape ``[N, H]`` with routing weights applied. + """ + ... + + def set_ep_group(self, ep_group: dist.ProcessGroup) -> None: + """Set the expert-parallel process group for communication.""" + ... + + +class NVMixtralPreTrainedModel(PreTrainedModel): + """Base class for NVMixtral models.""" + + config_class = NVMixtralConfig + base_model_prefix = "model" + _no_split_modules = ("NVMixtralDecoderLayer",) + _skip_keys_device_placement = ("past_key_values",) + _do_not_quantize = ("lm_head", "model.layers.*.mlp.gate") # Flag for testing that these layers are not quantized. + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # After reset_parameters materializes GroupedLinear views on CUDA, + # re-stack them into the authoritative stacked parameters. + for module in self.modules(): + if isinstance(module, NVMixtralSparseMoeBlock): + module._restack_from_views() + + self.model.embed_tokens.to_empty(device="cuda") + self.model.embed_tokens.apply(self._init_weights) + + self.model.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=self.model.config).inv_freq.to("cuda") + + self.tie_weights() + + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys.""" + state_dict = super().state_dict(*args, **kwargs) + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVMixtralSparseMoeBlock(nn.Module): + """Mixture of Experts block using TransformerEngine GroupedLinear.""" + + def __init__(self, config: MixtralConfig, dispatcher: TokenDispatcher | None = None): + """Initialize the sparse MoE block.""" + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.jitter_noise = config.router_jitter_noise + + # Expert parallelism + self.ep_size = getattr(config, "expert_parallel_size", 1) + self.num_local_experts = self.num_experts // self.ep_size + self.moe_aux_loss_coeff = getattr(config, "moe_aux_loss_coeff", 0.0) + self._aux_loss: torch.Tensor = torch.tensor(0.0) + self.initializer_range = config.initializer_range + + self.dispatcher: TokenDispatcher = dispatcher or AllToAllTokenDispatcher( + self.num_experts, + self.num_local_experts, + self.hidden_size, + self.ep_size, + ) + + device = "meta" if torch.get_default_device() == torch.device("meta") else "cuda" + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + # Router always outputs num_experts logits (replicated across EP ranks) + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.gate = transformer_engine.pytorch.Linear( + self.hidden_size, + self.num_experts, + bias=False, + device=device, + params_dtype=config.dtype, + init_method=_init_method, + ) + + # Expert FFNs — only num_local_experts per rank when EP > 1 + self.experts_gate_up = transformer_engine.pytorch.GroupedLinear( + num_gemms=self.num_local_experts, + in_features=self.hidden_size, + out_features=2 * self.intermediate_size, + bias=False, + params_dtype=config.dtype, + device=device, + init_method=_init_method, + ) + self.experts_down = transformer_engine.pytorch.GroupedLinear( + num_gemms=self.num_local_experts, + in_features=self.intermediate_size, + out_features=self.hidden_size, + bias=False, + params_dtype=config.dtype, + device=device, + init_method=_init_method, + ) + + # Stack per-expert weights into single parameters (authoritative weight store). + # GroupedLinear's _parameters dict is emptied; weight attributes are set as views + # so that reset_parameters() / _get_weight_tensors() can still find them. + self.experts_gate_up_weight = nn.Parameter( + torch.stack( + [self.experts_gate_up._parameters.pop(f"weight{i}").data for i in range(self.num_local_experts)] + ) + ) # [num_local_experts, 2*intermediate_size, hidden_size] + + self.experts_down_weight = nn.Parameter( + torch.stack([self.experts_down._parameters.pop(f"weight{i}").data for i in range(self.num_local_experts)]) + ) # [num_local_experts, hidden_size, intermediate_size] + + # Set views back on GroupedLinear so getattr(self, "weight{i}") still works + # (needed by GroupedLinear.reset_parameters and _get_weight_tensors). + self._sync_expert_views() + + def _restack_from_views(self) -> None: + """Re-create stacked parameters on CUDA after meta init. + + Called by ``init_empty_weights()`` after ``reset_parameters()`` has been called + on all TE modules. Since GroupedLinear has no registered parameters (we popped them), + its ``reset_parameters()`` cannot move them from meta to CUDA. This method explicitly + creates the stacked parameters on CUDA and reinitializes them. + """ + device = torch.cuda.current_device() + for attr_name in ("experts_gate_up_weight", "experts_down_weight"): + old_param = getattr(self, attr_name) + if isinstance(old_param.data, DTensor): + # FSDP2 has sharded this param; materialize the local shard on CUDA + # and reconstruct the DTensor wrapper so FSDP2 can manage it. + local_data = old_param.data.to_local() + new_local = torch.empty(local_data.shape, dtype=local_data.dtype, device=device) + torch.nn.init.normal_(new_local, mean=0.0, std=self.initializer_range) + new_dtensor = DTensor.from_local( + new_local, + device_mesh=old_param.data.device_mesh, + placements=old_param.data.placements, + ) + setattr(self, attr_name, nn.Parameter(new_dtensor)) + else: + new_data = torch.empty_like(old_param, device=device) + torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range) + setattr(self, attr_name, nn.Parameter(new_data)) + + # Re-sync views to point to the new stacked parameter + self._sync_expert_views() + + def _sync_expert_views(self) -> None: + """Set GroupedLinear weight attributes as views of the stacked parameters. + + GroupedLinear internally uses ``getattr(self, f"weight{i}")`` in methods like + ``reset_parameters()`` and ``_get_weight_tensors()``. After popping the original + parameters, we set views of the stacked tensor so these methods keep working. + Uses ``object.__setattr__`` to bypass ``nn.Module.__setattr__`` and avoid + re-registering them as parameters. + """ + gate_up_w = self.experts_gate_up_weight + if isinstance(gate_up_w, DTensor): + gate_up_w = gate_up_w.to_local() + num_local = gate_up_w.shape[0] + for i in range(num_local): + object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i]) + + down_w = self.experts_down_weight + if isinstance(down_w, DTensor): + down_w = down_w.to_local() + num_local_down = down_w.shape[0] + for i in range(num_local_down): + object.__setattr__(self.experts_down, f"weight{i}", down_w[i]) + + def set_ep_group(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None: + """Set the expert-parallel process group and convert stacked weights to DTensors. + + Must be called before the first forward pass when ``ep_size > 1``. + + Args: + ep_group: A ``torch.distributed.ProcessGroup`` whose world size equals ``self.ep_size``. + ep_mesh: A 1-D ``DeviceMesh`` for expert parallelism. Used to wrap stacked weights + as ``DTensor(Shard(0))`` so that DCP can save/load/reshard them automatically. + """ + self.dispatcher.set_ep_group(ep_group) + # Convert stacked parameters to DTensors with Shard(0) on the expert dimension. + # Global shape is [num_experts, ...]; each rank stores [num_local_experts, ...]. + # Guard: only wrap plain tensors; skip if already DTensors (e.g. repeated calls). + if not isinstance(self.experts_gate_up_weight.data, DTensor): + self.experts_gate_up_weight = nn.Parameter( + DTensor.from_local(self.experts_gate_up_weight.data, device_mesh=ep_mesh, placements=[Shard(0)]) + ) + if not isinstance(self.experts_down_weight.data, DTensor): + self.experts_down_weight = nn.Parameter( + DTensor.from_local(self.experts_down_weight.data, device_mesh=ep_mesh, placements=[Shard(0)]) + ) + + def _expert_ffn(self, tokens: torch.Tensor, m_splits: list[int]) -> torch.Tensor: + """Run the expert SwiGLU FFN (gate_up -> silu -> down). + + Args: + tokens: Input tensor of shape [total_tokens, H], sorted by expert. + m_splits: Number of tokens per local expert. + + Returns: + Output tensor of shape [total_tokens, H]. + """ + gate_up_output = self.experts_gate_up(tokens, m_splits=m_splits) + gate_output, up_output = gate_up_output.chunk(2, dim=-1) + intermediate = torch.nn.functional.silu(gate_output) * up_output + return self.experts_down(intermediate, m_splits=m_splits) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass for the MoE block. + + Args: + hidden_states: Input tensor of shape [B, S, H] (bshd) or [T, H] (thd). + + Returns: + Output tensor of the same shape as the input. + """ + original_shape = hidden_states.shape + + # Apply multiplicative jitter noise to hidden states during training to encourage load balancing + if self.training and self.jitter_noise > 0: + hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_( + 1.0 - self.jitter_noise, 1.0 + self.jitter_noise + ) + + # Flatten to [N, H] for routing + if hidden_states.dim() == 3: + hidden_states = hidden_states.reshape(-1, self.hidden_size) + + # Router: compute expert assignments + with transformer_engine.pytorch.autocast(enabled=False): + # Keep the router logits in bf16 during FP8 training + router_logits = self.gate(hidden_states) # [N, num_experts] + + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # [N, top_k] + # Normalize routing weights + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + + # Auxiliary load-balancing loss (switch transformer style) + if self.moe_aux_loss_coeff > 0: + num_tokens = hidden_states.shape[0] + m_splits_tensor = torch.bincount(selected_experts.reshape(-1), minlength=self.num_experts).int() + # f_i: fraction of tokens dispatched to each expert + f = m_splits_tensor.float() / (num_tokens * self.top_k) + # P_i: mean router probability per expert (over all tokens) + router_probs = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float32) + p = router_probs.mean(dim=0) + self._aux_loss = self.moe_aux_loss_coeff * self.num_experts * (f * p).sum() + else: + self._aux_loss = torch.tensor(0.0, device=hidden_states.device) + + # Populate GroupedLinear weight attributes from stacked parameters. + # For EP, the stacked parameter is a DTensor; .to_local() gives the local shard. + self._sync_expert_views() + + dispatch_output = self.dispatcher.dispatch(hidden_states, selected_experts, routing_weights) + + expert_input = dispatch_output.expert_input + tokens_per_expert = dispatch_output.tokens_per_expert + + # MXFP8 requires both tensor dims divisible by 32. Upstream attention layers + # get this from the collator (pad_sequences_to_be_divisible_by=32), but after + # all-to-all dispatch the per-rank token count is data-dependent (routing + # decisions pick different expert loads). Pad here so GroupedLinear's MXFP8 + # kernels don't assert, then slice the padding off afterwards. + n_tokens = expert_input.shape[0] + mxfp8_pad = (32 - n_tokens % 32) % 32 + if mxfp8_pad: + expert_input = torch.nn.functional.pad(expert_input, (0, 0, 0, mxfp8_pad)) + # Attribute the padding tokens to the last expert so m_splits still sums correctly. + tokens_per_expert = list(tokens_per_expert) + tokens_per_expert[-1] += mxfp8_pad + + expert_output = self._expert_ffn(expert_input, tokens_per_expert) + + if mxfp8_pad: + expert_output = expert_output[:n_tokens] + + output = self.dispatcher.combine(expert_output, dispatch_output.handle) + + return output.reshape(original_shape) + + +class NVMixtralDecoderLayer(nn.Module): + """Mixtral decoder layer using TE attention and MoE MLP.""" + + def __init__(self, config: MixtralConfig, layer_idx: int, dispatcher: TokenDispatcher | None = None): + """Initialize the decoder layer.""" + super().__init__() + self.hidden_size = config.hidden_size + + device = "meta" if torch.get_default_device() == torch.device("meta") else "cuda" + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.self_attention = transformer_engine.pytorch.MultiheadAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_gqa_groups=config.num_key_value_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + attention_dropout=0, + fuse_qkv_params=True, + qkv_weight_interleaved=True, + normalization="RMSNorm", + input_layernorm=True, + qkv_format=config.attn_input_format, + attn_mask_type=config.self_attn_mask_type, + layer_number=layer_idx + 1, + params_dtype=config.dtype, + device=device, + init_method=_init_method, + output_layer_init_method=_init_method, + ) + + self.post_attention_layernorm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device=device, + ) + + self.mlp = NVMixtralSparseMoeBlock(config, dispatcher) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + rotary_pos_emb: torch.Tensor | None = None, + inference_params: InferenceParams | None = None, + **kwargs, + ) -> torch.Tensor: + """Forward pass for the decoder layer.""" + # Self attention with fused input layernorm + attn_output = self.self_attention( + hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=inference_params, + cu_seqlens_q=kwargs.get("cu_seqlens_q", None), + cu_seqlens_kv=kwargs.get("cu_seqlens_kv", None), + cu_seqlens_q_padded=kwargs.get("cu_seqlens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seqlens_kv_padded", None), + max_seqlen_q=kwargs.get("max_seqlen_q", None), + max_seqlen_kv=kwargs.get("max_seqlen_kv", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + # Residual connection + hidden_states = hidden_states + attn_output + + # Post-attention layernorm + MoE MLP + residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NVMixtralModel(NVMixtralPreTrainedModel): + """Mixtral model implemented in Transformer Engine.""" + + def __init__( + self, + config: MixtralConfig, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + dispatcher: TokenDispatcher | None = None, + ): + """Initialize the NVMixtral model. + + Args: + config: The configuration of the model. + fp8_recipe: The FP8 recipe for the model. + fp4_recipe: The FP4 recipe for the model. + dispatcher: The token dispatcher for the model. If None, the default AllToAllTokenDispatcher will be used. + """ + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe + self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe + + if self.config.layer_precision is None: + if fp8_recipe is not None and fp4_recipe is not None: + raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.") + if fp8_recipe is not None: + warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning) + self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers + elif fp4_recipe is not None: + raise RuntimeError( + "FP4 recipe provided but no layer_precision configured. " + "Set layer_precision explicitly when using FP4." + ) + + if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None: + raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.") + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + + layers: list[NVMixtralDecoderLayer] = [] + for layer_idx in range(config.num_hidden_layers): + with self.get_autocast_context(layer_idx, init=True): + layers += [NVMixtralDecoderLayer(config, layer_idx, dispatcher)] + + self.layers = nn.ModuleList(layers) + + self.norm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + + self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq + + self.gradient_checkpointing = False + + self.post_init() + + def set_ep_groups(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None: + """Propagate an expert-parallel process group and mesh to every MoE block. + + Args: + ep_group: The EP process group to set on each ``NVMixtralSparseMoeBlock``. + ep_mesh: A 1-D ``DeviceMesh`` for expert parallelism. + """ + for layer in self.layers: + layer.mlp.set_ep_group(ep_group, ep_mesh) + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: InferenceParams | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + """Forward pass for the NVMixtral model.""" + all_hidden_states = [] + output_hidden_states = kwargs.get("output_hidden_states", False) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # TE-specific input handling + has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]] + should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd" + + if should_pack_inputs: + assert attention_mask is not None, "Attention mask is required when packing BSHD inputs." + batch_size = hidden_states.size(0) + padded_seq_len = input_ids.size(1) + hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask) + kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens + kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + hidden_states = hidden_states.squeeze(0) + + if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2: + # Convert HF mask (1=attend, 0=pad) to TE boolean mask (True=masked, False=attend) + attention_mask = ~attention_mask[:, None, None, :].bool() + + if isinstance(past_key_values, InferenceParams): + lengths = ( + attention_mask.sum(dim=1).tolist() + if attention_mask.shape == input_ids.shape + else [1] * input_ids.shape[0] + ) + past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths))) + + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings) + + with self.get_autocast_context(None, outer=True): + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + with self.get_autocast_context(layer_idx): + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + if should_pack_inputs: + hidden_states = _pad_input(hidden_states, indices, batch_size, padded_seq_len) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + def get_autocast_context( + self, layer_number: int | None, init: bool = False, outer: bool = False + ) -> ContextManager: + """Return the appropriate TE autocast context manager for a given layer. + + This function handles both the quantized_model_init during layer creation and the te.autocast() during layer + forward pass. + + Args: + layer_number: The 0-indexed layer number. + init: Whether to return a `quantized_model_init` context for layer initialization. + outer: Whether to return a global te.autocast() context to wrap the entire model stack. + """ + if self.config.layer_precision is None: + return nullcontext() + + if outer: + if "fp8" not in self.config.layer_precision: + return nullcontext() + if self._fp8_recipe is None: + warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning) + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp8_recipe) + + precision = self.config.layer_precision[layer_number] + recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision) + + if init and self.config.use_quantized_model_init: + if precision in ("fp8", "fp4"): + return transformer_engine.pytorch.quantized_model_init(recipe=recipe) + return nullcontext() + + if precision == "fp8": + if recipe is None: + warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning) + return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe) + if precision == "fp4": + if recipe is None: + raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.") + return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe) + return transformer_engine.pytorch.autocast(enabled=False) + + +class NVMixtralForCausalLM(NVMixtralPreTrainedModel, transformers.GenerationMixin): + """Mixtral model with causal language head.""" + + _tied_weights_keys: ClassVar[list[str]] = [] + + def __init__( + self, + config, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + dispatcher: TokenDispatcher | None = None, + ): + """Initialize the NVMixtralForCausalLM model. + + Args: + config: The configuration of the model. + fp8_recipe: The FP8 recipe for the model. + fp4_recipe: The FP4 recipe for the model. + dispatcher: The token dispatcher for expert parallelism. If None, the default + AllToAllTokenDispatcher will be used. + """ + super().__init__(config) + self.model = NVMixtralModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe, dispatcher=dispatcher) + self.vocab_size = config.vocab_size + + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.lm_head = transformer_engine.pytorch.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: tuple[tuple[torch.Tensor, ...], ...] | None = None, + inputs_embeds: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + shift_labels: torch.Tensor | None = None, + use_cache: bool | None = None, + cache_position: torch.Tensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + """Forward pass for the NVMixtralForCausalLM model.""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + with transformer_engine.pytorch.autocast(enabled=False): + if hidden_states.ndim == 3: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + else: + logits = self.lm_head(hidden_states[slice_indices, :]) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, shift_labels=shift_labels, vocab_size=self.config.vocab_size, **kwargs + ) + + # Collect auxiliary load-balancing loss from all MoE layers + if self.config.moe_aux_loss_coeff > 0 and loss is not None: + aux_loss = sum(layer.mlp._aux_loss for layer in self.model.layers) + loss = loss + aux_loss + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def save_final_model_ep( + model: NVMixtralForCausalLM, + save_directory: str | os.PathLike, + dist_config=None, +) -> None: + """Gather all EP-sharded expert weights and save as safetensors. + + Uses ``get_model_state_dict(full_state_dict=True)`` to all-gather DTensors, + matching the pattern from ``save_final_model_fsdp2`` in the llama3 checkpoint module. + + All ranks must call this function. Only rank 0 writes files. + + Args: + model: The NVMixtral model (may have DTensor expert parameters). + save_directory: Directory to save ``model.safetensors`` and config. + dist_config: Optional distributed config with ``is_main_process()`` method. + If ``None``, only rank 0 saves. + """ + from safetensors.torch import save_file + + model_state_dict = get_model_state_dict( + model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + # Filter out TE _extra_state keys + model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")} + + is_main = dist_config.is_main_process() if dist_config is not None else (dist.get_rank() == 0) + if is_main: + os.makedirs(save_directory, exist_ok=True) + save_file(model_state_dict, os.path.join(save_directory, "model.safetensors")) + model.config.save_pretrained(save_directory) + logger.info(f"Saved final EP model to {save_directory}") + + +# Required for torch.compile'd functions below (_pad_input, _unpad_input, _build_expert_sort_indices) +# that use data-dependent scalar values (e.g., max_seqlen_in_batch.item()) or produce tensors +# whose shape depends on input data (e.g., repeat_interleave with tensor counts). +# These must be set at module level because torch.compile traces lazily on first call, +# so a scoped setting would not be active at trace time. +torch._dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.capture_dynamic_output_shape_ops = True + + +@torch.compile +def _pad_input(hidden_states, indices, batch, seqlen): + """Convert a THD tensor to a BSHD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +@torch.compile +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """Convert a BSHD tensor to a THD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + """ + batch_size = hidden_states.size(0) + seq_length = hidden_states.size(1) + + if attention_mask.shape[1] != seq_length: + return ( + hidden_states.squeeze(1), + torch.arange(batch_size, dtype=torch.int64, device=hidden_states.device), + torch.arange(batch_size + 1, dtype=torch.int32, device=hidden_states.device), + 1, + 1, + ) + + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + hidden_states.reshape(-1, *hidden_states.shape[2:])[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +class HFInferenceParams(InferenceParams): + """Extension of the InferenceParams class to support HF generate() and beam search.""" + + # Required by transformers >= 5.4 _valid_auto_compile_criteria(); this + # custom TE-based cache is not compatible with torch.compile generate(). + is_compileable = False + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Return the current cached sequence length. + + Required by HuggingFace transformers generate() to determine how many + tokens have already been cached. + """ + if not self.sequences: + return 0 + return max(self.sequences.values()) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder the cache based on the beam indices.""" + if isinstance(self.cache_manager, PagedKVCacheManager): + raise NotImplementedError("Beam search is not supported for paged cache manager.") + for layer_number, (key_cache, value_cache) in self.cache_manager.cache.items(): + updated_key_cache = key_cache.index_select(0, beam_idx) + updated_value_cache = value_cache.index_select(0, beam_idx) + self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache) + + +@torch.compile(fullgraph=True) +def _build_expert_sort_indices(recv_counts: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Build sort and unsort index tensors for reordering received tokens by local expert. + + After all-to-all, tokens arrive grouped by source rank: + ``[src0_exp0..src0_expL, src1_exp0..src1_expL, ...]``. ``GroupedLinear`` expects them + grouped by expert: ``[all_exp0, all_exp1, ...]``. + + Uses only vectorized tensor operations (no ``.item()`` calls or Python-level loops) + so that it is compatible with ``torch.compile(fullgraph=True)``. + + Args: + recv_counts: Integer tensor of shape ``[ep_size, num_local_experts]`` giving the + number of tokens received from each source rank for each local expert. + + Returns: + A ``(sort_indices, unsort_indices)`` pair of 1-D ``int64`` tensors that can be + used to reorder and restore the token dimension. + """ + ep_size, num_local_experts = recv_counts.shape + device = recv_counts.device + num_blocks = ep_size * num_local_experts + + # Source-grouped (row-major) block offsets: [s0e0, s0e1, ..., s1e0, s1e1, ...] + counts_src = recv_counts.reshape(-1).long() + offsets_src = torch.zeros(num_blocks, dtype=torch.long, device=device) + offsets_src[1:] = counts_src[:-1].cumsum(0) + + # Expert-grouped (column-major) block offsets: [e0s0, e0s1, ..., e1s0, e1s1, ...] + counts_exp = recv_counts.t().contiguous().reshape(-1).long() + offsets_exp = torch.zeros(num_blocks, dtype=torch.long, device=device) + offsets_exp[1:] = counts_exp[:-1].cumsum(0) + + total = counts_src.sum() + + # Mapping from source block index (s * L + e) to expert block index (e * S + s) + s_idx = torch.arange(ep_size, device=device).unsqueeze(1).expand(ep_size, num_local_experts) + e_idx = torch.arange(num_local_experts, device=device).unsqueeze(0).expand(ep_size, num_local_experts) + src_to_exp = (e_idx * ep_size + s_idx).reshape(-1) + + # Per-block positional shift from source layout to expert layout + shifts = offsets_exp[src_to_exp] - offsets_src + + # Expand per-block shifts to per-token + token_shifts = shifts.repeat_interleave(counts_src) + + # Map each source-grouped position to its expert-grouped destination + src_positions = torch.arange(total, device=device) + dst_positions = src_positions + token_shifts + + # sort_indices[exp_pos] = src_pos (gathers source tokens into expert order) + sort_indices = torch.empty(total, dtype=torch.long, device=device) + sort_indices[dst_positions] = src_positions + + # unsort_indices: inverse permutation (restores expert-ordered output to source order) + unsort_indices = torch.empty_like(sort_indices) + unsort_indices[sort_indices] = torch.arange(total, device=device) + + return sort_indices, unsort_indices + + +@dataclass +class _AllToAllHandle: + """Opaque handle for AllToAllTokenDispatcher, storing state between dispatch and combine.""" + + row_id_map: torch.Tensor + routing_weights: torch.Tensor + unsort_indices: torch.Tensor | None = None + input_split_sizes: list[int] | None = None + output_split_sizes: list[int] | None = None + + +class _DifferentiableAllToAll(torch.autograd.Function): + """Differentiable wrapper around dist.all_to_all_single. + + The forward pass performs the standard all-to-all communication. + The backward pass reverses the communication direction (swapping + input/output split sizes) so that gradients flow correctly. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + output_split_sizes: list[int], + input_split_sizes: list[int], + group: dist.ProcessGroup, + ) -> torch.Tensor: + """Perform all-to-all forward and save sizes for backward.""" + ctx.input_split_sizes = input_split_sizes + ctx.output_split_sizes = output_split_sizes + ctx.group = group + output = torch.empty( + sum(output_split_sizes), + input.shape[1], + device=input.device, + dtype=input.dtype, + ) + dist.all_to_all_single(output, input.contiguous(), output_split_sizes, input_split_sizes, group=group) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None]: + """Reverse all-to-all: swap input and output split sizes.""" + grad_input = torch.empty( + sum(ctx.input_split_sizes), + grad_output.shape[1], + device=grad_output.device, + dtype=grad_output.dtype, + ) + dist.all_to_all_single( + grad_input, + grad_output.contiguous(), + ctx.input_split_sizes, + ctx.output_split_sizes, + group=ctx.group, + ) + return grad_input, None, None, None + + +class AllToAllTokenDispatcher: + """TokenDispatcher using NCCL all-to-all for expert-parallel communication. + + Handles both EP=1 (no communication, just permute/unpermute) and EP>1 + (all-to-all token exchange between ranks) cases transparently. + + Args: + num_experts: Total number of experts (global). + num_local_experts: Number of experts on this rank. + hidden_size: Hidden dimension size. + ep_size: Expert parallel world size. + """ + + def __init__(self, num_experts: int, num_local_experts: int, hidden_size: int, ep_size: int): + """Initialize the AllToAllTokenDispatcher.""" + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.ep_size = ep_size + self._ep_group: dist.ProcessGroup | None = None + + def set_ep_group(self, ep_group: dist.ProcessGroup) -> None: + """Set the expert-parallel process group for all-to-all communication.""" + self._ep_group = ep_group + + def dispatch( + self, + hidden_states: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + ) -> DispatchOutput: + """Dispatch tokens to their assigned experts via permute and optional all-to-all. + + Args: + hidden_states: Flattened input tensor of shape ``[N, H]``. + selected_experts: Expert assignments, shape ``[N, top_k]``, int. + routing_weights: Normalized routing probabilities, shape ``[N, top_k]``, float32. + + Returns: + DispatchOutput with expert-sorted tokens, per-expert counts, and an opaque handle. + """ + # Permute tokens by expert using TE moe_permute + permuted_hidden, row_id_map = transformer_engine.pytorch.moe_permute( + hidden_states, selected_experts.to(torch.int32), map_type="index" + ) + + # Compute m_splits: number of tokens per expert + m_splits_tensor = torch.bincount(selected_experts.reshape(-1), minlength=self.num_experts).int() + + if self._ep_group is not None: + ep_group = self._ep_group + + # Token counts per expert, reshaped to [ep_size, num_local_experts] + send_counts = m_splits_tensor.reshape(self.ep_size, self.num_local_experts) + + # Exchange per-expert token counts between EP ranks + recv_counts = torch.empty_like(send_counts) + dist.all_to_all_single(recv_counts.flatten(), send_counts.flatten(), group=ep_group) + + # Derive split sizes for the token all-to-all + input_split_sizes = send_counts.sum(dim=1).tolist() + output_split_sizes = recv_counts.sum(dim=1).tolist() + local_m_splits = recv_counts.sum(dim=0).int().tolist() + + # Dispatch tokens to expert-owning ranks (differentiable) + recv_tokens = _DifferentiableAllToAll.apply( + permuted_hidden, output_split_sizes, input_split_sizes, ep_group + ) + + # Sort received tokens by local expert index. + # After all_to_all layout is [src0_exp0..src0_expL, src1_exp0..src1_expL, ...]. + # GroupedLinear needs [all_exp0, all_exp1, ...]. + sort_indices, unsort_indices = _build_expert_sort_indices(recv_counts) + + handle = _AllToAllHandle( + row_id_map=row_id_map, + routing_weights=routing_weights, + unsort_indices=unsort_indices, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes, + ) + return DispatchOutput( + expert_input=recv_tokens[sort_indices], + tokens_per_expert=local_m_splits, + handle=handle, + ) + + handle = _AllToAllHandle(row_id_map=row_id_map, routing_weights=routing_weights) + return DispatchOutput( + expert_input=permuted_hidden, + tokens_per_expert=m_splits_tensor.tolist(), + handle=handle, + ) + + def combine(self, expert_output: torch.Tensor, handle: _AllToAllHandle) -> torch.Tensor: + """Combine expert outputs back to the original token order. + + Args: + expert_output: Expert output tensor of shape ``[total_recv_tokens, H]``. + handle: Handle from ``dispatch()`` containing state for the reverse operation. + + Returns: + Combined output tensor of shape ``[N, H]`` with routing weights applied. + """ + if self._ep_group is not None: + assert handle.unsort_indices is not None + # Unsort back to source-rank-grouped order and reverse all_to_all (differentiable) + combined = _DifferentiableAllToAll.apply( + expert_output[handle.unsort_indices], + handle.input_split_sizes, + handle.output_split_sizes, + self._ep_group, + ) + else: + combined = expert_output + + # Unpermute and combine with routing weights (keep probs in float32 for numerical stability) + return transformer_engine.pytorch.moe_unpermute( + combined, + handle.row_id_map, + merging_probs=handle.routing_weights, + map_type="index", + ) diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/opengenome_collator.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/opengenome_collator.py new file mode 100644 index 0000000000..5076b2014a --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/opengenome_collator.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 + +"""Genomic sequence masking functions for data preprocessing.""" + +from dataclasses import dataclass +from typing import Any + +import torch + + +def _make_upper_case(tokens, lowercase_start=97, lowercase_end=122, case_diff=32): + """Replace lowercase ASCII characters with uppercase.""" + lowercase_mask = (tokens >= lowercase_start) & (tokens <= lowercase_end) + uppercase_tensor = torch.where(lowercase_mask, tokens - case_diff, tokens) + return uppercase_tensor, lowercase_mask + + +@dataclass +class GenomicDataCollator: + """Wrapper collator that adds genomic-specific masking to any base collator.""" + + base_collator: Any + uppercase_labels: bool = False + mask_degenerate_bases: bool = True + dna_tokens: tuple[int, ...] = (65, 67, 71, 84, 97, 99, 103, 116) + control_tags: tuple[int, ...] = (64, 35) + + def __call__(self, features: list) -> dict[str, Any]: + """Apply base collator, then add genomic masking.""" + batch = self.base_collator(features) + labels = batch["labels"] + + if self.uppercase_labels: + labels, _ = _make_upper_case(labels) + + if self.mask_degenerate_bases: + dna_tokens_tensor = torch.tensor(self.dna_tokens, device=labels.device) + control_tensor = torch.tensor(self.control_tags, device=labels.device) + not_dna = ~torch.isin(labels, dna_tokens_tensor) + is_control = torch.isin(labels, control_tensor) + labels[(not_dna | is_control) & (labels != -100)] = -100 + + batch["labels"] = labels + return batch diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/optimizer.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/optimizer.py new file mode 100644 index 0000000000..880c9f30b4 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/optimizer.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optimizer utilities for OpenGenome2 training.""" + +import logging + +import torch + + +logger = logging.getLogger(__name__) + + +def get_parameter_groups_with_weight_decay( + model: torch.nn.Module, + weight_decay: float, + skip_embeddings: bool = False, +) -> list[dict]: + """Create optimizer parameter groups with Megatron-style weight decay filtering.""" + decay_params = [] + no_decay_params = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + should_skip_decay = name.endswith(".bias") or param.dim() == 1 or (skip_embeddings and "embed" in name.lower()) + + if should_skip_decay: + no_decay_params.append(param) + else: + decay_params.append(param) + + logger.info( + "Weight decay groups: %d params with decay, %d params without decay", + len(decay_params), + len(no_decay_params), + ) + + return [ + {"params": decay_params, "weight_decay": weight_decay}, + {"params": no_decay_params, "weight_decay": 0.0}, + ] diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/perf_logger.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/perf_logger.py new file mode 100644 index 0000000000..7798260d35 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/perf_logger.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time + +import nvtx +import torch + + +try: + import nvdlfw_inspect.api as debug_api + + HAS_NVDLFW_INSPECT = True +except ImportError: + debug_api = None + HAS_NVDLFW_INSPECT = False +import torchmetrics +import wandb +from distributed_config import DistributedConfig +from omegaconf import DictConfig, OmegaConf +from torch.distributed.tensor import DTensor +from tqdm import tqdm +from transformers.modeling_outputs import CausalLMOutputWithPast + + +logger = logging.getLogger(__name__) + + +class PerfLogger: + """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. + + Args: + dist_config: The distributed configuration. + args: The arguments. + + Attributes: + min_loss: The minimum loss seen so far. + """ + + def __init__(self, dist_config: DistributedConfig, args: DictConfig): + """Initialize the logger.""" + self._dist_config = dist_config + self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) + + self._device = torch.device(f"cuda:{dist_config.local_rank}") + self.min_loss = torch.tensor(float("inf"), device=self._device) + + self.logging_frequency = args.logger.frequency + + metrics_dict = { + "train/loss": torchmetrics.MeanMetric(), + "train/grad_norm": torchmetrics.MeanMetric(), + "train/learning_rate": torchmetrics.MeanMetric(), + "train/step_time": torchmetrics.MeanMetric(), + "train/tokens_per_second_per_gpu": torchmetrics.MeanMetric(), + "train/unpadded_tokens_per_second_per_gpu": torchmetrics.MeanMetric(), + "train/total_unpadded_tokens_per_batch": torchmetrics.SumMetric(), + "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), + "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), + } + + self.metrics = torchmetrics.MetricCollection(metrics_dict) + # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. + self.metrics.to(self._device) + self.previous_step_time = time.perf_counter() + self._profiler = None + + if self._dist_config.is_main_process(): + # Log the entire args object to wandb for experiment tracking and reproducibility. + self._wandb_run = wandb.init(**args.wandb, config=self._run_config) + self._progress_bar = tqdm(total=args.num_train_steps, desc="Training") + + if args.profiler.enabled: + self._profiler = NsightProfiler( + **args.profiler, + wandb_run=self._wandb_run, + dist_config=dist_config, + ) + + # Gradient accumulation tracking + self.num_tokens = 0 + self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + self.running_loss = torch.tensor(0.0, device=self._device) + self.grad_acc_step_count = 0 + + # Whether to step debug_api.step() after each step + self.fp8_stats_enabled = args.fp8_stats_config.enabled + + @nvtx.annotate("PerfLogger.log_micro_step", color="pink") + def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast): + """Store data on micro step for gradient accumulation metrics. + + Args: + step: The step number. + batch: The batch of data for the micro step. + outputs: The outputs of the micro step. + """ + if self._dist_config.local_rank == 0: + logger.debug("log_micro_step") + + assert outputs.loss is not None, "Loss is None" + + with torch.no_grad(): + self.grad_acc_step_count += 1 + self.running_loss += outputs.loss + + if step % self.logging_frequency == 0 and step > 0: + self.num_tokens += batch["input_ids"].numel() + # Use attention_mask to count unpadded tokens (works for both BSHD and THD) + if "attention_mask" in batch: + self.num_unpadded_tokens += batch["attention_mask"].sum() + else: + # Fallback for pure sequence packing with no padding: all tokens are unpadded + self.num_unpadded_tokens += batch["input_ids"].numel() + + @nvtx.annotate("PerfLogger.log_step", color="purple") + def log_step( + self, + step: int, + grad_norm: torch.Tensor | DTensor, + lr: float, + ): + """Log a step to the logger and wandb. + + Args: + step: The step number. + grad_norm: The gradient norm of the step. + lr: The learning rate of the step. + """ + if self._dist_config.local_rank == 0: + logger.debug("log_step %s", step) + + with torch.no_grad(): + # Use accumulated metrics from gradient accumulation + assert self.grad_acc_step_count > 0, ( + f"Gradient accumulation steps ({self.grad_acc_step_count}) must be greater than 0, " + f"and can be incremented by log_micro_step()." + ) + + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.to_local() + + if self._profiler is not None: + self._profiler.step(step) + + if self.fp8_stats_enabled and HAS_NVDLFW_INSPECT: + debug_api.step() + + if step % self.logging_frequency == 0 and step > 0: + # Calculate average loss over all micro steps in the logging window + avg_loss = self.running_loss / self.grad_acc_step_count + self.min_loss = torch.minimum(self.min_loss, avg_loss) + + # Calculate an average step time over all steps in the logging window + now = time.perf_counter() + step_time = (now - self.previous_step_time) / self.logging_frequency + self.previous_step_time = now + + # For some reason, these trigger a CudaStreamSynchronize call, which blocks the dataloader in the next + # step. We therefore only update these once every logging_frequency steps. + self.metrics["train/loss"].update(avg_loss) + self.metrics["train/learning_rate"].update(lr) + self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/step_time"].update(step_time) + self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) + self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) + + memory_allocated = torch.cuda.memory_allocated() / (1024**3) + self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + + metrics = self.metrics.compute() + self.metrics.reset() + metrics = { + k: v.detach().cpu().item() if isinstance(v, torch.Tensor) and v.dim() == 0 else v + for k, v in metrics.items() + } + metrics["train/global_step"] = step + + if self._dist_config.is_main_process(): + wandb.log(metrics, step=step) + self._progress_bar.update(self.logging_frequency) + self._progress_bar.set_postfix({"loss": avg_loss.item()}) + + if self._dist_config.local_rank == 0: + logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) + + # Reset running loss and other tracking variables for next window + self.running_loss.zero_() + self.num_tokens = 0 + self.num_unpadded_tokens.zero_() + self.grad_acc_step_count = 0 + + def log_validation(self, step: int, val_metrics: dict): + """Log validation metrics to wandb. + + Args: + step: The current training step. + val_metrics: Dictionary with val_loss, val_ppl, val_tokens, val_batches, and optional Megatron-style. + """ + if self._dist_config.is_main_process(): + metrics = { + "val/loss": val_metrics["val_loss"], + "val/ppl": val_metrics["val_ppl"], + "val/tokens": val_metrics["val_tokens"], + "val/batches": val_metrics["val_batches"], + } + # Add Megatron-style metrics if available + if "val_loss_megatron" in val_metrics: + metrics["val/loss_megatron"] = val_metrics["val_loss_megatron"] + if "val_ppl_megatron" in val_metrics: + metrics["val/ppl_megatron"] = val_metrics["val_ppl_megatron"] + wandb.log(metrics, step=step) + + def finish(self): + """Finish the logger and close the progress bar.""" + if not self._dist_config.is_main_process(): + return + + wandb.finish() + self._progress_bar.close() + + if self.fp8_stats_enabled and HAS_NVDLFW_INSPECT: + debug_api.end_debug() + + +class NsightProfiler: + """Nsight Systems profiler wrapper for performance analysis. + + Args: + enabled: Whether profiling is enabled. + start_step: The step number at which to start profiling. + end_step: The step number at which to end profiling. + wandb_run: The wandb run for logging artifacts. + dist_config: The distributed configuration. + """ + + def __init__( + self, + enabled: bool, + start_step: int, + end_step: int, + wandb_run: wandb.Run, + dist_config: DistributedConfig, + ): + """Initialize the Nsight profiler.""" + self._wandb_run = wandb_run + self._dist_config = dist_config + + self.start_step = start_step + self.end_step = end_step + + self.current_step = 0 + self.profiling_started = False + self.profiling_finished = False + + # Check if running under nsys + self.running_under_nsys = "NSYS_PROFILING_SESSION_ID" in os.environ + + if self.running_under_nsys: + logger.info("Detected running under nsys - will use CUDA Profiler API for range control") + else: + logger.warning( + "Not running under nsys. Profiling will be skipped. " + "To enable profiling, run your script with: " + "nsys profile -o output_trace --trace=cuda,nvtx,osrt,cudnn,cublas --capture-range=cudaProfilerApi " + "--capture-range-end=stop python train_fsdp2.py profiler.enabled=true" + ) + + def step(self, step_num: int): + """Record a training step and control profiling based on the schedule. + + Args: + step_num: The current training step number. + """ + if not self.running_under_nsys or self.profiling_finished: + return + + self.current_step = step_num + + # Start profiling at start_step + if self.current_step == self.start_step and not self.profiling_started: + self._start_profiling() + # Stop profiling at end_step + elif self.current_step == self.end_step and self.profiling_started: + self._stop_profiling() + + def _start_profiling(self): + """Start CUDA profiling using the CUDA Profiler API.""" + if self.profiling_started: + return + + logger.info(f"Starting Nsight profiling at step {self.current_step}") + try: + torch.cuda.cudart().cudaProfilerStart() # type: ignore[attr-defined] + self.profiling_started = True + except Exception as e: + logger.error(f"Failed to start CUDA profiler: {e}") + + def _stop_profiling(self): + """Stop CUDA profiling using the CUDA Profiler API.""" + if not self.profiling_started or self.profiling_finished: + return + + logger.info(f"Stopping Nsight profiling at step {self.current_step}") + try: + torch.cuda.cudart().cudaProfilerStop() # type: ignore[attr-defined] + self.profiling_started = False + self.profiling_finished = True + except Exception as e: + logger.error(f"Failed to stop CUDA profiler: {e}") diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/requirements.txt b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/requirements.txt new file mode 100644 index 0000000000..073d9b39e3 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/requirements.txt @@ -0,0 +1,12 @@ +datasets +hydra-core +torch +torchao!=0.14.0 +torchdata +torchmetrics +tqdm +transformer_engine[pytorch] +transformers +wandb +zstandard +nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/run_ep_test.sh b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/run_ep_test.sh new file mode 100644 index 0000000000..84fe9409f0 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/run_ep_test.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Usage: ./run_ep_test.sh +# Run from inside the recipe directory on DGX. +# Example: ./run_ep_test.sh test_8x1B_og2_ep2 +set -e + +CONFIG=${1:-test_8x1B_og2_ep2} + +export BIONEMO_DISABLE_TORCH_COMPILE_HELPERS=1 +export TOKENIZERS_PARALLELISM=false +export NCCL_DEBUG=WARN +# EP>1 triggers torch._dynamo internally (DTensor/FSDP2), which needs ptxas + cuda.h +export PATH="/usr/local/cuda/bin:${PATH}" +export CPATH="/usr/local/cuda/include:${CPATH:-}" + +echo "=== Starting run: $CONFIG ===" +echo "Time: $(date)" +echo "GPUs: $CUDA_VISIBLE_DEVICES" + +torchrun \ + --standalone \ + --nproc_per_node=8 \ + train_fsdp2.py \ + --config-name "$CONFIG" + +echo "=== Finished run: $CONFIG at $(date) ===" diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/scheduler.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/scheduler.py new file mode 100644 index 0000000000..1579a5a5b2 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/scheduler.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from torch.optim.lr_scheduler import LambdaLR + + +def get_cosine_annealing_schedule_with_warmup( + optimizer, + num_warmup_steps=2_000, + num_decay_steps=500_000, + min_lr_ratio=0.0, + last_epoch=-1, +): + """Cosine annealing scheduler with warmup.""" + max_lr = optimizer.param_groups[0]["lr"] + min_lr = max_lr * min_lr_ratio + + def lr_lambda(current_step: int): + if num_warmup_steps > 0 and current_step <= num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + if current_step > num_warmup_steps + num_decay_steps: + return min_lr_ratio + num_steps_ = current_step - num_warmup_steps + decay_ratio = float(num_steps_) / float(num_decay_steps) + delta_lr = max_lr - min_lr + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + actual_lr = min_lr + coeff * delta_lr + return actual_lr / max_lr + + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/test_genomic_sequences.jsonl b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/test_genomic_sequences.jsonl new file mode 100644 index 0000000000..53f735e197 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/test_genomic_sequences.jsonl @@ -0,0 +1,3 @@ +{"sequence":"ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG"} +{"sequence":"GCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA"} +{"sequence":"TTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAATTAA"} diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/conftest.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/conftest.py new file mode 100644 index 0000000000..f68db475df --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/conftest.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path +from unittest import mock + +import pytest +import torch + + +sys.path.append(Path(__file__).parent.parent.as_posix()) +sys.path.append(Path(__file__).parent.as_posix()) +from distributed_config import DistributedConfig + + +@pytest.fixture +def recipe_path() -> Path: + """Return the root directory of the recipe.""" + return Path(__file__).parent.parent + + +@pytest.fixture +def tokenizer_path(recipe_path): + """Get the path to the nucleotide tokenizer.""" + return str(recipe_path / "tokenizers" / "nucleotide_fast_tokenizer") + + +@pytest.fixture(autouse=True) +def debug_api_cleanup(): + """Ensure nvdlfw_inspect does not stay initialized across tests.""" + yield + try: + import nvdlfw_inspect.api as debug_api + + debug_api.end_debug() + except Exception: + pass + + +@pytest.fixture(scope="session", autouse=True) +def device_mesh(): + """Create a re-usable torch process group for testing.""" + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + with ( + mock.patch("torch.distributed.init_process_group", return_value=None), + mock.patch("torch.distributed.destroy_process_group", return_value=None), + ): + yield + + torch.distributed.destroy_process_group() + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/distributed_helpers.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/distributed_helpers.py new file mode 100644 index 0000000000..4b052c22cf --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/distributed_helpers.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test utilities for distributed (EP/FSDP) tests in the opengenome2_mixtral_native_te recipe.""" + +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path + +import torch + + +# Import NVMixtralConfig from the local recipe copy (CI uses sparse-checkout) +RECIPE_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(RECIPE_ROOT)) + +from modeling_mixtral_te import NVMixtralConfig # noqa: E402 + + +def create_small_mixtral_config(**overrides) -> NVMixtralConfig: + """Create a small og2-style Mixtral config suitable for testing.""" + defaults = { + "hidden_size": 128, + "intermediate_size": 256, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "num_local_experts": 4, + "num_experts_per_tok": 2, + "max_position_embeddings": 128, + "vocab_size": 256, + "pad_token_id": 1, + "attn_input_format": "bshd", + "self_attn_mask_type": "causal", + "router_jitter_noise": 0.0, + } + defaults.update(overrides) + return NVMixtralConfig(**defaults) + + +def get_dummy_batch(vocab_size: int, seq_len: int = 32, batch_size: int = 2, device: str = "cuda"): + """Create a simple dummy batch for testing.""" + torch.manual_seed(42) + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids) + labels = input_ids.clone() + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + + +@dataclass(frozen=True) +class DistributedConfig: + """Distributed environment configuration.""" + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + _master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost")) + _master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355")) + + def is_main_process(self) -> bool: + """Return True if this is the global rank 0 process.""" + return self.rank == 0 diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_dataset.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_dataset.py new file mode 100644 index 0000000000..64b607240c --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_dataset.py @@ -0,0 +1,413 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +import torch +from dataset import create_bshd_dataloader, create_thd_dataloader, create_tokenized_dataset +from distributed_config import DistributedConfig +from hydra import compose, initialize_config_dir +from transformers import AutoTokenizer + + +@pytest.fixture +def simple_parquet(tmp_path): + """Create a simple Parquet file with multiple genomic sequences for testing batching.""" + parquet_path = tmp_path / "genomic_sequences.parquet" + + sequences = [ + "A" * 1000, + "T" * 1200, + "C" * 800, + "G" * 1500, + "ATCG" * 300, + ] + + table = pa.table({"text": sequences}) + pq.write_table(table, parquet_path) + return str(parquet_path) + + +def test_dataset_loads_and_tokenizes_sequence(tokenizer_path, tmp_path): + """Test that dataset loads and tokenizes a sequence correctly with exact token verification.""" + parquet_path = tmp_path / "genomic_sequences.parquet" + sequence = "T" * 10 + table = pa.table({"text": [sequence]}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + } + + tokenized_dataset, _ = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=20, + stride=10, + buffer_size=10_000, + ) + + sample = tokenized_dataset[0] + assert "input_ids" in sample + + tokens = sample["input_ids"] + nucleotides = tokens[1:-1] + + bos = 2 + eos = 0 + t = 84 # ASCII value of 'T' + + expected_sequence = [t] * 10 + received_sequence = nucleotides + + assert tokens[0] == bos, f"First token should be BOS (2), got {tokens[0]}" + assert tokens[-1] == eos, f"Last token should be EOS (0), got {tokens[-1]}" + assert received_sequence == expected_sequence, f"Expected {expected_sequence}, got {received_sequence}" + + +def test_dataloader_returns_expected_batch(tokenizer_path, tmp_path): + """Test dataloader returns exact expected batch with known input.""" + parquet_path = tmp_path / "single_sequence.parquet" + sequence = "A" * 5 + table = pa.table({"text": [sequence]}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + } + + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=1, + num_workers=0, + max_seq_length=7, + stride=5, + uppercase_labels=False, + mask_degenerate_bases=False, + ) + + returned_batch = next(iter(dataloader)) + + bos = 2 + eos = 0 + a = 65 # ASCII value of 'A' + + expected_input_ids = torch.tensor([[bos, a, a, a, a, a, eos]], dtype=torch.long) + expected_labels = torch.tensor([[bos, a, a, a, a, a, eos]], dtype=torch.long) + expected_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1]], dtype=torch.long) + + assert torch.equal(returned_batch["input_ids"], expected_input_ids) + assert torch.equal(returned_batch["labels"], expected_labels) + assert torch.equal(returned_batch["attention_mask"], expected_attention_mask) + + +def test_attention_mask_aligns_with_labels(tokenizer_path, simple_parquet): + """Test attention_mask correctly identifies real vs padded positions in labels.""" + ignore_pad_token = -100 + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": simple_parquet, + "split": "train", + } + + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=2, + num_workers=0, + max_seq_length=500, + stride=100, + uppercase_labels=False, + mask_degenerate_bases=False, + ) + + batch = next(iter(dataloader)) + + attention_mask = batch["attention_mask"][0] + labels = batch["labels"][0] + input_ids = batch["input_ids"][0] + + real_positions = attention_mask == 1 + real_labels = labels[real_positions] + real_input_ids = input_ids[real_positions] + + assert torch.all(real_labels == real_input_ids), "Labels should match input_ids at real token positions" + assert real_labels[0].item() == 2, "First token should be BOS (2)" + assert real_labels[-1].item() == 0, "Last real token should be EOS (0)" + + assert torch.all(real_labels != ignore_pad_token), "Real tokens should not have IGNORE_PAD_TOKEN" + + padded_positions = attention_mask == 0 + if padded_positions.any(): + padded_labels = labels[padded_positions] + assert torch.all(padded_labels == ignore_pad_token) + + +def test_windowing_in_dataset_creates_multiple_samples(tokenizer_path, tmp_path): + """Test that the dataset's windowing creates expected number of samples.""" + parquet_path = tmp_path / "genomic_sequences.parquet" + sequence = "A" * 3000 + table = pa.table({"text": [sequence]}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + } + + tokenized_dataset, _ = create_tokenized_dataset( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + max_seq_length=1000, + stride=800, + buffer_size=10_000, + ) + + num_samples = len(tokenized_dataset) + assert num_samples == 12, f"Expected exactly 12 windows, got {num_samples}" + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_multiple_sequences_batch_correctly(tokenizer_path, simple_parquet, streaming): + """Test that multiple sequences batch together correctly.""" + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": simple_parquet, + "split": "train", + "streaming": streaming, + } + + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=2, + num_workers=0, + max_seq_length=500, + stride=100, + buffer_size=10_000, + ) + + batch = next(iter(dataloader)) + + assert batch["input_ids"].shape[0] == 2, f"Batch should contain 2 sequences, got {batch['input_ids'].shape[0]}" + + batch_size, seq_length = batch["input_ids"].shape + assert batch["attention_mask"].shape == (batch_size, seq_length) + assert batch["labels"].shape == (batch_size, seq_length) + + +def test_batching_produces_correct_batch_size(tokenizer_path, tmp_path): + """Test that batching produces correct batch sizes with remainder.""" + parquet_path = tmp_path / "five_sequences.parquet" + sequences = ["A" * 10, "T" * 15, "C" * 12, "G" * 8, "ATCG" * 3] + table = pa.table({"text": sequences}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + } + + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=2, + num_workers=0, + max_seq_length=50, + stride=10, + ) + + batches = list(dataloader) + + assert len(batches) == 3, f"Expected exactly 3 batches from 5 sequences, got {len(batches)}" + assert batches[0]["input_ids"].shape[0] == 2 + assert batches[1]["input_ids"].shape[0] == 2 + assert batches[2]["input_ids"].shape[0] == 1 + + +def test_non_streaming_dataset_produces_correct_batch_size(recipe_path): + """Test that non-streaming dataset produces correct batch sizes.""" + distributed_config = DistributedConfig(rank=0, world_size=1) + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=["dataset.load_dataset_kwargs.streaming=False"], + ) + + dataloader, sampler = create_bshd_dataloader( + distributed_config=distributed_config, + **sanity_config.dataset, + ) + + assert isinstance(sampler, torch.utils.data.distributed.DistributedSampler) + + batches = list(itertools.islice(dataloader, 50)) + + for batch in batches: + assert batch["input_ids"].shape[0] == sanity_config.dataset.micro_batch_size + assert batch["input_ids"].shape[1] <= sanity_config.dataset.max_seq_length + + +def test_batching_produces_correct_batch_size_sequence_packing(tokenizer_path, tmp_path): + """Test that sequence packing batching works correctly.""" + parquet_path = tmp_path / "five_sequences.parquet" + sequences = ["A"] * 20 + table = pa.table({"text": sequences}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + "streaming": True, + } + + dataloader, _ = create_thd_dataloader( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + token_micro_batch_size=15, + max_seq_length=15, + stride=10, + split_samples_in_token_packing=False, + ) + + batches = list(dataloader) + assert len(batches) > 0 + + for batch in batches: + torch.testing.assert_close(batch["input_ids"].squeeze(0), torch.tensor([[2, 65, 0] * 5]).flatten()) + + +def test_dataloader_with_genomic_masking(tokenizer_path, tmp_path): + """Test that create_bshd_dataloader works with genomic masking enabled.""" + parquet_path = tmp_path / "genomic_with_degenerate.parquet" + sequences = ["ACGTN", "GGTAR"] + table = pa.table({"text": sequences}) + pq.write_table(table, parquet_path) + + distributed_config = DistributedConfig(rank=0, world_size=1) + + load_dataset_kwargs = { + "path": "parquet", + "data_files": str(parquet_path), + "split": "train", + } + + dataloader, _ = create_bshd_dataloader( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + micro_batch_size=2, + num_workers=0, + max_seq_length=10, + stride=5, + mask_degenerate_bases=True, + ) + + batch = next(iter(dataloader)) + + assert batch["input_ids"].ndim == 2 + assert batch["labels"].ndim == 2 + + labels = batch["labels"] + assert 78 not in labels, "Degenerate N (78) should be masked" + assert 82 not in labels, "Degenerate R (82) should be masked" + + valid_dna = [65, 67, 71, 84] + assert any(tok in labels for tok in valid_dna), "Should have valid DNA tokens" + + +def test_token_packing_dataloader(tokenizer_path, tmp_path): + """Test that the token packing dataloader works.""" + parquet_path = tmp_path / "token_packing_sequences.parquet" + table = pa.table({"sequence": ["A" * 300, "C" * 280, "G" * 260, "T" * 240]}) + pq.write_table(table, parquet_path) + + load_dataset_kwargs = { + "path": "parquet", + "split": "train", + "data_files": str(parquet_path), + "streaming": True, + } + + distributed_config = DistributedConfig(rank=0, world_size=1) + + dataloader, _ = create_thd_dataloader( + distributed_config=distributed_config, + tokenizer_name_or_path=tokenizer_path, + load_dataset_kwargs=load_dataset_kwargs, + text_column="sequence", + micro_batch_size=1, + max_seq_length=1024, + ) + + batches = list(dataloader) + assert len(batches) >= 1 + assert batches[0]["input_ids"].ndim == 2 + assert batches[0]["labels"].ndim == 2 + + +@pytest.mark.parametrize( + "sequence", + [ + "ACGTACGT", + "A" * 100, + "TTTCCCGGGAAA", + ], +) +def test_tokenizer_roundtrip_decode(tokenizer_path, sequence): + """Test that encode -> decode round-trips correctly (no inserted spaces). + + The tokenizer uses a character-level WordLevel model with a Split pre-tokenizer, + so each nucleotide becomes a separate token. Without the Fuse decoder, decoding + inserts spaces between tokens (e.g., "AAA" -> [65,65,65] -> "A A A"). The Fuse + decoder joins them back without spaces. + """ + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + token_ids = tokenizer.encode(sequence, add_special_tokens=False) + decoded = tokenizer.decode(token_ids, skip_special_tokens=False) + assert decoded == sequence, f"Round-trip failed: '{sequence}' -> {token_ids} -> '{decoded}'" diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_distributed_checkpointing.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_distributed_checkpointing.py new file mode 100644 index 0000000000..51a74258e0 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_distributed_checkpointing.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributed checkpoint stop-go tests for the OpenGenome2 Mixtral native TE recipe.""" + +import gc +import os +import socket +import subprocess + +import pytest +import torch +from hydra import compose, initialize_config_dir +from train_fsdp2 import main as main_fsdp2 + + +os.environ["WANDB_DISABLED"] = "true" +os.environ["WANDB_MODE"] = "disabled" + + +def _reserve_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +os.environ["MASTER_PORT"] = str(_reserve_port()) + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +def _compose_config(recipe_path, tmp_path, overrides): + ckpt_dir = str(tmp_path / "ckpt") + base = [ + f"checkpoint.ckpt_dir={ckpt_dir}", + f"+wandb.dir={tmp_path}", + "dataset.use_stateful_dataloader=true", + ] + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + return compose(config_name="L0_sanity", overrides=base + list(overrides)) + + +def _assert_loss_valid(loss): + assert loss is not None + loss_val = float(loss) + assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}" + + +def _assert_checkpoint_step(ckpt_subdir, step, num_ranks): + step_dir = os.path.join(ckpt_subdir, f"step_{step}") + assert os.path.isdir(step_dir), f"Step {step} directory not found: {step_dir}" + files = os.listdir(step_dir) + # FSDP2 DCP checkpoints save as .distcp files with a .metadata index, + # not the older model_rank_*/optimizer_rank_* format. + distcp_files = [f for f in files if f.endswith(".distcp")] + has_metadata = ".metadata" in files + assert has_metadata, f"Missing .metadata in {step_dir}: {files}" + assert len(distcp_files) >= num_ranks, f"Expected at least {num_ranks} .distcp files in {step_dir}: {files}" + dataloader_files = [f for f in files if "dataloader" in f] + assert len(dataloader_files) >= num_ranks, ( + f"Expected dataloader files for {num_ranks} ranks in {step_dir}: {files}" + ) + + +def test_checkpoint_save_and_load_single_process_fsdp2_ep1(recipe_path, tokenizer_path, tmp_path): + """Single-process FSDP2 checkpoint save and resume with EP=1.""" + common = [ + "checkpoint.save_every_n_steps=5", + "checkpoint.async_save=false", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + "expert_parallel_size=1", + ] + + cfg1 = _compose_config( + recipe_path, + tmp_path, + ["num_train_steps=10", "checkpoint.resume_from_checkpoint=false", *common], + ) + loss1 = main_fsdp2(cfg1) + gc.collect() + torch.cuda.empty_cache() + + ckpt_subdir = os.path.join(str(tmp_path / "ckpt"), "train_fsdp2") + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=1) + + cfg2 = _compose_config( + recipe_path, + tmp_path, + ["num_train_steps=15", "checkpoint.resume_from_checkpoint=true", *common], + ) + loss2 = main_fsdp2(cfg2) + gc.collect() + torch.cuda.empty_cache() + + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=1) + _assert_checkpoint_step(ckpt_subdir, 10, num_ranks=1) + _assert_loss_valid(loss1) + _assert_loss_valid(loss2) + + +@requires_multi_gpu +def test_checkpoint_save_and_load_two_processes_fsdp2_ep2(recipe_path, tokenizer_path, tmp_path): + """Multi-GPU FSDP2 checkpoint save and resume with EP=2 via torchrun.""" + ckpt_dir = str(tmp_path / "ckpt") + env = os.environ.copy() + env["WANDB_MODE"] = "disabled" + env["MASTER_PORT"] = str(_reserve_port()) + env["PATH"] = f"/usr/local/cuda/bin:{env['PATH']}" + env["CPATH"] = f"/usr/local/cuda/include:{env.get('CPATH', '')}".rstrip(":") + env["BIONEMO_DISABLE_TORCH_COMPILE_HELPERS"] = "1" + env["TOKENIZERS_PARALLELISM"] = "false" + env["NCCL_DEBUG"] = "WARN" + env["TRITON_PTXAS_PATH"] = "/usr/local/cuda/bin/ptxas" + + train_script = recipe_path / "train_fsdp2.py" + common = [ + f"checkpoint.ckpt_dir={ckpt_dir}", + "checkpoint.save_every_n_steps=5", + "checkpoint.async_save=false", + "dataset.use_stateful_dataloader=true", + f"dataset.tokenizer_name_or_path={tokenizer_path}", + "expert_parallel_size=2", + ] + base_cmd = ["torchrun", "--standalone", "--nproc_per_node=2", str(train_script)] + + result1 = subprocess.run( + [*base_cmd, "num_train_steps=10", "checkpoint.resume_from_checkpoint=false", *common], + check=False, + capture_output=True, + text=True, + env=env, + ) + assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" + + ckpt_subdir = os.path.join(ckpt_dir, "train_fsdp2") + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=2) + + result2 = subprocess.run( + [*base_cmd, "num_train_steps=15", "checkpoint.resume_from_checkpoint=true", *common], + check=False, + capture_output=True, + text=True, + env=env, + ) + assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" + + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=2) + _assert_checkpoint_step(ckpt_subdir, 10, num_ranks=2) diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_fsdp_ep.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_fsdp_ep.py new file mode 100644 index 0000000000..401b8ebd55 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_fsdp_ep.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for FSDP2 + Expert Parallelism (EP) in the opengenome2_mixtral_native_te recipe. + +Verifies that FSDP2 and EP can be composed together: +- FSDP=2, EP=1 (2 GPUs): Data-parallel sharding, all experts on each rank. +- FSDP=1, EP=2 (2 GPUs): Expert-parallel training, no data parallelism. +""" + +import subprocess +import sys +from pathlib import Path + + +# Import from local recipe copy (CI uses sparse-checkout, shared recipe may not exist) +RECIPE_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(RECIPE_ROOT)) +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +import pytest # noqa: E402 +import torch # noqa: E402 +from distributed_helpers import DistributedConfig, create_small_mixtral_config, get_dummy_batch # noqa: E402 +from modeling_mixtral_te import NVMixtralForCausalLM # noqa: E402 + + +requires_2_gpus = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +def _distribute_state_dict(full_state_dict: dict, model: torch.nn.Module, device: torch.device) -> dict: + """Distribute a full (EP=1) state dict to match a model's DTensor sharding. + + After calling ``set_ep_groups``, expert weight parameters become DTensors with + ``Shard(0)`` placement. This function uses ``distribute_tensor`` to automatically + shard full expert weights according to those annotations, avoiding manual slicing. + + Args: + full_state_dict: Complete state dict from an EP=1 model (plain tensors). + model: Target EP model whose expert parameters are already DTensors. + device: Device to move source tensors to before distributing. + """ + from torch.distributed.tensor import DTensor, distribute_tensor + + distributed_state: dict = {} + # model.state_dict() filters _extra_state keys via the NVMixtralPreTrainedModel + # override, so use nn.Module.state_dict to get the unfiltered dict that includes + # TransformerEngine _extra_state entries required by load_state_dict(strict=True). + for key, value in torch.nn.Module.state_dict(model).items(): + if key.endswith("_extra_state"): + distributed_state[key] = value + elif key not in full_state_dict: + continue + elif isinstance(value, DTensor): + distributed_state[key] = distribute_tensor( + full_state_dict[key].to(device), + value.device_mesh, + list(value.placements), + ) + else: + distributed_state[key] = full_state_dict[key] + return distributed_state + + +def _train_step(model, batch): + """Run a single forward + backward + optimizer step. + + Returns: + Tuple of (loss value, dict of gradient norms, dict of weight change norms). + """ + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + # Snapshot weights before step + pre_weights = {n: p.detach().clone() for n, p in model.named_parameters()} + + optimizer.zero_grad() + outputs = model(**batch) + loss = outputs.loss + loss.backward() + + grad_norms = {} + for name, param in model.named_parameters(): + if param.grad is not None: + g = param.grad + if hasattr(g, "full_tensor"): + g = g.full_tensor() + grad_norms[name] = g.detach().float().norm().item() + + optimizer.step() + + # Measure weight changes + weight_changes = {} + for name, param in model.named_parameters(): + pre = pre_weights[name] + cur = param.detach() + if hasattr(pre, "full_tensor"): + pre = pre.full_tensor() + if hasattr(cur, "full_tensor"): + cur = cur.full_tensor() + weight_changes[name] = (cur.float() - pre.float()).norm().item() + + return loss.detach().item(), grad_norms, weight_changes + + +# --------------------------------------------------------------------------- +# Pytest entry points — launch torchrun subprocesses +# --------------------------------------------------------------------------- + + +def _run_torchrun(test_fn_name: str, port: int, nproc: int = 2): + """Run a named worker function via torchrun.""" + recipe_dir = str(Path(__file__).resolve().parent.parent) + script = str(Path(__file__).resolve()) + cmd = [ + "torchrun", + f"--nproc_per_node={nproc}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{port}", + script, + test_fn_name, + ] + result = subprocess.run( + cmd, + check=False, + text=True, + cwd=recipe_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=300, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"{test_fn_name} failed with exit code {result.returncode}") + + +@requires_2_gpus +def test_fsdp2_ep1(free_tcp_port): + """Test FSDP=2, EP=1: data-parallel training with all experts on each rank.""" + _run_torchrun("fsdp2_ep1", free_tcp_port, nproc=2) + + +@requires_2_gpus +def test_fsdp1_ep2(free_tcp_port): + """Test FSDP=1, EP=2: expert-parallel training without data parallelism.""" + _run_torchrun("fsdp1_ep2", free_tcp_port, nproc=2) + + +# --------------------------------------------------------------------------- +# Distributed workers executed via torchrun +# --------------------------------------------------------------------------- + + +def _worker_fsdp2_ep1(): + """FSDP=2, EP=1: weights sharded by FSDP, all experts on each rank. + + Uses a 2D device mesh (dp=2, ep=1) so that DTensor multi-dimensional + placement logic is exercised even though the EP dimension is trivial. + + 1. Init distributed, create 2D device mesh with ep=1. + 2. Create model with EP=1, set EP groups on the trivial EP sub-mesh. + 3. Wrap with FSDP2 on the DP sub-mesh. + 4. Run one training step, verify loss/gradients are finite and weights update. + """ + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import fully_shard + + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.cuda.set_device(device) + torch.distributed.init_process_group(backend="nccl", device_id=device) + + ep_size = 1 + dp_size = dist_config.world_size + device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep")) + + config = create_small_mixtral_config(expert_parallel_size=ep_size) + torch.manual_seed(0) + model = NVMixtralForCausalLM(config).to(dtype=torch.bfloat16, device=device) + + # EP setup with trivial (size-1) EP sub-mesh + ep_mesh = device_mesh["ep"] + ep_group = ep_mesh.get_group() + model.model.set_ep_groups(ep_group, ep_mesh) + + # FSDP2 wrapping on DP sub-mesh + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) + + model.train() + batch = get_dummy_batch(config.vocab_size, device=str(device)) + + loss_val, grad_norms, weight_changes = _train_step(model, batch) + + assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}" + assert len(grad_norms) > 0, "No gradients computed" + for name, gnorm in grad_norms.items(): + assert torch.isfinite(torch.tensor(gnorm)), f"Gradient for {name} is not finite: {gnorm}" + assert any(wc > 0 for wc in weight_changes.values()), "No weights updated after optimizer step" + + torch.distributed.destroy_process_group() + + +def _worker_fsdp1_ep2(): + """FSDP=1, EP=2: experts sharded across ranks, trivial data parallelism. + + Uses a 2D device mesh (dp=1, ep=2) so that DTensor multi-dimensional + placement logic is exercised even though the DP dimension is trivial. + + 1. Init distributed, create 2D device mesh with dp=1. + 2. Create full EP=1 model for reference weights. + 3. Create EP=2 model, set EP groups (DTensor annotations), load via distribute_tensor. + 4. Wrap with FSDP2 on the trivial DP sub-mesh. + 5. Run one training step, verify loss/gradients are finite and weights update. + """ + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import fully_shard + + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.cuda.set_device(device) + torch.distributed.init_process_group(backend="nccl", device_id=device) + + ep_size = dist_config.world_size + dp_size = 1 + device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep")) + + ep_mesh = device_mesh["ep"] + ep_group = ep_mesh.get_group() + + # Get reference weights from a full EP=1 model + config_full = create_small_mixtral_config(expert_parallel_size=1) + torch.manual_seed(0) + full_model = NVMixtralForCausalLM(config_full).to(dtype=torch.bfloat16, device="cpu") + full_state_dict = {k: v.clone() for k, v in full_model.state_dict().items()} + del full_model + + # Create EP=2 model, set EP groups to create DTensor annotations, then load weights + config_ep = create_small_mixtral_config(expert_parallel_size=ep_size) + torch.manual_seed(0) + model = NVMixtralForCausalLM(config_ep).to(dtype=torch.bfloat16, device=device) + + # EP setup on EP sub-mesh first (creates DTensor annotations on expert weights) + model.model.set_ep_groups(ep_group, ep_mesh) + + # Load EP=1 weights — distribute_tensor handles expert sharding automatically + distributed_state = _distribute_state_dict(full_state_dict, model, device) + model.load_state_dict(distributed_state, strict=True) + + # FSDP2 wrapping on trivial (size-1) DP sub-mesh + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) + + model.train() + batch = get_dummy_batch(config_ep.vocab_size, device=str(device)) + + loss_val, grad_norms, weight_changes = _train_step(model, batch) + + assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}" + assert len(grad_norms) > 0, "No gradients computed" + for name, gnorm in grad_norms.items(): + assert torch.isfinite(torch.tensor(gnorm)), f"Gradient for {name} is not finite: {gnorm}" + assert any(wc > 0 for wc in weight_changes.values()), "No weights updated after optimizer step" + + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + test_name = sys.argv[1] + + workers = { + "fsdp2_ep1": _worker_fsdp2_ep1, + "fsdp1_ep2": _worker_fsdp1_ep2, + } + workers[test_name]() diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_train.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_train.py new file mode 100644 index 0000000000..887bf357e6 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_train.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random + +import pytest +import torch +from hydra import compose, initialize_config_dir +from modeling_mixtral_te import NVMixtralConfig, NVMixtralForCausalLM +from optimizer import get_parameter_groups_with_weight_decay +from train_fsdp2 import main as main_fsdp2 + + +@pytest.fixture(autouse=True) +def set_seed(): + """Set random seeds for reproducibility.""" + random.seed(42) + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + +def test_sanity_convergence_fsdp2_te_bshd(tmp_path, recipe_path): + """Test that FSDP2 training converges with BSHD format.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "config_kwargs.attn_input_format=bshd", + "use_sequence_packing=false", + "num_train_steps=80", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite" + + +def test_sanity_convergence_fsdp2_te_thd(tmp_path, recipe_path): + """Test that FSDP2 training converges with THD format.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "use_sequence_packing=true", + "config_kwargs.attn_input_format=thd", + "config_kwargs.self_attn_mask_type=padding_causal", + "dataset.max_seq_length=256", + "num_train_steps=80", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite" + + +def test_sanity_convergence_fsdp2_te_thd_grad_acc(tmp_path, recipe_path): + """Test FSDP2 training with THD format and gradient accumulation.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "use_sequence_packing=true", + "config_kwargs.attn_input_format=thd", + "config_kwargs.self_attn_mask_type=padding_causal", + "grad_acc_steps=2", + "num_train_steps=40", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite" + + +def test_train_fsdp2_fp32_master_weights_thd(tmp_path, recipe_path): + """Test FSDP2 convergence with FP32 master weights and THD sequence packing.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "use_fp32_master_weights=true", + "fp8_config.enabled=false", + "use_sequence_packing=true", + "config_kwargs.attn_input_format=thd", + "config_kwargs.self_attn_mask_type=padding_causal", + "num_train_steps=40", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite" + + +def _create_tiny_config(**overrides) -> NVMixtralConfig: + """Create a small Mixtral config for fast grouping tests.""" + kwargs = dict( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + num_local_experts=4, + num_experts_per_tok=2, + vocab_size=256, + max_position_embeddings=128, + rms_norm_eps=1e-5, + initializer_range=0.02, + attn_input_format="bshd", + self_attn_mask_type="causal", + ) + kwargs.update(overrides) + return NVMixtralConfig(**kwargs) + + +def test_weight_decay_grouping(): + """Test that weight decay grouping correctly separates decay and no-decay params.""" + model = NVMixtralForCausalLM(_create_tiny_config()) + + param_groups = get_parameter_groups_with_weight_decay(model, weight_decay=0.1) + decay_group = param_groups[0] + no_decay_group = param_groups[1] + + assert decay_group["weight_decay"] == 0.1 + assert no_decay_group["weight_decay"] == 0.0 + assert len(decay_group["params"]) > 0 + assert len(no_decay_group["params"]) > 0 + + no_decay_set = {id(p) for p in no_decay_group["params"]} + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + if param.dim() == 1 or name.endswith(".bias"): + assert id(param) in no_decay_set, f"1D/bias param '{name}' should be in no-decay group" + + +def test_weight_decay_skip_embeddings(): + """Test that skip_embeddings=True excludes embedding weights from weight decay.""" + model = NVMixtralForCausalLM(_create_tiny_config()) + + param_groups = get_parameter_groups_with_weight_decay(model, weight_decay=0.1, skip_embeddings=True) + no_decay_set = {id(p) for p in param_groups[1]["params"]} + + for name, param in model.named_parameters(): + if "embed" in name.lower() and param.requires_grad: + assert id(param) in no_decay_set, f"Embedding param '{name}' should be in no-decay group" diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_train_two_gpu.py new file mode 100644 index 0000000000..f96af22778 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tests/test_train_two_gpu.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-GPU training tests for OpenGenome2. + +These tests validate that FSDP2 training works correctly with multiple GPUs. +They require at least 2 GPUs to run and will be skipped on single-GPU machines. +""" + +import subprocess + +import pytest +import torch + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + +# TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed. +requires_datacenter_hardware = pytest.mark.skipif( + not torch.cuda.is_available() + or not any( + gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"] + ), + reason="Test requires datacenter hardware (H100, H200, B100, B200, B300)", +) + + +def run_train_cmd(cmd, recipe_path): + """Run a training command and check for errors.""" + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + cwd=str(recipe_path), + ) + + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}") + + +@requires_multi_gpu +def test_multi_gpu_train_fsdp2(tmp_path, recipe_path): + """Test FSDP2 training on 2 GPUs. + + Validates that FSDP2 launches, shards the model, and completes training + without errors on multiple GPUs. + """ + run_train_cmd( + [ + "torchrun", + "--nproc_per_node", + "2", + "--standalone", + "train_fsdp2.py", + "--config-name", + "L0_sanity", + "num_train_steps=4", + ], + recipe_path, + ) + + +@requires_multi_gpu +def test_multi_gpu_train_fsdp2_with_checkpointing(tmp_path, recipe_path): + """Test FSDP2 training on 2 GPUs with checkpoint saving. + + Validates that sharded checkpoints are created correctly across + multiple processes without race conditions. + """ + run_train_cmd( + [ + "torchrun", + "--nproc_per_node", + "2", + "--standalone", + "train_fsdp2.py", + "--config-name", + "L0_sanity", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.save_every_n_steps=5", + "dataset.use_stateful_dataloader=true", + ], + recipe_path, + ) + + ckpt_dir = tmp_path / "train_fsdp2" + assert ckpt_dir.exists(), f"Checkpoint directory not created: {ckpt_dir}" + assert (ckpt_dir / "step_5").exists(), "Checkpoint at step 5 not found" diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/README.md b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/README.md new file mode 100644 index 0000000000..19fdbf195c --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/README.md @@ -0,0 +1,4 @@ +# A transformers-based tokenizer for DNA sequences + +This tokenizer is similar to the one used in the Evo-2 DNA model, and tokenizes DNA sequences using the `ord()` +function. diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/special_tokens_map.json b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/special_tokens_map.json new file mode 100644 index 0000000000..bb5ae2f3d5 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/special_tokens_map.json @@ -0,0 +1,30 @@ +{ + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/tokenizer.json b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/tokenizer.json new file mode 100644 index 0000000000..8d2cdba32d --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/tokenizer.json @@ -0,0 +1,398 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Split", + "pattern": { + "String": "" + }, + "behavior": "Isolated", + "invert": false + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "special_tokens": { + "": { + "id": "", + "ids": [ + 2 + ], + "tokens": [ + "" + ] + }, + "": { + "id": "", + "ids": [ + 0 + ], + "tokens": [ + "" + ] + } + } + }, + "decoder": { + "type": "Fuse" + }, + "model": { + "type": "WordLevel", + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3, + "\u0004": 4, + "\u0005": 5, + "\u0006": 6, + "\u0007": 7, + "\b": 8, + "\t": 9, + "\n": 10, + "\u000b": 11, + "\f": 12, + "\r": 13, + "\u000e": 14, + "\u000f": 15, + "\u0010": 16, + "\u0011": 17, + "\u0012": 18, + "\u0013": 19, + "\u0014": 20, + "\u0015": 21, + "\u0016": 22, + "\u0017": 23, + "\u0018": 24, + "\u0019": 25, + "\u001a": 26, + "\u001b": 27, + "\u001c": 28, + "\u001d": 29, + "\u001e": 30, + "\u001f": 31, + " ": 32, + "!": 33, + "\"": 34, + "#": 35, + "$": 36, + "%": 37, + "&": 38, + "'": 39, + "(": 40, + ")": 41, + "*": 42, + "+": 43, + ",": 44, + "-": 45, + ".": 46, + "/": 47, + "0": 48, + "1": 49, + "2": 50, + "3": 51, + "4": 52, + "5": 53, + "6": 54, + "7": 55, + "8": 56, + "9": 57, + ":": 58, + ";": 59, + "<": 60, + "=": 61, + ">": 62, + "?": 63, + "@": 64, + "A": 65, + "B": 66, + "C": 67, + "D": 68, + "E": 69, + "F": 70, + "G": 71, + "H": 72, + "I": 73, + "J": 74, + "K": 75, + "L": 76, + "M": 77, + "N": 78, + "O": 79, + "P": 80, + "Q": 81, + "R": 82, + "S": 83, + "T": 84, + "U": 85, + "V": 86, + "W": 87, + "X": 88, + "Y": 89, + "Z": 90, + "[": 91, + "\\": 92, + "]": 93, + "^": 94, + "_": 95, + "`": 96, + "a": 97, + "b": 98, + "c": 99, + "d": 100, + "e": 101, + "f": 102, + "g": 103, + "h": 104, + "i": 105, + "j": 106, + "k": 107, + "l": 108, + "m": 109, + "n": 110, + "o": 111, + "p": 112, + "q": 113, + "r": 114, + "s": 115, + "t": 116, + "u": 117, + "v": 118, + "w": 119, + "x": 120, + "y": 121, + "z": 122, + "{": 123, + "|": 124, + "}": 125, + "~": 126, + "": 127, + "€": 128, + "": 129, + "‚": 130, + "ƒ": 131, + "„": 132, + "…": 133, + "†": 134, + "‡": 135, + "ˆ": 136, + "‰": 137, + "Š": 138, + "‹": 139, + "Œ": 140, + "": 141, + "Ž": 142, + "": 143, + "": 144, + "‘": 145, + "’": 146, + "“": 147, + "”": 148, + "•": 149, + "–": 150, + "—": 151, + "˜": 152, + "™": 153, + "š": 154, + "›": 155, + "œ": 156, + "": 157, + "ž": 158, + "Ÿ": 159, + " ": 160, + "¡": 161, + "¢": 162, + "£": 163, + "¤": 164, + "¥": 165, + "¦": 166, + "§": 167, + "¨": 168, + "©": 169, + "ª": 170, + "«": 171, + "¬": 172, + "­": 173, + "®": 174, + "¯": 175, + "°": 176, + "±": 177, + "²": 178, + "³": 179, + "´": 180, + "µ": 181, + "¶": 182, + "·": 183, + "¸": 184, + "¹": 185, + "º": 186, + "»": 187, + "¼": 188, + "½": 189, + "¾": 190, + "¿": 191, + "À": 192, + "Á": 193, + "Â": 194, + "Ã": 195, + "Ä": 196, + "Å": 197, + "Æ": 198, + "Ç": 199, + "È": 200, + "É": 201, + "Ê": 202, + "Ë": 203, + "Ì": 204, + "Í": 205, + "Î": 206, + "Ï": 207, + "Ð": 208, + "Ñ": 209, + "Ò": 210, + "Ó": 211, + "Ô": 212, + "Õ": 213, + "Ö": 214, + "×": 215, + "Ø": 216, + "Ù": 217, + "Ú": 218, + "Û": 219, + "Ü": 220, + "Ý": 221, + "Þ": 222, + "ß": 223, + "à": 224, + "á": 225, + "â": 226, + "ã": 227, + "ä": 228, + "å": 229, + "æ": 230, + "ç": 231, + "è": 232, + "é": 233, + "ê": 234, + "ë": 235, + "ì": 236, + "í": 237, + "î": 238, + "ï": 239, + "ð": 240, + "ñ": 241, + "ò": 242, + "ó": 243, + "ô": 244, + "õ": 245, + "ö": 246, + "÷": 247, + "ø": 248, + "ù": 249, + "ú": 250, + "û": 251, + "ü": 252, + "ý": 253, + "þ": 254, + "ÿ": 255 + }, + "unk_token": "" + } +} diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/tokenizer_config.json b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/tokenizer_config.json new file mode 100644 index 0000000000..5e189bec32 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/tokenizers/nucleotide_fast_tokenizer/tokenizer_config.json @@ -0,0 +1,44 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "tokenizer_class": "PreTrainedTokenizerFast", + "unk_token": "" +} diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/train_fsdp2.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/train_fsdp2.py new file mode 100644 index 0000000000..fdb3f5d0e4 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/train_fsdp2.py @@ -0,0 +1,480 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenGenome2 FSDP2 training script for Mixtral with TransformerEngine.""" + +import gc +import logging +import random +from contextlib import nullcontext +from pathlib import Path +from typing import Iterable + +import hydra +import numpy as np +import torch + + +try: + import nvdlfw_inspect.api as debug_api + + HAS_NVDLFW_INSPECT = True +except ImportError: + debug_api = None + HAS_NVDLFW_INSPECT = False +import transformer_engine +import transformer_engine.pytorch +from checkpoint import ( + _ckpt_futures, + load_checkpoint_fsdp2, + save_checkpoint_fsdp2, + save_final_model_fsdp2, + should_save_checkpoint, +) +from dataset import create_bshd_dataloader, create_thd_dataloader +from distributed_config import DistributedConfig +from fp8_debugging import initialize_fp8_debugging +from modeling_mixtral_te import NVMixtralConfig, NVMixtralForCausalLM +from omegaconf import DictConfig, OmegaConf +from optimizer import get_parameter_groups_with_weight_decay +from perf_logger import PerfLogger +from scheduler import get_cosine_annealing_schedule_with_warmup +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard +from torch.optim import AdamW +from transformer_engine.common.recipe import Format +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM +from validation import run_validation + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +# --- BEGIN COPIED FUNCTIONS --- +# _build_dispatcher and clip_grad_norm_ep_aware are copied from: +# bionemo-recipes/recipes/mixtral_native_te/train_fsdp2.py +# Kept inline for recipe self-containment (KISS over DRY). +# --- END COPIED FUNCTIONS --- + + +def _build_dispatcher(args: DictConfig, config: NVMixtralConfig): + """Build the requested token dispatcher for EP runs. + + Returns None for the default alltoall dispatcher (handled natively by TE). + Returns a FusedTokenRouter when token_dispatcher=fused_deepep and deep_ep is available. + Falls back to alltoall (returns None) when fused_deepep is requested but unavailable, + if token_dispatcher_fallback=alltoall is set. + """ + token_dispatcher = str(getattr(args, "token_dispatcher", "alltoall")) + fallback_dispatcher = str(getattr(args, "token_dispatcher_fallback", "error")) + if config.expert_parallel_size == 1: + return None + if token_dispatcher == "alltoall": + return None + if token_dispatcher != "fused_deepep": + raise ValueError(f"Unsupported token_dispatcher: {token_dispatcher!r}. Expected 'alltoall' or 'fused_deepep'.") + + try: + from fused_token_router import FusedTokenRouter + + return FusedTokenRouter( + num_experts=config.num_local_experts, + num_local_experts=config.num_local_experts // config.expert_parallel_size, + hidden_size=config.hidden_size, + ep_size=config.expert_parallel_size, + ) + except ImportError as exc: + if fallback_dispatcher == "alltoall": + logger.warning("Fused DeepEP dispatcher unavailable (%s). Falling back to AllToAllTokenDispatcher.", exc) + return None + raise + + +def clip_grad_norm_ep_aware(params: Iterable[torch.nn.Parameter], max_norm: float, ep_size: int) -> torch.Tensor: + """Clip gradient norms, handling expert parallelism (DTensor parameters on different meshes). + + When ep_size > 1, parameters may be DTensors on different device meshes (dp vs ep), + which prevents torch.nn.utils.clip_grad_norm_ from stacking norms across them. + This function computes norms per-parameter from local shards and clips accordingly. + + Args: + params: Model parameters (may include DTensor expert weights). + max_norm: Maximum gradient norm. + ep_size: Expert parallelism size. If 1, falls back to standard clip_grad_norm_. + + Returns: + Total gradient norm (approximate for ep_size > 1). + """ + if ep_size == 1: + return torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm) + + # Compute per-param local norms, handling DTensor by extracting the local shard. + param_list = list(params) + norms = [] + for p in param_list: + if p.grad is None: + continue + g = p.grad.detach() + if hasattr(g, "to_local"): + g = g.to_local() # Extract local shard of DTensor gradient + norms.append(g.float().norm()) + + if not norms: + return torch.tensor(0.0) + + total_norm = torch.stack(norms).norm() + clip_coef = max_norm / (total_norm + 1e-6) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for p in param_list: + if p.grad is not None: + p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device)) + + return total_norm + + +# --- END COPIED FUNCTIONS --- + + +def set_seed(seed: int) -> None: + """Set random seeds for reproducibility. + + For FSDP2/DTensor, ALL ranks must use the SAME seed to ensure weights + are initialized identically before sharding. + + Args: + seed: Random seed (same on all ranks). + """ + random.seed(seed) + np.random.seed(seed) # noqa: NPY002 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + logger.info("Set seed to %s (same on all ranks for FSDP2)", seed) + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") +def main(args: DictConfig) -> float | None: + """Train OpenGenome2 Mixtral with TE layers using FSDP2. + + Returns: + float: The minimum loss value observed during training. + """ + # --- Distributed Setup --- + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + # Set random seeds (same seed on ALL ranks for FSDP2/DTensor) + seed = getattr(args, "seed", 42) + set_seed(seed) + + # TE Debug feature logging - MUST be done BEFORE FSDP wrapping + if args.fp8_stats_config.enabled: + initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) + + ep_size = args.expert_parallel_size + if dist_config.world_size % ep_size != 0: + raise ValueError( + f"world_size ({dist_config.world_size}) must be divisible by expert_parallel_size ({ep_size})" + ) + dp_size = dist_config.world_size // ep_size + device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep")) + + # --- Model Configuration --- + # Create quantization recipes -- only used if FP8 is enabled in the config. + fp8_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + + # --- Model Initialization --- + if args.use_te: + config_class = NVMixtralConfig + model_class = NVMixtralForCausalLM + else: + config_class = MixtralConfig + model_class = MixtralForCausalLM + + # Determine dtype for model initialization + use_fp32_master_weights = getattr(args, "use_fp32_master_weights", False) + model_dtype = torch.float32 if use_fp32_master_weights else torch.bfloat16 + + if use_fp32_master_weights: + logger.info("FP32 master weights enabled: model init in FP32") + + config_kwargs = OmegaConf.to_container(args.config_kwargs, resolve=True) if args.config_kwargs else {} + # Pass expert_parallel_size to config so the model initializes with the correct + # num_local_experts = num_experts // expert_parallel_size per rank. + if args.use_te: + config_kwargs["expert_parallel_size"] = ep_size + + config = config_class.from_pretrained(args.config_name_or_path, dtype=model_dtype, **config_kwargs) + + logger.info( + "Init config: std=%s, num_layers=%s, experts=%s, top_k=%s", + getattr(config, "initializer_range", 0.02), + getattr(config, "num_hidden_layers", None), + getattr(config, "num_local_experts", None), + getattr(config, "num_experts_per_tok", None), + ) + + dispatcher = _build_dispatcher(args, config) if args.use_te else None + with ( + torch.device("meta") if args.use_meta_device else nullcontext(), + transformer_engine.pytorch.quantized_model_init( + recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs + ), + ): + if dispatcher is not None: + model = model_class(config, dispatcher=dispatcher) + else: + model = model_class(config) + + logger.info("Initialized Model:\n%s", model) + + # --- Expert Parallelism Setup --- + # Expert parallelism setup — MUST happen before fully_shard() + # Wraps expert weights as DTensors with Shard(0) on the expert dimension. + if args.use_te and ep_size > 1: + ep_mesh = device_mesh["ep"] + ep_group = ep_mesh.get_group() + model.model.set_ep_groups(ep_group, ep_mesh) + + # Create MixedPrecisionPolicy for FSDP when using FP32 master weights + mp_policy = None + if use_fp32_master_weights: + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=torch.bfloat16, + cast_forward_inputs=False, # if True, will downcast top_embeddings to param dtype (bf16) + ) + logger.info( + "MixedPrecisionPolicy: param_dtype=bf16, reduce_dtype=fp32, output_dtype=bf16, cast_forward_inputs=False" + ) + + # --- Distributed Wrapping (FSDP2) --- + # Shard transformer layers with FSDP + if mp_policy is None: + mp_policy = MixedPrecisionPolicy() + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + + if args.use_meta_device and isinstance(model, NVMixtralForCausalLM): + model.init_empty_weights() + elif args.use_meta_device and isinstance(model, MixtralForCausalLM): + model.to_empty(device=device) + model.apply(model._init_weights) + + # Assign names to layers so debug API can identify them + if args.fp8_stats_config.enabled and HAS_NVDLFW_INSPECT: + debug_api.infer_and_assign_layer_names(model) + + # --- Optimizer & Scheduler --- + # Create optimizer + adamw_kwargs = OmegaConf.to_container(args.adamw_kwargs, resolve=True) + + use_wd_grouping = getattr(args, "use_weight_decay_grouping", True) + if use_wd_grouping: + weight_decay = adamw_kwargs.pop("weight_decay", 0.1) + skip_embedding_wd = getattr(args, "skip_embedding_weight_decay", False) + param_groups = get_parameter_groups_with_weight_decay( + model=model, + weight_decay=weight_decay, + skip_embeddings=skip_embedding_wd, + ) + optimizer = AdamW(param_groups, **adamw_kwargs) # type: ignore + logger.info("Weight decay grouping enabled: wd=%s, skip_embeddings=%s", weight_decay, skip_embedding_wd) + else: + optimizer = AdamW(model.parameters(), **adamw_kwargs) # type: ignore + logger.info("Weight decay grouping disabled: wd=%s for all params", adamw_kwargs.get("weight_decay", 0.1)) + + scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + # --- Data Loading --- + if args.use_sequence_packing: + train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) + else: + train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) + + if args.use_torch_compile: + model = torch.compile(model) + + # --- Checkpoint Resume --- + # Load checkpoint if resuming + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None + if args.checkpoint.resume_from_checkpoint and ckpt_path: + logger.info("Attempting to load checkpoint from %s", ckpt_path) + model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + dist_config=dist_config, + dataloader=train_dataloader, # type: ignore[arg-type] + process_group=device_mesh.get_group("dp"), + expert_parallel_size=ep_size, + ) + logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) + else: + logger.info("No checkpoint to load, starting from scratch") + start_step = 0 + epoch = 0 + + perf_logger = PerfLogger(dist_config, args) + + # Setup validation if enabled + val_config = getattr(args, "validation", None) + val_enabled = val_config is not None and getattr(val_config, "enabled", False) + val_dataloader = None + + if val_enabled: + val_data_path = getattr(val_config, "data_path", None) + if val_data_path: + logger.info("Setting up validation dataloader from %s", val_data_path) + val_dataset_kwargs = OmegaConf.to_container(args.dataset, resolve=True) + val_dataset_kwargs["load_dataset_kwargs"] = { + "path": "json", + "data_files": val_data_path, + "split": "train", + "streaming": True, + } + val_dataset_kwargs["use_stateful_dataloader"] = False + val_dataset_kwargs["num_workers"] = 0 + + if hasattr(val_config, "micro_batch_size") and val_config.micro_batch_size is not None: + val_dataset_kwargs["micro_batch_size"] = val_config.micro_batch_size + + if args.use_sequence_packing: + val_dataloader, _ = create_thd_dataloader(dist_config, **val_dataset_kwargs) + else: + val_dataloader, _ = create_bshd_dataloader(dist_config, **val_dataset_kwargs) + + logger.info( + "Validation enabled: every %s steps, %s batches", val_config.eval_interval, val_config.num_batches + ) + else: + logger.warning("Validation enabled but no data_path specified, skipping validation") + val_enabled = False + + # --- Training Loop --- + gc.collect() + torch.cuda.empty_cache() + + # Training loop + logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) + step = start_step + micro_step = 0 + + if train_dataloader is None: + raise RuntimeError("Expected train_dataloader to be initialized before training.") + + while step < args.num_train_steps: + for batch in train_dataloader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + + micro_step += 1 + + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + outputs = model(**batch) + + loss = outputs.loss / args.grad_acc_steps + loss.backward() + + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + + # Gradient accumulation - only step optimizer after accumulating gradients + if micro_step % args.grad_acc_steps == 0: + micro_step = 0 + + total_norm = clip_grad_norm_ep_aware(model.parameters(), max_norm=1.0, ep_size=ep_size) + + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + perf_logger.log_step( + step=step, + grad_norm=total_norm, + lr=optimizer.param_groups[0]["lr"], + ) + + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): + save_checkpoint_fsdp2( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + step=step, + epoch=epoch, + dist_config=dist_config, + dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, + process_group=device_mesh.get_group("dp"), + expert_parallel_size=ep_size, + max_checkpoints=args.checkpoint.max_checkpoints, + async_save=args.checkpoint.async_save, + ) + + # Run validation at specified interval + if val_enabled and val_dataloader is not None and step > 0 and step % val_config.eval_interval == 0: + try: + val_metrics = run_validation( + model=model, + val_dataloader=val_dataloader, + num_batches=val_config.num_batches, + device=device, + dist_config=dist_config, + ) + perf_logger.log_validation(step, val_metrics) + except Exception as e: + logger.error(f"Validation failed at step {step}: {e}") + torch.distributed.barrier() + + step += 1 + if step >= args.num_train_steps: + break + + epoch += 1 + dataset_or_sampler.set_epoch(epoch) + + # --- Cleanup --- + # Save final model + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_fsdp2( + model=model, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + # Wait for any outstanding async checkpoint saves + if args.checkpoint.async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None: + _ckpt_futures["fsdp2"].result() + + perf_logger.finish() + torch.distributed.destroy_process_group() + + return perf_logger.min_loss + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/opengenome2_mixtral_native_te/validation.py b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/validation.py new file mode 100644 index 0000000000..a2731fa219 --- /dev/null +++ b/bionemo-recipes/recipes/opengenome2_mixtral_native_te/validation.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validation utilities for OpenGenome2 training.""" + +import logging + +import torch +import transformer_engine.pytorch +from distributed_config import DistributedConfig + + +logger = logging.getLogger(__name__) + + +@torch.no_grad() +def run_validation( + model: torch.nn.Module, + val_dataloader, + num_batches: int, + device: torch.device, + dist_config: DistributedConfig, +) -> dict: + """Run validation and aggregate metrics across ranks.""" + model.eval() + + total_loss = 0.0 + total_weighted_loss = 0.0 + total_tokens = 0 + num_evaluated = 0 + + val_iter = iter(val_dataloader) + + for _ in range(num_batches): + try: + batch = next(val_iter) + except StopIteration: + break + + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + try: + with transformer_engine.pytorch.autocast(enabled=False): + outputs = model(**batch) + loss = outputs.loss + if loss is not None: + loss_val = loss.item() + total_loss += loss_val + labels = batch.get("labels") + num_tokens = (labels != -100).sum().item() if labels is not None else batch["input_ids"].numel() + total_tokens += num_tokens + total_weighted_loss += loss_val * num_tokens + num_evaluated += 1 + except Exception as exc: + logger.warning("Validation forward pass failed on rank %s: %s", dist_config.rank, exc) + continue + + if num_evaluated == 0: + raise RuntimeError(f"All {num_batches} validation batches failed on rank {dist_config.rank}") + + torch.distributed.barrier() + + loss_tensor = torch.tensor( + [total_loss, float(total_tokens), float(num_evaluated), total_weighted_loss], + device=device, + ) + torch.distributed.all_reduce(loss_tensor) + global_loss = loss_tensor[0].item() + global_tokens = int(loss_tensor[1].item()) + global_batches = int(loss_tensor[2].item()) + global_weighted_loss = loss_tensor[3].item() + + avg_loss = global_loss / max(global_batches, 1) + perplexity = torch.exp(torch.tensor(avg_loss)).item() + megatron_style_loss = global_weighted_loss / max(global_tokens, 1) + megatron_ppl = torch.exp(torch.tensor(megatron_style_loss)).item() + + if dist_config.rank == 0: + logger.info( + "[VAL] HF loss=%.4f (ppl=%.2f) | Megatron loss=%.4f (ppl=%.2f) | batches=%d tokens=%d", + avg_loss, + perplexity, + megatron_style_loss, + megatron_ppl, + global_batches, + global_tokens, + ) + + model.train() + + return { + "val_loss": avg_loss, + "val_ppl": perplexity, + "val_loss_megatron": megatron_style_loss, + "val_ppl_megatron": megatron_ppl, + "val_tokens": global_tokens, + "val_batches": global_batches, + } diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index 92d65903cb..bafbd66287 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -164,6 +164,8 @@ def _compare_file_contents(source_file: Path, dest_file: Path, source_display: s "bionemo-recipes/recipes/esm2_native_te/collator.py", "bionemo-recipes/recipes/llama3_native_te/collator.py", "bionemo-recipes/recipes/opengenome2_llama_native_te/collator.py", + "bionemo-recipes/recipes/mixtral_native_te/collator.py", + "bionemo-recipes/recipes/opengenome2_mixtral_native_te/collator.py", "bionemo-recipes/recipes/esm2_peft_te/collator.py", ], "bionemo-recipes/models/esm2/state.py": [ @@ -195,6 +197,10 @@ def _compare_file_contents(source_file: Path, dest_file: Path, source_display: s "bionemo-recipes/models/esm2/LICENSE": [ "bionemo-recipes/recipes/vllm_inference/esm2/LICENSE", ], + "bionemo-recipes/models/mixtral/modeling_mixtral_te.py": [ + "bionemo-recipes/recipes/mixtral_native_te/modeling_mixtral_te.py", + "bionemo-recipes/recipes/opengenome2_mixtral_native_te/modeling_mixtral_te.py", + ], # CodonFM model -> recipe sync "bionemo-recipes/models/codonfm/modeling_codonfm_te.py": [ "bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py",