Generalized Tensor Parallelism (GTP) #3005
Conversation
Greptile SummaryThis PR adds Generalized Tensor Parallelism (GTP) to TransformerEngine's
Confidence Score: 3/5The backward path in The core wgrad path in
Important Files Changed
Sequence DiagramsequenceDiagram
participant Mcore
participant TE_Module as TE Linear/GroupedLinear
participant GTPShardedParam
participant NCCL
Note over Mcore,TE_Module: Init
Mcore->>TE_Module: register_gtp_hooks(slice_fn, finalize_fn, wrap_fn)
TE_Module->>GTPShardedParam: _gtp_slice_fn(module, name, param, expert_idx)
TE_Module->>Mcore: _gtp_wrap_fn(module, weight_names, gtp_group)
Note over Mcore,NCCL: Forward Pass
TE_Module->>GTPShardedParam: weight.setup(weight_quantizer)
GTPShardedParam->>NCCL: "all_gather_and_prefetch(fwd=True)"
TE_Module->>TE_Module: GEMM(input, gathered_weight)
TE_Module->>TE_Module: save sharded param refs for bwd
Note over Mcore,NCCL: Backward Pass
TE_Module->>GTPShardedParam: saved_weight.all_gather_and_prefetch_bwd()
TE_Module->>TE_Module: dgrad GEMM(grad_output, gathered_weight)
TE_Module->>TE_Module: wgrad GEMM(input, grad_output)
GTPShardedParam->>NCCL: wgrad_reduce_scatter(wgrad) async on rs_stream
NCCL-->>GTPShardedParam: shard lands in main_grad buffer
|
|
/te-ci L1 pytorch |
3e70bdf to
ed9ce68
Compare
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com> Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
| # Fix the interleaved transposed data from gathering along first dim. | ||
| out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) | ||
| out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) | ||
| # In-place .copy_() (not `=` rebind) to keep the storage address stable | ||
| # for CUDA graph capture — replays see the same pointer they captured. | ||
| out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) | ||
| out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) | ||
|
|
||
| # Optionally pad the scaling inverse if needed. | ||
| out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) | ||
| # Optionally pad the scaling inverse if needed (same in-place pattern). | ||
| out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) | ||
|
|
There was a problem hiding this comment.
Shape mismatch in
_post_process_nvfp4_gather breaks any K not a multiple of 128
out._columnwise_scale_inv is allocated by NVFP4Quantizer.make_empty with shape (round_up(K, 128), round_up(ceil(M_total/16), 4)) — the fully-padded shape. The intermediate result from _swap_first_dims(columnwise_scale_inv_interleaved, world_size) has the unpadded shape (K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge and out._columnwise_scale_inv.copy_(...) raises a RuntimeError at the first all-gather call.
The pre-PR code used = rebinding, which handled arbitrary shapes. Replacing it with .copy_() is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — which make_empty does not do. The GTP-prefetched output_tensor path has the same problem on the step-1 copy before the pad_columnwise_scale_inv call can correct things.
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
| @@ -1627,10 +1677,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: | |||
| with get_rng_state_tracker().fork(): | |||
| init_fn(param) | |||
|
|
|||
| # GTP slice: shard the freshly-init weight into a GTPShardedParam; | |||
There was a problem hiding this comment.
Wrong
expert_idx for LayerNormLinear (and GroupedLinear with bias) silently disables GTP weight slicing
expert_idx=idx uses the position of the parameter in named_parameters(recurse=False), which includes non-linear-weight parameters. For LayerNormLinear the iteration order is layer_norm_weight (idx=0), layer_norm_bias (idx=1 for non-RMSNorm), weight (idx=2 or 1). The linear weight therefore arrives at _gtp_slice_fn with expert_idx=2 (or 1 for RMSNorm) instead of expert_idx=0. A Mcore hook that maps expert_idx to a pre-registered shard slot would find no entry for idx=2 and return None, silently leaving the weight un-sharded while gtp_group is set — defeating GTP for the entire LayerNormLinear path this PR explicitly adds.
Similarly, for GroupedLinear with biases enabled, weight1 receives expert_idx=2 (interleaved with bias0), so every expert beyond the first is mis-indexed.
A correct counter only advances when gtp_sharded is not None, keeping it aligned with the weight-only registration slots.
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Deisgn doc: GTP.docx
Description
Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.
Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.
Summary of features
How Mcore interacts with TE
① Mcore registers callbacks into TE at import time.
② TE calls back into Mcore runtime during te.Linear(gtp_group=…) init AND during fwd/bwd (weight.all_gather_and_prefetch / wgrad_reduce_scatter).
③ Mcore extensions forward gtp_group= at module init.
④ TE provides FP8 / MXFP8 / NVFP4 tensor types AND the quantize-then-AG / RS collectives (gather_along_first_dim, reduce_scatter_along_first_dim) — imported by Mcore runtime; GTP wraps them with its own schedule, buffer cache, and stream choreography.
Type of change
Changes
Please list the changes introduced in this PR:
wgrad_shape.
carving (with/without GTP);
Checklist: