Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 167 additions & 16 deletions physicsnemo/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO this also needs more docstrings
from typing import List, Optional
from typing import List, Literal, Optional, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -44,16 +43,49 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
return sections


def get_memory_format(tensor):
"""Gets format for tensor"""
def get_memory_format(tensor: torch.Tensor) -> torch.memory_format:
"""Returns the memory format of a tensor.

Parameters
----------
tensor : torch.Tensor
Input tensor to inspect.

Returns
-------
torch.memory_format
``torch.channels_last`` if the tensor is channels-last contiguous,
otherwise ``torch.contiguous_format``.
"""
if tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
else:
return torch.contiguous_format


def pad_helper(tensor, dim, new_size, mode="zero"):
"""Util for padding tensors"""
def pad_helper(
tensor: torch.Tensor, dim: int, new_size: int, mode: Literal["zero", "conj"] = "zero"
) -> torch.Tensor:
"""Pads a tensor along a specified dimension to a new size.

Parameters
----------
tensor : torch.Tensor
Input tensor to pad.
dim : int
Dimension along which to pad. Negative indices are supported.
new_size : int
Target size of the dimension after padding.
mode : Literal["zero", "conj"], optional
Padding mode. ``"zero"`` fills new entries with zeros;
``"conj"`` fills with the conjugate-symmetric reflection of the
existing data (useful for Hermitian-symmetric spectra), by default ``"zero"``.

Returns
-------
torch.Tensor
Padded tensor with ``tensor.shape[dim] == new_size``.
"""
ndim = tensor.ndim
dim = (dim + ndim) % ndim
ndim_pad = ndim - dim
Expand All @@ -78,8 +110,24 @@ def pad_helper(tensor, dim, new_size, mode="zero"):
return tensor_pad


def truncate_helper(tensor, dim, new_size):
"""Util for truncating"""
def truncate_helper(tensor: torch.Tensor, dim: int, new_size: int) -> torch.Tensor:
"""Truncates a tensor along a specified dimension to a new size.

Parameters
----------
tensor : torch.Tensor
Input tensor to truncate.
dim : int
Dimension along which to truncate. Negative indices are supported.
new_size : int
Target size of the dimension after truncation; must be <= current size.

Returns
-------
torch.Tensor
Truncated tensor with ``tensor.shape[dim] == new_size``, preserving
the original memory format.
"""
input_format = get_memory_format(tensor)
ndim = tensor.ndim
dim = (dim + ndim) % ndim
Expand All @@ -92,7 +140,32 @@ def truncate_helper(tensor, dim, new_size):
return tensor_trunc


def split_tensor_along_dim(tensor, dim, num_chunks):
def split_tensor_along_dim(
tensor: torch.Tensor, dim: int, num_chunks: int
) -> tuple[torch.Tensor, ...]:
"""Splits a tensor along a dimension into balanced chunks.

Parameters
----------
tensor : torch.Tensor
Input tensor to split.
dim : int
Dimension along which to split.
num_chunks : int
Number of chunks to produce. The last chunk may be smaller if
``tensor.shape[dim]`` is not evenly divisible.

Returns
-------
Tuple[torch.Tensor, ...]
Tuple of ``num_chunks`` tensors.

Raises
------
ValueError
If ``dim`` is greater than or equal to the tensor's number of
dimensions, or if the dimension size is smaller than ``num_chunks``.
"""
if dim >= tensor.dim():
raise ValueError(
f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
Expand All @@ -111,7 +184,7 @@ def split_tensor_along_dim(tensor, dim, num_chunks):


@torch.no_grad()
def reduce_loss(loss: float, dst_rank: int = 0, mean: bool = True): # pragma: no cover
def reduce_loss(loss: float, dst_rank: int = 0, mean: bool = True) -> Optional[float]: # pragma: no cover
"""Reduces loss from all processes to destination rank for logging.

Parameters
Expand All @@ -123,6 +196,11 @@ def reduce_loss(loss: float, dst_rank: int = 0, mean: bool = True): # pragma: n
mean : bool, Optional
Calculate the mean of the losses gathered, by default True.

Returns
-------
Optional[float]
The reduced loss value on ``dst_rank``; ``None`` on all other ranks.

Raises
------
Exception
Expand Down Expand Up @@ -151,8 +229,40 @@ def reduce_loss(loss: float, dst_rank: int = 0, mean: bool = True): # pragma: n


# distributed primitives
def distributed_transpose(tensor, dim0, dim1, group=None, async_op=False):
"""Perform distributed transpose of tensor to switch sharding dimension"""
def distributed_transpose(
tensor: torch.Tensor,
dim0: int,
dim1: int,
group: Optional[dist.ProcessGroup] = None,
async_op: bool = False,
) -> tuple[list[torch.Tensor, ...], dist.Work | None]:
"""Performs a distributed transpose to switch the sharding dimension.

Splits ``tensor`` along ``dim0`` across all ranks in the process group,
exchanges the chunks via ``all_to_all``, and returns the list of received
chunks (which collectively form a tensor sharded along ``dim1``).

Parameters
----------
tensor : torch.Tensor
Local shard of the distributed tensor, sharded along ``dim0``.
dim0 : int
Current sharding dimension (the one being split for the all-to-all).
dim1 : int
Target sharding dimension after the transpose.
group : Optional[dist.ProcessGroup], optional
Process group over which the operation is performed, by default None
(uses the default group).
async_op : bool, optional
If ``True``, the all-to-all is launched asynchronously and the caller
must call ``.wait()`` on the returned work handle, by default False.

Returns
-------
Tuple[List[torch.Tensor], Optional[dist.Work]]
A 2-tuple of (received tensor chunks, work handle). The work handle
is ``None`` when ``async_op=False``.
"""
# get input format
input_format = get_memory_format(tensor)

Expand All @@ -173,8 +283,29 @@ def distributed_transpose(tensor, dim0, dim1, group=None, async_op=False):
return x_recv, req


def _reduce(input_, use_fp32=True, group=None): # pragma: no cover
"""All-reduce the input tensor across model parallel group."""
def _reduce( # pragma: no cover
input_: torch.Tensor,
use_fp32: bool = True,
group: Optional[dist.ProcessGroup] = None,
) -> torch.Tensor:
"""All-reduces the input tensor across the model parallel group.

Parameters
----------
input_ : torch.Tensor
Tensor to reduce in-place across ranks.
use_fp32 : bool, optional
If ``True`` and ``input_`` is a low-precision float (< 4 bytes per
element), the reduction is performed in FP32 and cast back to the
original dtype, by default True.
group : Optional[dist.ProcessGroup], optional
Process group over which to reduce, by default None (default group).

Returns
-------
torch.Tensor
Reduced tensor (same object as ``input_`` when no up-cast is needed).
"""

# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
Expand All @@ -193,8 +324,28 @@ def _reduce(input_, use_fp32=True, group=None): # pragma: no cover
return input_


def _split(input_, dim_, group=None): # pragma: no cover
"""Split the tensor along its last dimension and keep the corresponding slice."""
def _split( # pragma: no cover
input_: torch.Tensor,
dim_: int,
group: Optional[dist.ProcessGroup] = None,
) -> torch.Tensor:
"""Splits the tensor along ``dim_`` and returns this rank's slice.

Parameters
----------
input_ : torch.Tensor
Tensor to split across ranks.
dim_ : int
Dimension along which to split.
group : Optional[dist.ProcessGroup], optional
Process group determining the number of chunks and this rank's index,
by default None (default group).

Returns
-------
torch.Tensor
The contiguous slice of ``input_`` belonging to the current rank.
"""
# get input format
input_format = get_memory_format(input_)

Expand Down