diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 670eecaa5e..6a2f7fee1d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Iterable -from contextlib import contextmanager, AbstractContextManager, ContextDecorator +from contextlib import contextmanager, AbstractContextManager, ContextDecorator, nullcontext from functools import lru_cache from dataclasses import dataclass import math @@ -926,7 +926,10 @@ def fork(self, name: str = "model-parallel-rng"): def reduce_scatter_along_first_dim( - inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False + inp: torch.Tensor, + tp_group: dist_group_type, + async_op: bool = False, + output: torch.Tensor = None, ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) @@ -944,7 +947,8 @@ def reduce_scatter_along_first_dim( dim_size[0] = dim_size[0] // world_size - output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) + if output is None: + output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) handle = torch.distributed.reduce_scatter_tensor( output, inp.contiguous(), group=tp_group, async_op=async_op ) @@ -1289,11 +1293,13 @@ def _post_process_nvfp4_gather( handle = None # Fix the interleaved transposed data from gathering along first dim. - out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) - out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + # In-place .copy_() (not `=` rebind) to keep the storage address stable + # for CUDA graph capture — replays see the same pointer they captured. + out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) + out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) - # Optionally pad the scaling inverse if needed. - out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + # Optionally pad the scaling inverse if needed (same in-place pattern). + out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) @dataclass @@ -1307,17 +1313,25 @@ class _NVFP4AllGatherAsyncHandle: async_handle: torch.distributed.Work _synchronized: bool = False - def wait(self) -> None: - """Wait for the async operation to complete and post-process the tensor.""" - if self._synchronized: - return - self.async_handle.wait() + def post_process_nvfp4_gather(self) -> None: + """Fix interleaved transposed data + pad scale_inv after the async AG completes. + + Idempotent: gated by ``_synchronized`` in :meth:`wait`. + """ _post_process_nvfp4_gather( self.output, self.columnwise_data_interleaved, self.columnwise_scale_inv_interleaved, self.world_size, ) + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + if self.async_handle is not None: + self.async_handle.wait() + self.post_process_nvfp4_gather() self._synchronized = True @@ -1328,6 +1342,8 @@ def _all_gather_nvfp4( async_op: bool = False, quantizer: NVFP4Quantizer, out_shape: Optional[list[int]] = None, + output_tensor=None, + grouped=False, ) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: """All-gather NVFP4 tensor along first dimension.""" @@ -1391,6 +1407,12 @@ def _all_gather_nvfp4( out = quantizer(out) return out, None + # Construct NVFP4 output tensor + if output_tensor is not None: + out = output_tensor + else: + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + # Cast input tensor to NVFP4 with required data if not isinstance(inp, NVFP4TensorStorage): inp = quantizer(inp) @@ -1403,17 +1425,19 @@ def _all_gather_nvfp4( ) inp = quantizer(inp.dequantize(dtype=dtype)) - # Construct NVFP4 output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - - # Coalesce NCCL collectives for gathering data and scale inverses. - with torch.distributed._coalescing_manager( - group=process_group, - device=device, - async_ops=async_op, - ) as gather_coalescing_manager: + if not grouped: + # Coalesce NCCL collectives for gathering data and scale inverses. + gather_coalescing_manager = torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) + else: + gather_coalescing_manager = nullcontext() + with gather_coalescing_manager as coalesced_handle: # Gather NVFP4 data for row-wise usage + out_columnwise_data = None if quantizer.rowwise_usage: # Remove padding from NVFP4 scale-inverses @@ -1441,8 +1465,9 @@ def _all_gather_nvfp4( group=process_group, ) - # Transfer amax to output. - out._amax_rowwise = inp._amax_rowwise + # Transfer amax to output via in-place .copy_() so the storage + # address stays stable for CUDA graph capture. + out._amax_rowwise.copy_(inp._amax_rowwise) # Gather the transposed NVFP4 data along first dimension. Fix format later. if quantizer.columnwise_usage: @@ -1491,17 +1516,24 @@ def _all_gather_nvfp4( ) # Transfer amax to output. - out._amax_columnwise = inp._amax_columnwise + out._amax_columnwise.copy_(inp._amax_columnwise) - handle = gather_coalescing_manager if async_op else None + handle = coalesced_handle if async_op else None # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. - if async_op and quantizer.columnwise_usage: - handle = _NVFP4AllGatherAsyncHandle( - out, out_columnwise_data, out_scale_inv, world_size, handle - ) - elif quantizer.columnwise_usage: - _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + if quantizer.columnwise_usage: + if async_op or grouped: + # Defer post-processing: either the async op hasn't completed yet, or an + # external coalescing manager owns the NCCL ops and hasn't flushed them. + inner_handle = handle if async_op else None + handle = _NVFP4AllGatherAsyncHandle( + out, out_columnwise_data, out_scale_inv, world_size, inner_handle + ) + else: + _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + else: + if handle is not None: + handle.output = out return out, handle @@ -1513,6 +1545,8 @@ def _all_gather_mxfp8( async_op: bool = False, quantizer: MXFP8Quantizer, out_shape: Optional[list[int]] = None, + output_tensor: torch.Tensor = None, + grouped: bool = False, ) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]: """All-gather MXFP8 tensor along first dimension.""" @@ -1578,15 +1612,22 @@ def _all_gather_mxfp8( inp = quantizer(inp.dequantize(dtype=dtype)) # Construct MXFP8 output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + if output_tensor is not None: + out = output_tensor + else: + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - # Coalesce NCCL collectives - with torch.distributed._coalescing_manager( - group=process_group, - device=device, - async_ops=async_op, - ) as coalescing_manager: + if not grouped: + # Coalesce NCCL collectives for gathering data and scale inverses. + gather_coalescing_manager = torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) + else: + gather_coalescing_manager = nullcontext() + with gather_coalescing_manager as coalesced_handle: # Gather MXFP8 data for row-wise usage if quantizer.rowwise_usage: @@ -1633,7 +1674,7 @@ def _all_gather_mxfp8( group=process_group, ) - handle = coalescing_manager if async_op else None + handle = coalesced_handle if async_op else None return out, handle @@ -1642,6 +1683,8 @@ def gather_along_first_dim( process_group: dist_group_type, async_op: bool = False, quantizer: Optional[Quantizer] = None, + output_tensor: torch.Tensor = None, + grouped: bool = False, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """ All-gather tensors and concatenate along first dimension. @@ -1732,6 +1775,8 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + output_tensor=output_tensor, + grouped=grouped, ) # NVFP4 case @@ -1746,6 +1791,8 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + output_tensor=output_tensor, + grouped=grouped, ) # High-precision communication for quantized tensors @@ -1775,19 +1822,20 @@ def gather_along_first_dim( inp = inp.dequantize() # Communication for plain PyTorch tensors - out = torch.empty( - out_shape, - dtype=inp.dtype, - device=inp.device, - memory_format=torch.contiguous_format, - ) + if output_tensor is None: + output_tensor = torch.empty( + out_shape, + dtype=inp.dtype, + device=inp.device, + memory_format=torch.contiguous_format, + ) handle = torch.distributed.all_gather_into_tensor( - out, + output_tensor, inp.contiguous(), group=process_group, async_op=async_op, ) - return out, handle + return output_tensor, handle # Global cache to store symmetric memory tensors diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6c7ba8a8ab..f4e8ad430d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -66,6 +66,8 @@ "initialize_ub", "destroy_ub", "is_ub_initialized", + "maybe_wrap_gtp", + "register_gtp_hooks", "using_cublasmp_backend", "UserBufferQuantizationMode", ] @@ -81,6 +83,47 @@ layers_atomic_ring_exchange = [] +# GTP hook slots. An external integrator (currently ``megatron.experimental.gtp``) +# populates these via ``register_gtp_hooks`` at its own import time. When the +# slots stay ``None``, the ``gtp_group=`` codepath in TE modules is a no-op +# and TE has no ``from megatron...`` dependency. +_gtp_slice_fn = None +_gtp_finalize_fn = None +_gtp_wrap_fn = None + + +def register_gtp_hooks(*, slice_fn=None, finalize_fn=None, wrap_fn=None): + """Register GTP integration hooks. Hooks left as ``None`` are unchanged. + + slice_fn(module, name, param, *, expert_idx) -> GTPShardedParam | None + Fires per weight during ``reset_parameters``, before FP8 quantize. + finalize_fn(module, weight_names) -> None + Fires after the per-weight loop in ``reset_parameters``. + wrap_fn(module, weight_names, gtp_group, is_grouped=False) -> None + Fires at the end of a module's ``__init__`` to finalize GTP wiring. + """ + global _gtp_slice_fn, _gtp_finalize_fn, _gtp_wrap_fn + if slice_fn is not None: + _gtp_slice_fn = slice_fn + if finalize_fn is not None: + _gtp_finalize_fn = finalize_fn + if wrap_fn is not None: + _gtp_wrap_fn = wrap_fn + + +def maybe_wrap_gtp(module, weight_names, gtp_group, is_grouped=False): + """Finalize GTP wiring on a module if a wrap hook is registered. + + No-op when ``gtp_group`` is None or no GTP integrator has called + ``register_gtp_hooks``. Called from each TE module's ``__init__`` after + ``reset_parameters`` finishes; the per-weight slice already happened + inside ``reset_parameters`` via ``_gtp_slice_fn``. + """ + if gtp_group is None or _gtp_wrap_fn is None: + return + _gtp_wrap_fn(module, weight_names, gtp_group, is_grouped=is_grouped) + + def is_ub_initialized() -> bool: """Whether the Userbuffers communicators have been initialized.""" return _ub_initialized @@ -1680,7 +1723,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: if defer_init: return - for name, param in self.named_parameters(recurse=False): + # Names of GTP-sharded weights, for GroupedLinear's post-loop finalize. + _gtp_sharded_weight_names = [] + + for idx, (name, param) in enumerate(self.named_parameters(recurse=False)): # Check if parameter is a DTensor (FSDP2) or regular tensor is_dtensor = isinstance(param, DTensor) dtensor_param = param if is_dtensor else None @@ -1702,10 +1748,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) + # GTP slice: shard the freshly-init weight into a GTPShardedParam; + # the FP8 quantize block below is skipped for it. + gtp_sharded = None + if ( + not is_dtensor + and getattr(self, "_gtp_group", None) is not None + and _gtp_slice_fn is not None + ): + gtp_sharded = _gtp_slice_fn(self, name, param, expert_idx=idx) + if gtp_sharded is not None: + param = gtp_sharded + _gtp_sharded_weight_names.append(name) + # Wrap parameters in QuantizedTensor if needed fp8_meta_index = self.param_init_meta[name].fp8_meta_index high_precision_init_val = None - if self.primary_weights_in_fp8 and fp8_meta_index is not None: + if self.primary_weights_in_fp8 and fp8_meta_index is not None and gtp_sharded is None: # Keep high-precision values on CPU if needed if self.preserve_high_precision_init_val: @@ -1735,6 +1794,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. + # Skip the wrap for GTPShardedParam (Parameter.__new__ would drop attrs). if is_dtensor: # recreate the DTensor from the parameter. dtensor_param = DTensor.from_local( @@ -1745,7 +1805,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: stride=dtensor_param.stride(), ) dtensor_param = torch.nn.Parameter(dtensor_param) - else: + elif gtp_sharded is None: param = torch.nn.Parameter(param) # Keep high-precision values on CPU if needed @@ -1783,6 +1843,10 @@ def clear(self): else: self.module_setattr(name, dtensor_param) + # GroupedLinear post-loop finalize hook (no-op outside GroupedLinear). + if _gtp_sharded_weight_names and _gtp_finalize_fn is not None: + _gtp_finalize_fn(self, _gtp_sharded_weight_names) + @abstractmethod def forward(self): """Needs override.""" diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 24e76463bd..7dd82269de 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -29,6 +29,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore +from .base import maybe_wrap_gtp from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( divide, @@ -413,6 +414,7 @@ def forward( skip_fp8_weight_update, save_original_input, debug, + gtp_size, ) = non_tensor_args if fp8: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override @@ -427,6 +429,14 @@ def forward( device = inp.device weight_requires_grad = weights[0].requires_grad + weights_gtp_sharded = weights + if gtp_size > 1: + weights = weights[0].batched_all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=skip_fp8_weight_update, + ) + # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): if FP8GlobalStateManager.get_fp8_recipe().custom(): @@ -632,12 +642,12 @@ def forward( # Python parameter attributes without keeping the parameter alive here. saved_weights = ( weights - if backward_override == "high_precision" and inp.requires_grad + if backward_override == "high_precision" and inp.requires_grad and gtp_size == 1 else [None] * num_gemms ) tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, - *weights_fp8, + *weights_fp8 if gtp_size == 1 else weights_gtp_sharded, *saved_weights, *biases, ) @@ -662,6 +672,10 @@ def forward( if hasattr(weights[0], "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] + elif gtp_size > 1: + ctx.main_grad_funcs = [ + weights_gtp_sharded[i].get_wgrad_tensor for i in range(num_gemms) + ] else: ctx.main_grad_funcs = [ lambda j=i: weights[j].main_grad for i in range(num_gemms) @@ -691,6 +705,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + ctx.gtp_size = gtp_size # backward overrides if backward_override is not None: @@ -915,7 +930,13 @@ def backward( # Only needed when fuse_wgrad_accumulation is enabled. origin_weights = [None] * N main_grads = [None] * N - if ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad: + if ctx.gtp_size > 1: + # GTP: origin_weights come from saved tensors; main_grads are + # get_wgrad_tensor scratch (do not assign to param.main_grad). + origin_weights = weights + if ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad: + main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + elif ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad: origin_weight_refs = ctx.origin_weight_refs ctx.origin_weight_refs = None origin_weights = [ref() if ref is not None else None for ref in origin_weight_refs] @@ -976,13 +997,18 @@ def backward( ctx.m_splits, ) - if ctx.is_first_microbatch is not None: + if ctx.gtp_size > 1: + accumulate_wgrad_into_param_main_grad = False + elif ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + if ctx.gtp_size > 1: + weights = origin_weights[0].batched_all_gather_and_prefetch_bwd() + if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD if ctx.fp8 or ctx.debug: @@ -1051,6 +1077,9 @@ def backward( device=ctx.device, ) wgrad_list = [wgrad_packed[i] for i in range(ctx.num_gemms)] + if ctx.gtp_size > 1: + # Gathered weights are no longer needed after dgrad GEMM. + del weights if ctx.save_original_input: inp = inputmats[0] @@ -1101,7 +1130,8 @@ def backward( use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=( accumulate_wgrad_into_param_main_grad - if not getattr(ctx, "origin_weights_overwrite_main_grad", False) + if ctx.gtp_size == 1 + and not getattr(ctx, "origin_weights_overwrite_main_grad", False) else False ), ) @@ -1143,10 +1173,19 @@ def handle_custom_ddp_from_mcore(weight, main_grad, wgrad): wgrad = None return wgrad - wgrad_list = [ - handle_custom_ddp_from_mcore(weight, main_grad, wgrad) - for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list) - ] + if ctx.gtp_size > 1: + wgrad_list = origin_weights[0].batched_wgrad_reduce_scatter(wgrad_list) + # Drop Python refs to wgrad input buffers. The async RS on rs_stream + # still holds C++ refs (via NCCL Work); those are released when + # _wait_reduce_scatter calls handle.wait() + self.handle = None. + # Without this del, main_grads keeps the tensors alive until function + # return, wasting memory during graph capture warmup. + del main_grads + else: + wgrad_list = [ + handle_custom_ddp_from_mcore(weight, main_grad, wgrad) + for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list) + ] else: wgrad_list = [None] * ctx.num_gemms @@ -1265,6 +1304,7 @@ def __init__( single_grouped_weight: bool = False, single_grouped_bias: bool = False, name: Optional[str] = None, + gtp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -1320,6 +1360,11 @@ def __init__( "Because the TP communication is handled outside of this module." ) + if gtp_group is None: + self.gtp_size = 1 + else: + self.gtp_size = get_distributed_world_size(gtp_group) + self.parallel_mode = parallel_mode if self.parallel_mode not in GemmParallelModes: raise ValueError( @@ -1371,9 +1416,18 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata(num_gemms=self.num_gemms) + self.weight_names = [f"weight{idx}" for idx in range(self.num_gemms)] is_meta = torch.device(device).type == "meta" + if gtp_group is not None: + # Stashed before reset_parameters so the slice hook can see it; + # _gtp_is_grouped routes through the GroupedLinear finalize path. + self._gtp_group = gtp_group + self._gtp_is_grouped = True + self.reset_parameters(defer_init=is_meta) + maybe_wrap_gtp(self, self.weight_names, gtp_group, is_grouped=True) + if self.wgrad_store.delay_wgrad_compute(): for name, param in self.named_parameters(): if name in ("weight", "bias"): @@ -1723,6 +1777,11 @@ def forward( weight_tensors = self._get_weight_tensors() bias_tensors = self._get_bias_tensors() + if self.gtp_size > 1: + weight_tensors[0].setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() if debug: @@ -1775,6 +1834,7 @@ def forward( skip_fp8_weight_update, self.save_original_input, debug, + self.gtp_size, ) out, new_workspaces = linear_fn( *autograd_ctx, inp, m_splits, non_tensor_args, *weight_tensors, *bias_tensors diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7fc96d4779..7e74282e4a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -30,6 +30,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from .base import maybe_wrap_gtp from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( assert_dim_for_fp8_exec, @@ -145,6 +146,7 @@ def forward( symmetric_ar_type, debug, is_fsdp2, + gtp_size, ) = non_tensor_args if fp8: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override @@ -299,6 +301,16 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + + weight_gtp_sharded = weight + if gtp_size > 1: + weight = weight.all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=skip_fp8_weight_update, + ) + out_features = weight.shape[0] + new_weight_workspace = None weightmat = weight is_weight_param_quantized = False @@ -491,8 +503,9 @@ def forward( wt_save = None tensors_to_save, tensor_objects = prepare_for_saving( inputmat, - wt_save, - weight, + # GTP: save the sharded reference only; backward re-gathers it. + wt_save if gtp_size == 1 else None, + weight if gtp_size == 1 else weight_gtp_sharded, bias, ln_weight, ln_out_to_save, @@ -519,6 +532,8 @@ def forward( if hasattr(weight, "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward ctx.main_grad_func = weight.get_main_grad + elif gtp_size > 1: + ctx.main_grad_func = weight_gtp_sharded.get_wgrad_tensor else: ctx.main_grad_func = lambda: weight.main_grad ctx.grad_input_quantizer = grad_input_quantizer @@ -561,6 +576,7 @@ def forward( qstate.is_first_fp8_module = _first_fp8_module ctx.wgrad_store = wgrad_store ctx.debug = debug + ctx.gtp_size = gtp_size # backward overrides if backward_override is not None: @@ -612,6 +628,9 @@ def backward( rsigma, ) = restore_from_func_ctx(ctx) + if ctx.gtp_size > 1: + weight = saved_weight.all_gather_and_prefetch_bwd() + # Restore from weakref to get original weight python object # (preserves attributes like main_grad, grad_added_to_main_grad, etc.) # Only needed when fuse_wgrad_accumulation is enabled. @@ -629,7 +648,7 @@ def backward( ), "weight was removed while fuse_wgrad_accumulation=True" # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ctx.main_grad_func() if weight is not None else None - if main_grad is not None: + if main_grad is not None and ctx.gtp_size == 1: origin_weight.main_grad = main_grad # Gather intermediate/activation tensors if needed @@ -964,7 +983,10 @@ def backward( use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if ctx.is_first_microbatch is not None: + if ctx.gtp_size > 1: + # GTP: accumulation happens downstream in wgrad_reduce_scatter. + accumulate_wgrad_into_param_main_grad = False + elif ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) @@ -1036,6 +1058,9 @@ def wgrad_gemm( # Call wgrad GEMM now wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output) + if ctx.gtp_size > 1: + wgrad = saved_weight.wgrad_reduce_scatter(wgrad) + # Update grad bias if needed if grad_bias is None: grad_bias = grad_bias_ @@ -1115,7 +1140,10 @@ def wgrad_gemm( if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): + if ctx.gtp_size > 1: + # GTP: skip — wgrad RS already produced the correct shard. + pass + elif ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( @@ -1282,6 +1310,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: Optional[str] = None, + gtp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -1312,6 +1341,11 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if gtp_group is None: + self.gtp_size = 1 + else: + self.gtp_size = get_distributed_world_size(gtp_group) + self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -1522,8 +1556,18 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() + if gtp_group is not None: + # Stashed before reset_parameters so the slice hook can see it. + self._gtp_group = gtp_group + self._gtp_is_grouped = False + self.reset_parameters(defer_init=device == "meta") + maybe_wrap_gtp(self, self.weight_names, gtp_group) + if gtp_group is not None: + # Free the full-size backing buffer; GTP replaced it with a sharded param. + del weight_tensor + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: @@ -1686,6 +1730,11 @@ def forward( # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + if self.gtp_size > 1: + weight_tensor.setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = ( self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) if not debug @@ -1756,6 +1805,7 @@ def forward( self.symmetric_ar_type, debug, self.is_fsdp2, + self.gtp_size, ) out, ln_out, new_weight_workspace = fwd_fn( *autograd_ctx, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c2d98d160..22923d1328 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -30,6 +30,7 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore +from .base import maybe_wrap_gtp from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, @@ -155,6 +156,9 @@ class LinearFwdArgs: cpu_offloading: bool is_grad_enabled: bool + # --- Generalized tensor parallelism --- + gtp_size: int = 1 + @dataclass(slots=True) class LinearBwdArgs: @@ -224,6 +228,9 @@ class LinearBwdArgs: cpu_offloading: bool = False owns_input: bool = False + # --- Generalized tensor parallelism --- + gtp_size: int = 1 + # --- Per-backward scratch state (populated inside _linear_backward) --- ub_obj_gradout: Optional[Any] = None @@ -401,6 +408,17 @@ def _linear_forward_impl( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + # GTP: rebind `weight` to the all-gathered tensor; `args.weight` keeps + # the GTPShardedParam reference for backward re-gather / wgrad RS. + if args.gtp_size > 1: + weight = weight.all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=args.skip_fp8_weight_update, + ) + # Refresh out_features from the gathered weight (captured sharded above, pre-gather). + out_features = weight.shape[0] + new_weight_workspace = None weightmat = weight if fp8 or debug: @@ -587,6 +605,9 @@ def _linear_forward_impl( wt_save = weightmat if is_fsdp2 and weightmat is not weight: wt_save = None + # GTP: don't save the workspace; backward re-gathers it. + if args.gtp_size > 1: + wt_save = None # Dedup save slots that alias forward inputs; ``_linear_setup_ctx`` # rebuilds the refs from ``inp`` / ``weight`` / ``bias``. @@ -691,11 +712,14 @@ def _linear_setup_ctx( bwd_args.origin_weight_overwrites_main_grad = getattr(weight, "overwrite_main_grad", False) if hasattr(weight, "__fsdp_param__"): bwd_args.main_grad_func = weight.get_main_grad + elif fwd_args.gtp_size > 1: + bwd_args.main_grad_func = weight.get_wgrad_tensor else: bwd_args.main_grad_func = lambda: weight.main_grad # Misc bwd_args.cpu_offloading = fwd_args.cpu_offloading + bwd_args.gtp_size = fwd_args.gtp_size if backward_override is not None: bwd_args.fp8 = False @@ -762,7 +786,8 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. origin_weight_python_object is not None ), "weight was removed while fuse_wgrad_accumulation=True" main_grad = bwd_args.main_grad_func() - origin_weight_python_object.main_grad = main_grad + if bwd_args.gtp_size == 1: + origin_weight_python_object.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when bwd_args.fp8 == False and torch.disttributed.FSDP already @@ -932,6 +957,12 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. dgrad = None dgrad_work = None + + # GTP: re-gather the sharded weight; runs even when requires_dgrad=False + # so the prev_w prefetch is issued for the next layer's bwd. + if bwd_args.gtp_size > 1: + weight_fp8 = saved_weight.all_gather_and_prefetch_bwd() + if bwd_args.requires_dgrad: # FSDP2: Re-create workspace from all-gathered weight when @@ -1140,7 +1171,10 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if bwd_args.is_first_microbatch is not None: + if bwd_args.gtp_size > 1: + # GTP: accumulation happens downstream in wgrad_reduce_scatter. + accumulate_wgrad_into_param_main_grad = False + elif bwd_args.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( bwd_args.fuse_wgrad_accumulation and not bwd_args.is_first_microbatch ) @@ -1216,6 +1250,11 @@ def wgrad_gemm( # Call wgrad GEMM now wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output) + # GTP: reduce-scatter the freshly computed wgrad (async; overlap + # with the next layer's bwd via the cascade). + if bwd_args.gtp_size > 1: + wgrad = saved_weight.wgrad_reduce_scatter(wgrad) + # Update grad bias if needed if grad_bias is None: grad_bias = grad_bias_ @@ -1261,15 +1300,19 @@ def wgrad_gemm( origin_weight_python_object, "grad_added_to_main_grad" ): origin_weight_python_object.grad_added_to_main_grad = True + # Use the param's local shape (sharded under GTP) so the dummy wgrad + # matches the saved weight shape; main_grad_func() under GTP returns + # an unsharded scratch and would otherwise mismatch. + wgrad_shape = list(origin_weight_python_object.shape) if getattr(origin_weight_python_object, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( - list(main_grad.shape), + wgrad_shape, origin_weight_python_object.dtype, zero=True, ) else: wgrad = get_dummy_wgrad( - list(main_grad.shape), + wgrad_shape, origin_weight_python_object.dtype, ) elif bwd_args.fuse_wgrad_accumulation: @@ -1485,6 +1528,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + gtp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -1513,6 +1557,11 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if gtp_group is None: + self.gtp_size = 1 + else: + self.gtp_size = get_distributed_world_size(gtp_group) + self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -1698,8 +1747,18 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() + if gtp_group is not None: + # Stashed before reset_parameters so the slice hook can see it. + self._gtp_group = gtp_group + self._gtp_is_grouped = False + self.reset_parameters(defer_init=device == "meta") + maybe_wrap_gtp(self, self.weight_names, gtp_group) + if gtp_group is not None: + # Free the full-size backing buffer; GTP replaced it with a sharded param. + del weight_tensor + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: @@ -1830,6 +1889,11 @@ def forward( try: weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + if self.gtp_size > 1: + weight_tensor.setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = ( self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) if not debug @@ -1948,6 +2012,8 @@ def forward( # misc cpu_offloading=is_cpu_offload_enabled(), is_grad_enabled=is_grad_enabled, + # generalized tensor parallelism + gtp_size=self.gtp_size, ) out, new_weight_workspace = linear_fn( *autograd_ctx, diff --git a/transformer_engine/pytorch/ops/fused/grouped_mlp.py b/transformer_engine/pytorch/ops/fused/grouped_mlp.py index 39180f098e..b40e5b225e 100644 --- a/transformer_engine/pytorch/ops/fused/grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/grouped_mlp.py @@ -17,7 +17,12 @@ import transformer_engine_torch as tex from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload +from ...cpu_offload import ( + is_cpu_offload_enabled, + mark_activation_offload, + mark_not_offload, + start_offload, +) from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...module.base import _2X_ACC_WGRAD from ...quantization import Recipe @@ -534,6 +539,7 @@ def _compute_grad_params( wgrad_output = None op_label = f"Grouped MLP fused backward ({label})" if label else "Grouped MLP fused backward" weights = fc_op._get_weight_tensors() + gtp_size = getattr(ctx, "gtp_size", 1) if fc_op.single_grouped_weight: w_list = [None] if ctx.weight_requires_grad: @@ -566,7 +572,9 @@ def _compute_grad_params( else: w_list = [None] * num_groups if ctx.weight_requires_grad: - if fc_op._accumulate_into_main_grad: + # EGTP: the GEMM produces full-sized wgrads but main_grad is sharded, so use a + # full-sized scratch buffer (the reduce-scatter below lands it in main_grad). + if fc_op._accumulate_into_main_grad and gtp_size == 1: w_list = [get_main_grad_from_param(w, op_label=op_label) for w in weights] accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) else: @@ -582,6 +590,12 @@ def _compute_grad_params( if ctx.weight_requires_grad: # Launch or defer the GEMM delay_wgrad = fc_op.wgrad_store is not None and fc_op.wgrad_store.delay_wgrad_compute() + if gtp_size > 1 and delay_wgrad: + raise RuntimeError( + "EGTP + cuteDSL fused grouped-MLP does not support delay_wgrad / " + "overlap_dispatch_backward_with_experts_wgrad yet; set " + "delay_wgrad_compute=False." + ) if cudnn_wgrad_kernel_fn is not None: offsets = offsets if offsets.dtype == torch.int32 else offsets.to(dtype=torch.int32) gemm_fn = functools.partial( @@ -620,9 +634,14 @@ def _compute_grad_params( fc_op.wgrad_store.put([grouped_x, grouped_dy, wgrad_output], gemm_fn) else: gemm_fn(grouped_x, grouped_dy, wgrad_output) + # EGTP: reduce-scatter the full per-rank wgrads into each sharded main_grad + # (also fires the Megatron grad-accum hook). + if gtp_size > 1: + weights[0].batched_wgrad_reduce_scatter(w_list) # Need to return dummy wgrads for Megatron-LM wgrad fusion if grad is already added - if fc_op._accumulate_into_main_grad: + # (wgrad fusion, or the EGTP reduce-scatter above) so it doesn't double-add. + if fc_op._accumulate_into_main_grad or gtp_size > 1: w_list = get_dummy_wgrads_for_params(weights) elif delay_wgrad: w_list = [None] if fc_op.single_grouped_weight else [None] * num_groups @@ -885,6 +904,20 @@ def fuser_forward( num_groups = fc1_op.num_groups fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + + # EGTP: expert weights are sharded 1/N along out_features; the fused kernels read + # the full shape, so all-gather the full weight first. gtp_size==1 is a no-op. + fc1_gtp_size = getattr(fc1_weight_param, "gtp_size", 1) + fc2_gtp_size = getattr(fc2_weight_param, "gtp_size", 1) + assert fc1_gtp_size == fc2_gtp_size, "FC1/FC2 must share one EGTP group." + if fc1_gtp_size > 1: + assert not fc1_op.single_grouped_weight and not fc2_op.single_grouped_weight, ( + "EGTP + cuteDSL fused grouped-MLP only supports the discrete " + "(single_grouped_weight=False) expert-weight layout." + ) + assert fc1_op.weight0.is_routed_expert and fc1_op.weight0.weight_list is not None + assert fc2_op.weight0.is_routed_expert and fc2_op.weight0.weight_list is not None + device = fc1_weight_param.device if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") @@ -963,7 +996,23 @@ def fuser_forward( None, ) else: - fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] + if fc1_gtp_size > 1: + # Init per-shard quantizers once (quantize-then-gather). + if fc1_op.weight0._quantizer is None: + fc1_op.weight0.setup( + weight_quantizer=[ + fc1_op.get_quantizer("forward", 2 * idx + 1) + for idx in range(num_groups) + ] + ) + # All-gather the full per-expert weights (returns a list of N full tensors). + # TODO: pass in is_first_microbatch flag to skip redundant quantization after + # the first microbatch in each training step. + fc1_weights = fc1_op.weight0.batched_all_gather_and_prefetch( + fwd=True, skip_weight_cast=False, cast_noop_flag=None + ) + else: + fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] quantized_fc1_weights = [] for idx, weight in enumerate(fc1_weights): quantizer = fc1_op.get_quantizer("forward", 2 * idx + 1) @@ -974,47 +1023,11 @@ def fuser_forward( quantized_fc1_weights.append(weight) grouped_fc1_weight = quantized_fc1_weights - # Prepare FC2 grouped weight tensor for fused kernels. - if fc2_op.single_grouped_weight: - if not isinstance(fc2_op.weight, GroupedTensor): - raise RuntimeError( - "FC2 expected GroupedTensor weight with single_grouped_weight=True." - ) - if fc2_op.weight.quantizer is not None: - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - fc2_op.weight.quantizer = fc2_weight_quantizer - grouped_fc2_weight = fc2_op.weight - else: - if fc2_op.weight.rowwise_data is None: - raise RuntimeError("FC2 grouped weight has no rowwise_data to quantize.") - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - grouped_fc2_weight = _group_quantize_for_grouped_mlp( - fc2_op.weight.rowwise_data.view(fc2_op.weight.logical_shape), - fc2_weight_quantizer, - num_groups, - None, - ) - else: - fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] - quantized_fc2_weights = [] - for idx, weight in enumerate(fc2_weights): - quantizer = fc2_op.get_quantizer("forward", 2 * idx + 1) - if not is_quantized_tensor(weight): - quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - quantized_fc2_weights.append(quantizer(weight)) - else: - quantized_fc2_weights.append(weight) - grouped_fc2_weight = quantized_fc2_weights - # Some wrapper-copy paths may drop grouped storage metadata; enforce defaults. if isinstance(grouped_fc1_weight, GroupedTensor) and not hasattr( grouped_fc1_weight, "_with_gemm_swizzled_scales" ): grouped_fc1_weight._with_gemm_swizzled_scales = False - if isinstance(grouped_fc2_weight, GroupedTensor) and not hasattr( - grouped_fc2_weight, "_with_gemm_swizzled_scales" - ): - grouped_fc2_weight._with_gemm_swizzled_scales = False # Group-quantize input tensor and convert dtypes if needed fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) @@ -1239,6 +1252,57 @@ def fuser_forward( else: fc1_kernel_out = self.grouped_gemm_activation_kernel()(**fc1_activation_kwargs) + # Prepare FC2 grouped weight tensor for fused kernels. + if fc2_op.single_grouped_weight: + if not isinstance(fc2_op.weight, GroupedTensor): + raise RuntimeError( + "FC2 expected GroupedTensor weight with single_grouped_weight=True." + ) + if fc2_op.weight.quantizer is not None: + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + fc2_op.weight.quantizer = fc2_weight_quantizer + grouped_fc2_weight = fc2_op.weight + else: + if fc2_op.weight.rowwise_data is None: + raise RuntimeError("FC2 grouped weight has no rowwise_data to quantize.") + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + grouped_fc2_weight = _group_quantize_for_grouped_mlp( + fc2_op.weight.rowwise_data.view(fc2_op.weight.logical_shape), + fc2_weight_quantizer, + num_groups, + None, + ) + else: + if fc2_gtp_size > 1: + # Init per-shard quantizers once (quantize-then-gather); gated so the + # quantizer list isn't rebuilt every forward. + if fc2_op.weight0._quantizer is None: + fc2_op.weight0.setup( + weight_quantizer=[ + fc2_op.get_quantizer("forward", 2 * idx + 1) + for idx in range(num_groups) + ] + ) + # All-gather the full per-expert weights (returns a list of N full tensors). + fc2_weights = fc2_op.weight0.batched_all_gather_and_prefetch( + fwd=True, skip_weight_cast=False, cast_noop_flag=None + ) + else: + fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] + quantized_fc2_weights = [] + for idx, weight in enumerate(fc2_weights): + quantizer = fc2_op.get_quantizer("forward", 2 * idx + 1) + if not is_quantized_tensor(weight): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + quantized_fc2_weights.append(quantizer(weight)) + else: + quantized_fc2_weights.append(weight) + grouped_fc2_weight = quantized_fc2_weights + if isinstance(grouped_fc2_weight, GroupedTensor) and not hasattr( + grouped_fc2_weight, "_with_gemm_swizzled_scales" + ): + grouped_fc2_weight._with_gemm_swizzled_scales = False + # Unpack kernel outputs # Note: Fused kernel outputs tensors with non-contiguous # logical dims. @@ -1450,21 +1514,48 @@ def fuser_forward( grouped_fc_x.rowwise_data = None grouped_fc_x.scale_inv = None + # Per-op fine-grained offload markers. + offload_fc1_x = bool(getattr(fc1_op, "fine_grained_activation_offloading", False)) + offload_act = bool(getattr(activation_op, "fine_grained_activation_offloading", False)) + fine_grained_offload = offload_fc1_x or offload_act + saved_activations = ( + (grouped_fc1_x, offload_fc1_x), + (activation_in, offload_act), + (saved_grouped_fc2_x, offload_act), + ) + + # The hook-based offloader is opt-out, so explicitly keep the + # non-selected tensors resident (mark_not_offload sets _TE_do_not_offload). + if fine_grained_offload: + keep = [t for t, sel in saved_activations if t is not None and not sel] + if keep: + mark_not_offload(*keep) + if cpu_offloading: + # TE-native path; with no markers, offload everything saved (legacy). activation_tensors = [ - t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None + t + for t, sel in saved_activations + if t is not None and (sel or not fine_grained_offload) ] start_offload(*activation_tensors) mark_activation_offload(*activation_tensors) # Save an internal layout for this joint fused op. The saved state is # intentionally not compatible with the basic GroupedLinear backward. - fc1_weight_tensors = ( - [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight - ) - fc2_weight_tensors = ( - [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight - ) + # EGTP: save the small sharded params; backward col-AGs the layout it needs. + if fc1_gtp_size > 1: + fc1_weight_tensors = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] + else: + fc1_weight_tensors = ( + [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight + ) + if fc2_gtp_size > 1: + fc2_weight_tensors = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] + else: + fc2_weight_tensors = ( + [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight + ) fc1_ctx.save_for_backward( split_sizes, base_split_offsets, @@ -1482,6 +1573,7 @@ def fuser_forward( fc1_ctx.dtype = dtype fc1_ctx.input_requires_grad = input_requires_grad fc1_ctx.weight_requires_grad = weight_requires_grad + fc1_ctx.gtp_size = fc1_gtp_size fc2_ctx.input_quantizers = [fc2_input_quantizer] fc2_ctx.grad_output_quantizers = [fc2_grad_output_quantizer] @@ -1489,6 +1581,7 @@ def fuser_forward( fc2_ctx.input_requires_grad = input_requires_grad fc2_ctx.weight_requires_grad = weight_requires_grad fc2_ctx.recompute_input_from_dsrelu = recompute_srelu_fc2_x + fc2_ctx.gtp_size = fc2_gtp_size return fc2_out, [(), (), ()] @@ -1728,6 +1821,11 @@ def fuser_backward( glu_clamp_min=self._cudnn_glu_clamp_min, ) + # EGTP: forward gathered rowwise; col-AG the columnwise layout the dgrad needs, + # right before use (one weight live at a time). + if getattr(fc2_ctx, "gtp_size", 1) > 1: + grouped_fc2_weight = fc2_op.weight0.batched_all_gather_and_prefetch_bwd() + if fc2_op.single_grouped_weight: # Clone and swizzle scales for GEMM fc2_weight_for_gemm = grouped_fc2_weight.copy() @@ -2006,6 +2104,10 @@ def fuser_backward( "use_dynamic_sched": True, } + # EGTP: col-AG the columnwise layout for the FC1 dgrad (only when dgrad runs). + if getattr(fc1_ctx, "gtp_size", 1) > 1: + grouped_fc1_weight = fc1_op.weight0.batched_all_gather_and_prefetch_bwd() + if fc1_op.single_grouped_weight: # Clone and swizzle scales for GEMM fc1_weight_for_gemm = grouped_fc1_weight.copy()