-
Notifications
You must be signed in to change notification settings - Fork 749
Generalized Tensor Parallelism (GTP) #3005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
43a6cea
a7d0925
a532120
2b26b69
4a12eb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Iterable | ||
| from contextlib import contextmanager, AbstractContextManager, ContextDecorator | ||
| from contextlib import contextmanager, AbstractContextManager, ContextDecorator, nullcontext | ||
| from functools import lru_cache | ||
| from dataclasses import dataclass | ||
| import math | ||
|
|
@@ -926,7 +926,10 @@ def fork(self, name: str = "model-parallel-rng"): | |
|
|
||
|
|
||
| def reduce_scatter_along_first_dim( | ||
| inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False | ||
| inp: torch.Tensor, | ||
| tp_group: dist_group_type, | ||
| async_op: bool = False, | ||
| output: torch.Tensor = None, | ||
| ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: | ||
| """Reduce-scatter the input tensor across model parallel group.""" | ||
| world_size = get_distributed_world_size(tp_group) | ||
|
|
@@ -944,7 +947,8 @@ def reduce_scatter_along_first_dim( | |
|
|
||
| dim_size[0] = dim_size[0] // world_size | ||
|
|
||
| output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) | ||
| if output is None: | ||
| output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) | ||
| handle = torch.distributed.reduce_scatter_tensor( | ||
| output, inp.contiguous(), group=tp_group, async_op=async_op | ||
| ) | ||
|
|
@@ -1289,11 +1293,13 @@ def _post_process_nvfp4_gather( | |
| handle = None | ||
|
|
||
| # Fix the interleaved transposed data from gathering along first dim. | ||
| out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) | ||
| out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) | ||
| # In-place .copy_() (not `=` rebind) to keep the storage address stable | ||
| # for CUDA graph capture — replays see the same pointer they captured. | ||
| out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) | ||
| out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) | ||
|
|
||
| # Optionally pad the scaling inverse if needed. | ||
| out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) | ||
| # Optionally pad the scaling inverse if needed (same in-place pattern). | ||
| out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -1307,17 +1313,25 @@ class _NVFP4AllGatherAsyncHandle: | |
| async_handle: torch.distributed.Work | ||
| _synchronized: bool = False | ||
|
|
||
| def wait(self) -> None: | ||
| """Wait for the async operation to complete and post-process the tensor.""" | ||
| if self._synchronized: | ||
| return | ||
| self.async_handle.wait() | ||
| def post_process_nvfp4_gather(self) -> None: | ||
| """Fix interleaved transposed data + pad scale_inv after the async AG completes. | ||
|
|
||
| Idempotent: gated by ``_synchronized`` in :meth:`wait`. | ||
| """ | ||
| _post_process_nvfp4_gather( | ||
| self.output, | ||
| self.columnwise_data_interleaved, | ||
| self.columnwise_scale_inv_interleaved, | ||
| self.world_size, | ||
| ) | ||
|
|
||
| def wait(self) -> None: | ||
| """Wait for the async operation to complete and post-process the tensor.""" | ||
| if self._synchronized: | ||
| return | ||
| if self.async_handle is not None: | ||
| self.async_handle.wait() | ||
| self.post_process_nvfp4_gather() | ||
| self._synchronized = True | ||
|
fanshiqing marked this conversation as resolved.
|
||
|
|
||
|
Comment on lines
+1316
to
1336
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
|
|
@@ -1328,6 +1342,8 @@ def _all_gather_nvfp4( | |
| async_op: bool = False, | ||
| quantizer: NVFP4Quantizer, | ||
| out_shape: Optional[list[int]] = None, | ||
| output_tensor=None, | ||
| grouped=False, | ||
| ) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: | ||
| """All-gather NVFP4 tensor along first dimension.""" | ||
|
|
||
|
|
@@ -1391,6 +1407,12 @@ def _all_gather_nvfp4( | |
| out = quantizer(out) | ||
| return out, None | ||
|
|
||
| # Construct NVFP4 output tensor | ||
| if output_tensor is not None: | ||
| out = output_tensor | ||
| else: | ||
| out = quantizer.make_empty(out_shape, dtype=dtype, device=device) | ||
|
|
||
| # Cast input tensor to NVFP4 with required data | ||
| if not isinstance(inp, NVFP4TensorStorage): | ||
| inp = quantizer(inp) | ||
|
|
@@ -1403,17 +1425,19 @@ def _all_gather_nvfp4( | |
| ) | ||
| inp = quantizer(inp.dequantize(dtype=dtype)) | ||
|
|
||
| # Construct NVFP4 output tensor | ||
| out = quantizer.make_empty(out_shape, dtype=dtype, device=device) | ||
|
|
||
| # Coalesce NCCL collectives for gathering data and scale inverses. | ||
| with torch.distributed._coalescing_manager( | ||
| group=process_group, | ||
| device=device, | ||
| async_ops=async_op, | ||
| ) as gather_coalescing_manager: | ||
| if not grouped: | ||
| # Coalesce NCCL collectives for gathering data and scale inverses. | ||
| gather_coalescing_manager = torch.distributed._coalescing_manager( | ||
| group=process_group, | ||
| device=device, | ||
| async_ops=async_op, | ||
| ) | ||
| else: | ||
| gather_coalescing_manager = nullcontext() | ||
|
|
||
| with gather_coalescing_manager as coalesced_handle: | ||
| # Gather NVFP4 data for row-wise usage | ||
| out_columnwise_data = None | ||
| if quantizer.rowwise_usage: | ||
|
|
||
| # Remove padding from NVFP4 scale-inverses | ||
|
|
@@ -1441,8 +1465,9 @@ def _all_gather_nvfp4( | |
| group=process_group, | ||
| ) | ||
|
|
||
| # Transfer amax to output. | ||
| out._amax_rowwise = inp._amax_rowwise | ||
| # Transfer amax to output via in-place .copy_() so the storage | ||
| # address stays stable for CUDA graph capture. | ||
| out._amax_rowwise.copy_(inp._amax_rowwise) | ||
|
|
||
| # Gather the transposed NVFP4 data along first dimension. Fix format later. | ||
| if quantizer.columnwise_usage: | ||
|
|
@@ -1491,17 +1516,24 @@ def _all_gather_nvfp4( | |
| ) | ||
|
|
||
| # Transfer amax to output. | ||
| out._amax_columnwise = inp._amax_columnwise | ||
| out._amax_columnwise.copy_(inp._amax_columnwise) | ||
|
|
||
| handle = gather_coalescing_manager if async_op else None | ||
| handle = coalesced_handle if async_op else None | ||
|
|
||
| # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. | ||
| if async_op and quantizer.columnwise_usage: | ||
| handle = _NVFP4AllGatherAsyncHandle( | ||
| out, out_columnwise_data, out_scale_inv, world_size, handle | ||
| ) | ||
| elif quantizer.columnwise_usage: | ||
| _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) | ||
| if quantizer.columnwise_usage: | ||
| if async_op or grouped: | ||
| # Defer post-processing: either the async op hasn't completed yet, or an | ||
| # external coalescing manager owns the NCCL ops and hasn't flushed them. | ||
| inner_handle = handle if async_op else None | ||
| handle = _NVFP4AllGatherAsyncHandle( | ||
| out, out_columnwise_data, out_scale_inv, world_size, inner_handle | ||
| ) | ||
| else: | ||
| _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) | ||
| else: | ||
| if handle is not None: | ||
| handle.output = out | ||
|
|
||
| return out, handle | ||
|
|
||
|
|
@@ -1513,6 +1545,8 @@ def _all_gather_mxfp8( | |
| async_op: bool = False, | ||
| quantizer: MXFP8Quantizer, | ||
| out_shape: Optional[list[int]] = None, | ||
| output_tensor: torch.Tensor = None, | ||
| grouped: bool = False, | ||
| ) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]: | ||
| """All-gather MXFP8 tensor along first dimension.""" | ||
|
|
||
|
|
@@ -1578,15 +1612,22 @@ def _all_gather_mxfp8( | |
| inp = quantizer(inp.dequantize(dtype=dtype)) | ||
|
|
||
| # Construct MXFP8 output tensor | ||
| out = quantizer.make_empty(out_shape, dtype=dtype, device=device) | ||
| if output_tensor is not None: | ||
| out = output_tensor | ||
| else: | ||
| out = quantizer.make_empty(out_shape, dtype=dtype, device=device) | ||
|
|
||
| # Coalesce NCCL collectives | ||
| with torch.distributed._coalescing_manager( | ||
| group=process_group, | ||
| device=device, | ||
| async_ops=async_op, | ||
| ) as coalescing_manager: | ||
| if not grouped: | ||
| # Coalesce NCCL collectives for gathering data and scale inverses. | ||
| gather_coalescing_manager = torch.distributed._coalescing_manager( | ||
| group=process_group, | ||
| device=device, | ||
| async_ops=async_op, | ||
| ) | ||
| else: | ||
| gather_coalescing_manager = nullcontext() | ||
|
|
||
| with gather_coalescing_manager as coalesced_handle: | ||
| # Gather MXFP8 data for row-wise usage | ||
| if quantizer.rowwise_usage: | ||
|
|
||
|
|
@@ -1633,7 +1674,7 @@ def _all_gather_mxfp8( | |
| group=process_group, | ||
| ) | ||
|
|
||
| handle = coalescing_manager if async_op else None | ||
| handle = coalesced_handle if async_op else None | ||
| return out, handle | ||
|
|
||
|
|
||
|
|
@@ -1642,6 +1683,8 @@ def gather_along_first_dim( | |
| process_group: dist_group_type, | ||
| async_op: bool = False, | ||
| quantizer: Optional[Quantizer] = None, | ||
| output_tensor: torch.Tensor = None, | ||
| grouped: bool = False, | ||
| ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: | ||
| """ | ||
| All-gather tensors and concatenate along first dimension. | ||
|
|
@@ -1732,6 +1775,8 @@ def gather_along_first_dim( | |
| async_op=async_op, | ||
| quantizer=quantizer, | ||
| out_shape=out_shape, | ||
| output_tensor=output_tensor, | ||
| grouped=grouped, | ||
| ) | ||
|
|
||
| # NVFP4 case | ||
|
|
@@ -1746,6 +1791,8 @@ def gather_along_first_dim( | |
| async_op=async_op, | ||
| quantizer=quantizer, | ||
| out_shape=out_shape, | ||
| output_tensor=output_tensor, | ||
| grouped=grouped, | ||
| ) | ||
|
|
||
| # High-precision communication for quantized tensors | ||
|
|
@@ -1775,19 +1822,20 @@ def gather_along_first_dim( | |
| inp = inp.dequantize() | ||
|
|
||
| # Communication for plain PyTorch tensors | ||
| out = torch.empty( | ||
| out_shape, | ||
| dtype=inp.dtype, | ||
| device=inp.device, | ||
| memory_format=torch.contiguous_format, | ||
| ) | ||
| if output_tensor is None: | ||
| output_tensor = torch.empty( | ||
| out_shape, | ||
| dtype=inp.dtype, | ||
| device=inp.device, | ||
| memory_format=torch.contiguous_format, | ||
| ) | ||
| handle = torch.distributed.all_gather_into_tensor( | ||
| out, | ||
| output_tensor, | ||
| inp.contiguous(), | ||
| group=process_group, | ||
| async_op=async_op, | ||
| ) | ||
| return out, handle | ||
| return output_tensor, handle | ||
|
|
||
|
|
||
| # Global cache to store symmetric memory tensors | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_post_process_nvfp4_gatherbreaks any K not a multiple of 128out._columnwise_scale_invis allocated byNVFP4Quantizer.make_emptywith shape(round_up(K, 128), round_up(ceil(M_total/16), 4))— the fully-padded shape. The intermediate result from_swap_first_dims(columnwise_scale_inv_interleaved, world_size)has the unpadded shape(K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge andout._columnwise_scale_inv.copy_(...)raises a RuntimeError at the first all-gather call.The pre-PR code used
=rebinding, which handled arbitrary shapes. Replacing it with.copy_()is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — whichmake_emptydoes not do. The GTP-prefetchedoutput_tensorpath has the same problem on the step-1 copy before thepad_columnwise_scale_invcall can correct things.