diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 19de0ed5298..1a61559a870 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -472,6 +472,7 @@ def preprocess_for_fine_grained_offloading(self): min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks=self.config.delta_offload_bytes_across_pp_ranks, activation_offload_fraction=self.config.activation_offload_fraction, + max_inflight_offloads=self.config.fine_grained_offloading_max_inflight_offloads, ) if self.disable_param_offloading: for param in self.decoder.parameters(): diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index dac241c8d2a..09ea74b21b4 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -20,6 +20,7 @@ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.moe.paged_stash import paged_stash_init_chunk_handler from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, mtp_on_this_rank, @@ -314,6 +315,7 @@ def preprocess_for_fine_grained_offloading(self): min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks=self.config.delta_offload_bytes_across_pp_ranks, activation_offload_fraction=self.config.activation_offload_fraction, + max_inflight_offloads=self.config.fine_grained_offloading_max_inflight_offloads, ) if self.disable_param_offloading: for param in self.decoder.parameters(): @@ -326,6 +328,12 @@ def preprocess_for_fine_grained_offloading(self): off_interface.mark_not_offload(param) self.disable_param_offloading = False + def preprocess_for_paged_stash(self): + """Preprocess for paged stash.""" + return paged_stash_init_chunk_handler( + vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage + ) + def forward( self, input_ids: Tensor, @@ -354,6 +362,9 @@ def forward( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() + if self.config.moe_paged_stash: + self.preprocess_for_paged_stash() + inference_context = deprecate_inference_params(inference_context, inference_params) in_inference_mode = inference_context is not None and not self.training diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 87adfc6c593..1e3601aef26 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1,8 +1,8 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -from collections import deque +from collections import defaultdict, deque from contextlib import nullcontext -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple import torch from torch.autograd.graph import saved_tensors_hooks @@ -665,6 +665,7 @@ def init_model_chunk_offload_handler( min_offloaded_tensor_size=1024 * 1024, delta_offload_bytes_across_pp_ranks=0, activation_offload_fraction: float = 1.0, + max_inflight_offloads: Optional[int] = None, ): """ Initialize a chunk offload handler for a model chunk (microbatch). @@ -677,6 +678,9 @@ def init_model_chunk_offload_handler( delta_offload_bytes_across_pp_ranks: Difference of offload bytes across PP ranks to balance the offload load. activation_offload_fraction: Fraction of eligible groups to offload, in range [0, 1]. + max_inflight_offloads: If set, cap pending offloads per group name before main + wait_event; see ``fine_grained_offloading_max_inflight_offloads`` on + ``TransformerConfig``. """ if not self._is_warmup: return @@ -700,7 +704,11 @@ def init_model_chunk_offload_handler( self.flush() # Use shared CPU tensor pool for better reuse across chunks - cur_chunk = ChunkOffloadHandler(min_offloaded_tensor_size, self._cpu_tensor_pool) + cur_chunk = ChunkOffloadHandler( + min_offloaded_tensor_size, + self._cpu_tensor_pool, + max_inflight_offloads=max_inflight_offloads, + ) debug_rank(f"init_model_chunk_offload_handler {cur_chunk}") self._stages[cur_vpp_rank].append(cur_chunk) # For the last stage, push immediately and flush @@ -824,7 +832,12 @@ def reload(self, state, non_blocking=None): self.cpu_tensor_pool.free(cpu_backup) return gpu_tensor - def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): + def __init__( + self, + min_offloaded_tensor_size, + cpu_tensor_pool, + max_inflight_offloads: Optional[int] = None, + ): self.do_offload = True # Group management for batching offload/reload operations @@ -847,6 +860,10 @@ def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): self.min_offloaded_tensor_size = min_offloaded_tensor_size self.cpu_tensor_pool = cpu_tensor_pool self.is_warmup = True + # Max per-group-name inflight offloads not yet joined on the main stream (None = off). + self._max_inflight_offloads = max_inflight_offloads + # group_name -> FIFO of offload events for that name (same cap for every name). + self._offload_pending_by_name: Dict[str, deque] = defaultdict(deque) def reset(self): """Reset the chunk offload handler.""" @@ -855,6 +872,9 @@ def reset(self): self._groups_to_reload = [] self._tensor_count_current_group = 0 self._reloading_group = [] + # Clear the pending-event FIFO at iter boundary so we never wait on + # an event recorded in a previous (non-captured) iteration. + self._offload_pending_by_name.clear() def find_group_with_name( self, groups: list[OffloadTensorGroup], name: str, start_index: int = 0 @@ -953,6 +973,14 @@ def bulk_offload_group(self, group_to_offload): group_to_offload.record_offload_event(self.d2h_stream) self._groups_to_offload.pop() nvtx_range_pop(nvtx_msg) + # Under full-iteration CG capture, the main stream may not wait on d2h + # events; optional max-inflight enqueues each group's offload event and + # has main wait on older events for this group name when its pending + # count exceeds the cap (each name is tracked separately). + if self._max_inflight_offloads is not None: + gname = group_to_offload._name + self._offload_pending_by_name[gname].append(group_to_offload._offload_event) + self._drain_offload_pending(gname) def get_max_deduplicated_groups(self): """Get the maximum number of deduplicated groups.""" @@ -1036,6 +1064,18 @@ def bulk_offload(self, name, forced_released_tensors): release_tensor.record_stream(cur_stream) release_tensor.untyped_storage().resize_(0) + def _drain_offload_pending(self, group_name: str) -> None: + """For ``group_name``, have the main stream wait on older D2H events + when that name's pending count exceeds ``_max_inflight_offloads`` + (same cap for every name; 0 = wait on each commit for that name).""" + if self._max_inflight_offloads is None: + return + cur = torch.cuda.current_stream() + q = self._offload_pending_by_name[group_name] + while len(q) > self._max_inflight_offloads: + old_evt = q.popleft() + cur.wait_event(old_evt) + def on_group_commit_forward(self, name, forced_released_tensors): """Called at the end of a layer group's forward pass to trigger offloading.""" if not self.do_offload: @@ -1304,6 +1344,7 @@ def init_chunk_handler( min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks, activation_offload_fraction, + max_inflight_offloads: Optional[int] = None, ): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" PipelineOffloadManager.get_instance().init_model_chunk_offload_handler( @@ -1313,6 +1354,7 @@ def init_chunk_handler( min_offloaded_tensor_size, delta_offload_bytes_across_pp_ranks, activation_offload_fraction, + max_inflight_offloads=max_inflight_offloads, ) @staticmethod diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index e9ee2dd8deb..ae27f5ac10e 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -28,6 +28,9 @@ from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update from megatron.core.ssm.ops.mamba_ssm import selective_state_update from megatron.core.tensor_parallel import get_cuda_rng_tracker +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -271,6 +274,11 @@ def __init__( ] setattr(self.in_proj.weight, "partition_sizes", in_proj_partition_sizes) + self.offload_ssm_training = ( + self.config.fine_grained_activation_offloading + and "mamba_ssm_training" in self.config.offload_modules + ) + if not self.use_mem_eff_path: log_single_rank( logger, @@ -461,7 +469,16 @@ def forward( y = self._ssm_prefill(zxBCdt, conv_state=conv_state, ssm_state=ssm_state) else: assert ssm_state is None - y = self._ssm_training(zxBCdt, packed_seq_params) + ssm_training_offload_manager = off_interface( + self.offload_ssm_training, zxBCdt, "mamba_ssm_training" + ) + with ssm_training_offload_manager as zxBCdt: + y = self._ssm_training(zxBCdt, packed_seq_params) + y = ssm_training_offload_manager.group_offload( + y, + forced_released_tensors=[zxBCdt], + delay_offload=self.config.delay_offload_until_cuda_graph, + ) out, out_bias = self.out_proj(y) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 9801e824ba2..2f103fcec93 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -53,6 +53,19 @@ from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import make_weak_ref + def _set_skip_fp8_weight_update(value: bool): + """Compat: old TE exposed FP8GlobalStateManager.set_skip_fp8_weight_update_tensor; + newer TE stores the flag on FP8GlobalStateManager.quantization_state.""" + if hasattr(FP8GlobalStateManager, "set_skip_fp8_weight_update_tensor"): + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(value) + return + qstate = FP8GlobalStateManager.quantization_state + if qstate.skip_fp8_weight_update_tensor is None: + qstate.skip_fp8_weight_update_tensor = torch.empty( + 1, dtype=torch.float32, device="cuda" + ) + qstate.skip_fp8_weight_update_tensor.fill_(value) + HAVE_TE_GRAPHS = True except: HAVE_TE_GRAPHS = False @@ -608,7 +621,7 @@ def forward(ctx, runner, is_first_microbatch, *inputs): # Note that FP8GlobalStateManager.is_first_fp8_module() is inacccurate as each # layer may be in its own fp8 context, when the fp8 recipe != delayed_scaling if runner.is_first_layer and (runner.fp8_param_cache_updated != is_first_microbatch): - FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(not is_first_microbatch) + _set_skip_fp8_weight_update(not is_first_microbatch) runner.fp8_param_cache_updated = is_first_microbatch runner.fwd_graph.replay() @@ -735,13 +748,13 @@ def __init__( if self.fp8_enabled: self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) + _set_skip_fp8_weight_update(False) if self.fp4_enabled: from megatron.core.fp4_utils import get_fp4_recipe # to avoid circular import self.fp4_recipe = get_fp4_recipe(self.base_module.config) - FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) + _set_skip_fp8_weight_update(False) def __str__(self): return "%s; hid %s" % ( diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 9efff3189ac..d4004673870 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -808,12 +808,18 @@ def _unsupported(reason): if not isinstance(self.linear_fc2, te.pytorch.GroupedLinear): return _unsupported(f"linear_fc2 is {type(self.linear_fc2).__name__}") - # Check activation: SwiGLU or quick GEGLU (ScaledClampedQGeGLU, TE >= 2.15) + # Check activation: SwiGLU, quick GEGLU (TE >= 2.15), or squared-relu (TE mxfp8-srelu branch) # Use config.activation_func instead of self.activation_func because when # use_te_activation_func is True, self.activation_func is a TE module, not the raw function. - if not self.config.gated_linear_unit: + if self.config.activation_func == squared_relu: + try: + from transformer_engine.pytorch.ops import ScaledSReLU # noqa: F401 + except ImportError: + return _unsupported("squared_relu needs TE with ScaledSReLU op") + # squared_relu is non-GLU; skip the gated_linear_unit check + elif not self.config.gated_linear_unit: return _unsupported("gated_linear_unit not enabled") - if self.config.activation_func == F.silu: + elif self.config.activation_func == F.silu: pass # SwiGLU — supported elif self.config.activation_func == quick_gelu: try: @@ -888,7 +894,7 @@ def _make_fused_ops(self) -> torch.nn.Module: setattr(op, "bias", getattr(self.linear_fc1, "bias")) ops.append(op) - # Activation and post-multiply probs (SwiGLU or clamped quick-GEGL) + # Activation and post-multiply probs (SwiGLU, clamped quick-GEGLU, or squared-ReLU) glu_interleave = self.config.moe_mlp_glu_interleave_size if self.config.activation_func == F.silu and self.config.gated_linear_unit: op = te.pytorch.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave) @@ -900,9 +906,12 @@ def _make_fused_ops(self) -> torch.nn.Module: ) else: op = te.pytorch.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave) + elif self.config.activation_func == squared_relu: + op = te.pytorch.ops.ScaledSReLU() else: raise RuntimeError( - "_make_fused_ops expected SwiGLU or quick_gelu with gated_linear_unit; " + "_make_fused_ops expected SwiGLU, quick_gelu (with gated_linear_unit), " + "or squared_relu; " "call _is_fused_impl_supported() before constructing fused ops." ) ops.append(op) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index d316d23de10..472be4e4e6d 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -252,7 +252,14 @@ class MoEAuxLossAutoScaler(torch.autograd.Function): @staticmethod def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor) -> torch.Tensor: - """Preserve the aux_loss by storing it in the context to avoid garbage collection. + """Preserve the aux_loss metadata so backward can construct a gradient tensor. + + We store shape/dtype/device on ctx instead of using save_for_backward. + Under partial CUDA graph capture, the actual aux_loss tensor can live in + a CG memory pool that gets recycled on replay, causing saved_tensors to + appear "freed" on the second backward pass. Since the backward only needs + aux_loss's shape (for torch.ones_like-style grad construction), saving + metadata is sufficient and CG-safe. Args: output (torch.Tensor): The output tensor. @@ -261,7 +268,9 @@ def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The output tensor. """ - ctx.save_for_backward(aux_loss) + ctx.aux_loss_shape = aux_loss.shape + ctx.aux_loss_dtype = aux_loss.dtype + ctx.aux_loss_device = aux_loss.device return output @staticmethod @@ -275,13 +284,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient. """ - (aux_loss,) = ctx.saved_tensors if MoEAuxLossAutoScaler.main_loss_backward_scale is None: MoEAuxLossAutoScaler.main_loss_backward_scale = torch.tensor( - 1.0, device=aux_loss.device + 1.0, device=ctx.aux_loss_device ) aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale - scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale + scaled_aux_loss_grad = torch.ones( + ctx.aux_loss_shape, + dtype=ctx.aux_loss_dtype, + device=ctx.aux_loss_device, + ) * aux_loss_backward_scale return grad_output, scaled_aux_loss_grad @staticmethod diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 03cfba7c1aa..3372416e130 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -1191,21 +1191,35 @@ def _track_cfg(c): _track_cfg(config) + # Local import avoids circular import: schedules -> paged_stash -> multi_token_prediction + # -> megatron.core (still loading). + from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer + for model_chunk in self.model: model_with_decoder = get_attr_wrapped_model( model_chunk, "decoder", allow_none=False, return_model_obj=True ) _track_cfg(model_with_decoder.config) for layer in model_with_decoder.decoder.layers: - mlp = layer.mlp - if hasattr(mlp, 'token_dispatcher') and hasattr( + transformer_layer = ( + layer.mtp_model_layer + if isinstance(layer, MultiTokenPredictionLayer) + else layer + ) + mlp = getattr(transformer_layer, "mlp", None) + if mlp is not None and hasattr(mlp, 'token_dispatcher') and hasattr( mlp.token_dispatcher, 'check_over_budget' ): self.moe_layers.append(mlp) if model_with_decoder.mtp_process: for layer in model_with_decoder.mtp.layers: - mlp = layer.mtp_model_layer.mlp - if hasattr(mlp, 'token_dispatcher') and hasattr( + transformer_layer = ( + layer.mtp_model_layer + if isinstance(layer, MultiTokenPredictionLayer) + else layer + ) + mlp = getattr(transformer_layer, "mlp", None) + if mlp is not None and hasattr(mlp, 'token_dispatcher') and hasattr( mlp.token_dispatcher, 'check_over_budget' ): self.moe_layers.append(mlp) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index d7ae679334e..68875c70fe0 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1129,6 +1129,14 @@ class TransformerConfig(ModelParallelConfig): activation_offload_fraction: float = 1.0 """The fraction of the activation to be offloaded, which should be in range [0, 1].""" + fine_grained_offloading_max_inflight_offloads: Optional[int] = None + """Per fine-grained offloading group name, max number of inflight offloads for that name not + yet joined on the main stream (wait_event on D2H). The same cap applies to every name (e.g., + ``moe_act`` and ``qkv_linear`` each have their own pending queue). 0 = wait after every + offload for that name. 1 = at most one not-yet-waited offload per name, etc. None = do not + insert these joins. This feature is particularly useful when using with full-iteration CUDA + graphs""" + moe_paged_stash: bool = False """If True, enable paged stash for all routed-expert activations needed for backward""" @@ -1707,6 +1715,7 @@ def __post_init__(self): "attn_norm", "mlp_norm", "qkv_linear", + "mamba_ssm_training", } invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( @@ -1737,6 +1746,10 @@ def __post_init__(self): assert ( self.delta_offload_bytes_across_pp_ranks >= 0 ), "delta_offload_bytes_across_pp_ranks must be non-negative." + if self.fine_grained_offloading_max_inflight_offloads is not None: + assert ( + self.fine_grained_offloading_max_inflight_offloads >= 0 + ), "fine_grained_offloading_max_inflight_offloads must be non-negative when set." if self.moe_paged_stash: assert not self.cpu_offloading, "moe_paged_stash cannot be enabled with cpu_offloading." assert self.moe_expert_rank_capacity_factor is not None, ( @@ -2351,7 +2364,8 @@ def __post_init__(self): if self.fine_grained_activation_offloading: assert self.cuda_graph_impl == "transformer_engine" or ( - self.cuda_graph_impl == "local" and self.cuda_graph_scope == "full_iteration" + self.cuda_graph_impl == "local" + and self.cuda_graph_scope == [CudaGraphScope.full_iteration] ), ( "fine-grained activation offloading is only supported with " "transformer_engine CUDA graph implementation or local CUDA graph " @@ -2364,6 +2378,11 @@ def __post_init__(self): "cuda_graph_warmup_steps must be greater than 0 when enabling " "fine-grained activation offloading." ) + if CudaGraphScope.full_iteration in self.cuda_graph_scope: + assert self.fine_grained_offloading_max_inflight_offloads is not None, ( + "fine_grained_offloading_max_inflight_offloads must be set when using " + "fine-grained activation offloading with full-iteration CUDA graphs " + ) if self.moe_token_dispatcher_type in ["allgather"]: if self.variable_seq_lengths is True: diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index ee20545110c..a230c61ac9a 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1355,6 +1355,7 @@ def _set_offload_modules(self): ) self.offload_expert_fc1 = "expert_fc1" in self.config.offload_modules self.offload_moe_act = "moe_act" in self.config.offload_modules + self.offload_mamba_ssm_training = "mamba_ssm_training" in self.config.offload_modules else: self.offload_attn_norm = False self.offload_qkv_linear = False @@ -1363,6 +1364,7 @@ def _set_offload_modules(self): self.offload_mlp_norm = False self.offload_expert_fc1 = False self.offload_moe_act = False + self.offload_mamba_ssm_training = False # Check the compatibility of fine-grained activation offloading and cuda graph. if self.config.fine_grained_activation_offloading: if CudaGraphScope.attn in self.config.cuda_graph_scope: diff --git a/megatron/training/config/common_config.py b/megatron/training/config/common_config.py index 2107816bd85..0ca5633a5a1 100644 --- a/megatron/training/config/common_config.py +++ b/megatron/training/config/common_config.py @@ -62,6 +62,9 @@ class ProfilingConfig: memory_snapshot_path: str = "snapshot.pickle" """Specifies where to dump the memory history pickle.""" + memory_snapshot_iter: int = -1 + """If >=0, only record memory history during the iteration with this number, then dump and stop. If -1, falls back to dumping every log_interval (no per-iteration timeline).""" + record_shapes: bool = False """Record shapes of tensors in `torch.autograd.profiler.emit_nvtx` for the Nsys profiler.""" diff --git a/megatron/training/training.py b/megatron/training/training.py index a0817e83433..00cf9f0cb7e 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2469,9 +2469,23 @@ def training_log( total_loss_dict=total_loss_dict, ) # Dump memory snapshot and print metrics to stdout. + if ( + args.record_memory_history + and args.memory_snapshot_iter >= 0 + and iteration == args.memory_snapshot_iter + and (is_last_rank() or torch.distributed.get_backend() == 'fake') + ): + snapshot = torch.cuda.memory._snapshot() + from pickle import dump + + with open(args.memory_snapshot_path, 'wb') as f: + dump(snapshot, f) + torch.cuda.memory._record_memory_history(enabled=None) if iteration % args.log_interval == 0 or is_first_iteration: - if args.record_memory_history and ( - is_last_rank() or torch.distributed.get_backend() == 'fake' + if ( + args.record_memory_history + and args.memory_snapshot_iter < 0 + and (is_last_rank() or torch.distributed.get_backend() == 'fake') ): snapshot = torch.cuda.memory._snapshot() from pickle import dump @@ -3349,6 +3363,15 @@ def trace_handler(p): seqlen_squared_sum_this_global_batch = 0 else: ft_integration.on_training_step_start() + if ( + args.record_memory_history + and args.memory_snapshot_iter >= 0 + and iteration == 0 + and (is_last_rank() or torch.distributed.get_backend() == 'fake') + ): + torch.cuda.memory._record_memory_history( + enabled='all', context='all', stacks='python', max_entries=10**6 + ) ( loss_dict, skipped_iter, diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 6581193a067..2193226b3cb 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -562,6 +562,8 @@ def _broadcast(item): } def _broadcast_cu_seqlens(cu_seqlens): + if not (args.dynamic_context_parallel or args.sft): + return dev = torch.cuda.current_device() n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) n_tensor = torch.empty(1, dtype=torch.int64, device=dev).fill_(n) @@ -643,11 +645,13 @@ def _broadcast_cu_seqlens(cu_seqlens): ) def _broadcast_cu_seqlens(): + if not (args.dynamic_context_parallel or args.sft): + return None dev = torch.cuda.current_device() n = torch.empty((), dtype=torch.int64, device=dev) _broadcast(n) - n = int(n.item()) + n = 0 # patched: skip item() for cuda graph capture if n == 0: cu_seqlens = torch.empty(0, dtype=torch.int32, device=dev)