Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def preprocess_for_fine_grained_offloading(self):
min_offloaded_tensor_size=self.config.min_offloaded_tensor_size,
delta_offload_bytes_across_pp_ranks=self.config.delta_offload_bytes_across_pp_ranks,
activation_offload_fraction=self.config.activation_offload_fraction,
max_inflight_offloads=self.config.fine_grained_offloading_max_inflight_offloads,
)
if self.disable_param_offloading:
for param in self.decoder.parameters():
Expand Down
11 changes: 11 additions & 0 deletions megatron/core/models/mamba/mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.moe.paged_stash import paged_stash_init_chunk_handler
from megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlock,
mtp_on_this_rank,
Expand Down Expand Up @@ -314,6 +315,7 @@ def preprocess_for_fine_grained_offloading(self):
min_offloaded_tensor_size=self.config.min_offloaded_tensor_size,
delta_offload_bytes_across_pp_ranks=self.config.delta_offload_bytes_across_pp_ranks,
activation_offload_fraction=self.config.activation_offload_fraction,
max_inflight_offloads=self.config.fine_grained_offloading_max_inflight_offloads,
)
if self.disable_param_offloading:
for param in self.decoder.parameters():
Expand All @@ -326,6 +328,12 @@ def preprocess_for_fine_grained_offloading(self):
off_interface.mark_not_offload(param)
self.disable_param_offloading = False

def preprocess_for_paged_stash(self):
"""Preprocess for paged stash."""
return paged_stash_init_chunk_handler(
vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage
)

def forward(
self,
input_ids: Tensor,
Expand Down Expand Up @@ -354,6 +362,9 @@ def forward(
if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()

if self.config.moe_paged_stash:
self.preprocess_for_paged_stash()

inference_context = deprecate_inference_params(inference_context, inference_params)

in_inference_mode = inference_context is not None and not self.training
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

from collections import deque
from collections import defaultdict, deque
from contextlib import nullcontext
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

import torch
from torch.autograd.graph import saved_tensors_hooks
Expand Down Expand Up @@ -665,6 +665,7 @@ def init_model_chunk_offload_handler(
min_offloaded_tensor_size=1024 * 1024,
delta_offload_bytes_across_pp_ranks=0,
activation_offload_fraction: float = 1.0,
max_inflight_offloads: Optional[int] = None,
):
"""
Initialize a chunk offload handler for a model chunk (microbatch).
Expand All @@ -677,6 +678,9 @@ def init_model_chunk_offload_handler(
delta_offload_bytes_across_pp_ranks:
Difference of offload bytes across PP ranks to balance the offload load.
activation_offload_fraction: Fraction of eligible groups to offload, in range [0, 1].
max_inflight_offloads: If set, cap pending offloads per group name before main
wait_event; see ``fine_grained_offloading_max_inflight_offloads`` on
``TransformerConfig``.
"""
if not self._is_warmup:
return
Expand All @@ -700,7 +704,11 @@ def init_model_chunk_offload_handler(
self.flush()

# Use shared CPU tensor pool for better reuse across chunks
cur_chunk = ChunkOffloadHandler(min_offloaded_tensor_size, self._cpu_tensor_pool)
cur_chunk = ChunkOffloadHandler(
min_offloaded_tensor_size,
self._cpu_tensor_pool,
max_inflight_offloads=max_inflight_offloads,
)
debug_rank(f"init_model_chunk_offload_handler {cur_chunk}")
self._stages[cur_vpp_rank].append(cur_chunk)
# For the last stage, push immediately and flush
Expand Down Expand Up @@ -824,7 +832,12 @@ def reload(self, state, non_blocking=None):
self.cpu_tensor_pool.free(cpu_backup)
return gpu_tensor

def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool):
def __init__(
self,
min_offloaded_tensor_size,
cpu_tensor_pool,
max_inflight_offloads: Optional[int] = None,
):
self.do_offload = True

# Group management for batching offload/reload operations
Expand All @@ -847,6 +860,10 @@ def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool):
self.min_offloaded_tensor_size = min_offloaded_tensor_size
self.cpu_tensor_pool = cpu_tensor_pool
self.is_warmup = True
# Max per-group-name inflight offloads not yet joined on the main stream (None = off).
self._max_inflight_offloads = max_inflight_offloads
# group_name -> FIFO of offload events for that name (same cap for every name).
self._offload_pending_by_name: Dict[str, deque] = defaultdict(deque)

def reset(self):
"""Reset the chunk offload handler."""
Expand All @@ -855,6 +872,9 @@ def reset(self):
self._groups_to_reload = []
self._tensor_count_current_group = 0
self._reloading_group = []
# Clear the pending-event FIFO at iter boundary so we never wait on
# an event recorded in a previous (non-captured) iteration.
self._offload_pending_by_name.clear()

def find_group_with_name(
self, groups: list[OffloadTensorGroup], name: str, start_index: int = 0
Expand Down Expand Up @@ -953,6 +973,14 @@ def bulk_offload_group(self, group_to_offload):
group_to_offload.record_offload_event(self.d2h_stream)
self._groups_to_offload.pop()
nvtx_range_pop(nvtx_msg)
# Under full-iteration CG capture, the main stream may not wait on d2h
# events; optional max-inflight enqueues each group's offload event and
# has main wait on older events for this group name when its pending
# count exceeds the cap (each name is tracked separately).
if self._max_inflight_offloads is not None:
gname = group_to_offload._name
self._offload_pending_by_name[gname].append(group_to_offload._offload_event)
self._drain_offload_pending(gname)

def get_max_deduplicated_groups(self):
"""Get the maximum number of deduplicated groups."""
Expand Down Expand Up @@ -1036,6 +1064,18 @@ def bulk_offload(self, name, forced_released_tensors):
release_tensor.record_stream(cur_stream)
release_tensor.untyped_storage().resize_(0)

def _drain_offload_pending(self, group_name: str) -> None:
"""For ``group_name``, have the main stream wait on older D2H events
when that name's pending count exceeds ``_max_inflight_offloads``
(same cap for every name; 0 = wait on each commit for that name)."""
if self._max_inflight_offloads is None:
return
cur = torch.cuda.current_stream()
q = self._offload_pending_by_name[group_name]
while len(q) > self._max_inflight_offloads:
old_evt = q.popleft()
cur.wait_event(old_evt)

def on_group_commit_forward(self, name, forced_released_tensors):
"""Called at the end of a layer group's forward pass to trigger offloading."""
if not self.do_offload:
Expand Down Expand Up @@ -1304,6 +1344,7 @@ def init_chunk_handler(
min_offloaded_tensor_size,
delta_offload_bytes_across_pp_ranks,
activation_offload_fraction,
max_inflight_offloads: Optional[int] = None,
):
"""Initialize the chunk handler, called at the start of a microbatch forward pass."""
PipelineOffloadManager.get_instance().init_model_chunk_offload_handler(
Expand All @@ -1313,6 +1354,7 @@ def init_chunk_handler(
min_offloaded_tensor_size,
delta_offload_bytes_across_pp_ranks,
activation_offload_fraction,
max_inflight_offloads=max_inflight_offloads,
)

@staticmethod
Expand Down
19 changes: 18 additions & 1 deletion megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update
from megatron.core.ssm.ops.mamba_ssm import selective_state_update
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
FineGrainedActivationOffloadingInterface as off_interface,
)
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
Expand Down Expand Up @@ -271,6 +274,11 @@ def __init__(
]
setattr(self.in_proj.weight, "partition_sizes", in_proj_partition_sizes)

self.offload_ssm_training = (
self.config.fine_grained_activation_offloading
and "mamba_ssm_training" in self.config.offload_modules
)

if not self.use_mem_eff_path:
log_single_rank(
logger,
Expand Down Expand Up @@ -461,7 +469,16 @@ def forward(
y = self._ssm_prefill(zxBCdt, conv_state=conv_state, ssm_state=ssm_state)
else:
assert ssm_state is None
y = self._ssm_training(zxBCdt, packed_seq_params)
ssm_training_offload_manager = off_interface(
self.offload_ssm_training, zxBCdt, "mamba_ssm_training"
)
with ssm_training_offload_manager as zxBCdt:
y = self._ssm_training(zxBCdt, packed_seq_params)
y = ssm_training_offload_manager.group_offload(
y,
forced_released_tensors=[zxBCdt],
delay_offload=self.config.delay_offload_until_cuda_graph,
)

out, out_bias = self.out_proj(y)

Expand Down
19 changes: 16 additions & 3 deletions megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import make_weak_ref

def _set_skip_fp8_weight_update(value: bool):
"""Compat: old TE exposed FP8GlobalStateManager.set_skip_fp8_weight_update_tensor;
newer TE stores the flag on FP8GlobalStateManager.quantization_state."""
if hasattr(FP8GlobalStateManager, "set_skip_fp8_weight_update_tensor"):
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(value)
return
qstate = FP8GlobalStateManager.quantization_state
if qstate.skip_fp8_weight_update_tensor is None:
qstate.skip_fp8_weight_update_tensor = torch.empty(
1, dtype=torch.float32, device="cuda"
)
qstate.skip_fp8_weight_update_tensor.fill_(value)

HAVE_TE_GRAPHS = True
except:
HAVE_TE_GRAPHS = False
Expand Down Expand Up @@ -608,7 +621,7 @@ def forward(ctx, runner, is_first_microbatch, *inputs):
# Note that FP8GlobalStateManager.is_first_fp8_module() is inacccurate as each
# layer may be in its own fp8 context, when the fp8 recipe != delayed_scaling
if runner.is_first_layer and (runner.fp8_param_cache_updated != is_first_microbatch):
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(not is_first_microbatch)
_set_skip_fp8_weight_update(not is_first_microbatch)
runner.fp8_param_cache_updated = is_first_microbatch

runner.fwd_graph.replay()
Expand Down Expand Up @@ -735,13 +748,13 @@ def __init__(

if self.fp8_enabled:
self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
_set_skip_fp8_weight_update(False)

if self.fp4_enabled:
from megatron.core.fp4_utils import get_fp4_recipe # to avoid circular import

self.fp4_recipe = get_fp4_recipe(self.base_module.config)
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
_set_skip_fp8_weight_update(False)

def __str__(self):
return "%s; hid %s" % (
Expand Down
19 changes: 14 additions & 5 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,12 +808,18 @@ def _unsupported(reason):
if not isinstance(self.linear_fc2, te.pytorch.GroupedLinear):
return _unsupported(f"linear_fc2 is {type(self.linear_fc2).__name__}")

# Check activation: SwiGLU or quick GEGLU (ScaledClampedQGeGLU, TE >= 2.15)
# Check activation: SwiGLU, quick GEGLU (TE >= 2.15), or squared-relu (TE mxfp8-srelu branch)
# Use config.activation_func instead of self.activation_func because when
# use_te_activation_func is True, self.activation_func is a TE module, not the raw function.
if not self.config.gated_linear_unit:
if self.config.activation_func == squared_relu:
try:
from transformer_engine.pytorch.ops import ScaledSReLU # noqa: F401
except ImportError:
return _unsupported("squared_relu needs TE with ScaledSReLU op")
# squared_relu is non-GLU; skip the gated_linear_unit check
elif not self.config.gated_linear_unit:
return _unsupported("gated_linear_unit not enabled")
if self.config.activation_func == F.silu:
elif self.config.activation_func == F.silu:
pass # SwiGLU — supported
elif self.config.activation_func == quick_gelu:
try:
Expand Down Expand Up @@ -888,7 +894,7 @@ def _make_fused_ops(self) -> torch.nn.Module:
setattr(op, "bias", getattr(self.linear_fc1, "bias"))
ops.append(op)

# Activation and post-multiply probs (SwiGLU or clamped quick-GEGL)
# Activation and post-multiply probs (SwiGLU, clamped quick-GEGLU, or squared-ReLU)
glu_interleave = self.config.moe_mlp_glu_interleave_size
if self.config.activation_func == F.silu and self.config.gated_linear_unit:
op = te.pytorch.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave)
Expand All @@ -900,9 +906,12 @@ def _make_fused_ops(self) -> torch.nn.Module:
)
else:
op = te.pytorch.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave)
elif self.config.activation_func == squared_relu:
op = te.pytorch.ops.ScaledSReLU()
else:
raise RuntimeError(
"_make_fused_ops expected SwiGLU or quick_gelu with gated_linear_unit; "
"_make_fused_ops expected SwiGLU, quick_gelu (with gated_linear_unit), "
"or squared_relu; "
"call _is_fused_impl_supported() before constructing fused ops."
)
ops.append(op)
Expand Down
22 changes: 17 additions & 5 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,14 @@ class MoEAuxLossAutoScaler(torch.autograd.Function):

@staticmethod
def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor) -> torch.Tensor:
"""Preserve the aux_loss by storing it in the context to avoid garbage collection.
"""Preserve the aux_loss metadata so backward can construct a gradient tensor.

We store shape/dtype/device on ctx instead of using save_for_backward.
Under partial CUDA graph capture, the actual aux_loss tensor can live in
a CG memory pool that gets recycled on replay, causing saved_tensors to
appear "freed" on the second backward pass. Since the backward only needs
aux_loss's shape (for torch.ones_like-style grad construction), saving
metadata is sufficient and CG-safe.

Args:
output (torch.Tensor): The output tensor.
Expand All @@ -261,7 +268,9 @@ def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(aux_loss)
ctx.aux_loss_shape = aux_loss.shape
ctx.aux_loss_dtype = aux_loss.dtype
ctx.aux_loss_device = aux_loss.device
return output

@staticmethod
Expand All @@ -275,13 +284,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss
gradient.
"""
(aux_loss,) = ctx.saved_tensors
if MoEAuxLossAutoScaler.main_loss_backward_scale is None:
MoEAuxLossAutoScaler.main_loss_backward_scale = torch.tensor(
1.0, device=aux_loss.device
1.0, device=ctx.aux_loss_device
)
aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
scaled_aux_loss_grad = torch.ones(
ctx.aux_loss_shape,
dtype=ctx.aux_loss_dtype,
device=ctx.aux_loss_device,
) * aux_loss_backward_scale
return grad_output, scaled_aux_loss_grad

@staticmethod
Expand Down
22 changes: 18 additions & 4 deletions megatron/core/transformer/moe/paged_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,21 +1191,35 @@ def _track_cfg(c):

_track_cfg(config)

# Local import avoids circular import: schedules -> paged_stash -> multi_token_prediction
# -> megatron.core (still loading).
from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer

for model_chunk in self.model:
model_with_decoder = get_attr_wrapped_model(
model_chunk, "decoder", allow_none=False, return_model_obj=True
)
_track_cfg(model_with_decoder.config)
for layer in model_with_decoder.decoder.layers:
mlp = layer.mlp
if hasattr(mlp, 'token_dispatcher') and hasattr(
transformer_layer = (
layer.mtp_model_layer
if isinstance(layer, MultiTokenPredictionLayer)
else layer
)
mlp = getattr(transformer_layer, "mlp", None)
if mlp is not None and hasattr(mlp, 'token_dispatcher') and hasattr(
mlp.token_dispatcher, 'check_over_budget'
):
self.moe_layers.append(mlp)
if model_with_decoder.mtp_process:
for layer in model_with_decoder.mtp.layers:
mlp = layer.mtp_model_layer.mlp
if hasattr(mlp, 'token_dispatcher') and hasattr(
transformer_layer = (
layer.mtp_model_layer
if isinstance(layer, MultiTokenPredictionLayer)
else layer
)
mlp = getattr(transformer_layer, "mlp", None)
if mlp is not None and hasattr(mlp, 'token_dispatcher') and hasattr(
mlp.token_dispatcher, 'check_over_budget'
):
self.moe_layers.append(mlp)
Expand Down
Loading