Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion physicsnemo/domain_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,27 @@
# In minumum versions are met, we can import the shard tensor and spec.

from ._shard_tensor_spec import ShardTensorSpec
from .shard_tensor import ShardTensor, scatter_tensor
from .shard_tensor import (
FSDPOutputTensorAdapter,
ShardTensor,
distribute_over_domain_for_fsdp,
scatter_tensor,
wrap_for_fsdp,
)

def register_custom_ops():
"""Register all custom ShardTensor ops and shard-aware wrappers.

Imports are deferred to this function to avoid an import cycle between
``shard_tensor`` and the individual op modules.
"""
# These imports will register the custom ops with the ShardTensor class.
# It's done here to avoid an import cycle.
from .custom_ops import ( # noqa: F401
_tensor_ops,
mean_wrapper,
sum_wrapper,
unbind_wrapper,
)
from .shard_utils import register_shard_wrappers

Expand All @@ -69,3 +81,6 @@ def register_custom_ops():
ShardTensor = None
ShardTensorSpec = None
scatter_tensor = None
distribute_over_domain_for_fsdp = None
FSDPOutputTensorAdapter = None
wrap_for_fsdp = None
66 changes: 54 additions & 12 deletions physicsnemo/domain_parallel/_shard_redistribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
)

import physicsnemo.domain_parallel.shard_tensor as shard_tensor
from physicsnemo.domain_parallel._shard_tensor_spec import ShardTensorSpec
from physicsnemo.domain_parallel._shard_tensor_spec import (
ShardTensorSpec,
compute_sharding_shapes_from_chunking_global_shape,
)

# TODO:
# DTensor makes assumptions about sharding sizes.
Expand Down Expand Up @@ -242,8 +245,8 @@ def _to_new_shard_dim(
# But we can optimize the null-communication
dist.all_to_all(recv_shapes, send_shapes, group=group)

# Turn the recv_shapes back into torch shapes:
recv_shapes = [list(torch.Size(r)) for r in recv_shapes]
# Turn the recv_shapes back into plain int shape lists.
recv_shapes = [r.tolist() for r in recv_shapes]

# Create the buffers for recv:
recv_buffers = [
Expand All @@ -268,7 +271,7 @@ def redistribute_local_shard_tensor(
*,
async_op: bool = False,
is_backward: bool = False,
target_sharding_shapes: dict[int, tuple[torch.Size, ...]] | None = None,
target_sharding_shapes: dict[int, tuple[tuple[int, ...], ...]] | None = None,
) -> torch.Tensor:
r"""Redistribute a local tensor between different ShardTensorSpec configurations.

Expand Down Expand Up @@ -304,8 +307,9 @@ def redistribute_local_shard_tensor(
Whether to run asynchronously.
is_backward : bool, default=False
Whether this is a backward pass (affects some redistribution behaviors).
target_sharding_shapes : Optional[Dict[int, Tuple[torch.Size, ...]]], optional
Target sharding shapes to use for redistribution. Default is empty dict.
target_sharding_shapes : Optional[Dict[int, Tuple[Tuple[int, ...], ...]]], optional
Target sharding shapes (plain int tuples) to use for redistribution.
Default is empty dict.

Returns
-------
Expand Down Expand Up @@ -549,7 +553,6 @@ class ShardRedistribute(torch.autograd.Function):

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input: "shard_tensor.ShardTensor",
device_mesh: DeviceMesh,
placements: tuple[Placement, ...],
Expand All @@ -559,8 +562,6 @@ def forward(

Parameters
----------
ctx : torch.autograd.function.FunctionCtx
Autograd context for saving tensors/variables for backward.
input : ShardTensor
Input sharded tensor to redistribute.
device_mesh : DeviceMesh
Expand All @@ -577,9 +578,6 @@ def forward(
"""
current_spec = input._spec

ctx.current_spec = current_spec
ctx.async_op = async_op

if current_spec.placements != placements:
# We have to assume, here, that the current spec has correct sharding_shapes.
# Therefore, we can use the target placement + current sharding_shapes
Expand Down Expand Up @@ -612,6 +610,36 @@ def forward(
)
# Set the local shape:
target_spec._local_shape = output.shape

# Populate _sharding_shapes on the target spec so downstream
# consumers (especially under torch.compile) don't trip
# `_all_gather_shard_shapes` -- a blocking collective that is
# not AOT-traceable. Start from chunk semantics (pure
# arithmetic, no comms) and override preserved-shard tensor
# dims with the precomputed per-rank sizes from
# `target_sharding_shapes` so uneven sharding is preserved.
global_shape = tuple(input._spec.tensor_meta.shape)
chunk_shapes = compute_sharding_shapes_from_chunking_global_shape(
device_mesh, placements, global_shape
)
for mesh_dim, placement in enumerate(placements):
if not isinstance(placement, Shard):
continue
tensor_dim = placement.dim
if tensor_dim in target_sharding_shapes:
mesh_size = device_mesh.size(mesh_dim)
per_rank_sizes = target_sharding_shapes[tensor_dim]
if len(per_rank_sizes) == mesh_size:
overridden = []
for rank_size in per_rank_sizes:
rank_shape = list(global_shape)
rank_shape[tensor_dim] = int(rank_size)
overridden.append(rank_shape)
chunk_shapes[mesh_dim] = overridden
target_spec._sharding_shapes = {
mesh_dim: tuple(tuple(s) for s in shapes)
for mesh_dim, shapes in chunk_shapes.items()
}
else:
# use the same local tensor if placements are the same.
output = input._local_tensor
Expand All @@ -623,6 +651,20 @@ def forward(
requires_grad=input.requires_grad,
)

@staticmethod
def setup_context(ctx, inputs, output) -> None:
r"""Save the source spec and ``async_op`` flag for the backward redistribute.

``DisableTorchFunctionSubclass`` shielding avoids re-entering the
ShardTensor ``__torch_function__`` fallback while reading
``input._spec`` -- the same AOT-hostile bridge motivated the
shielding in ``ShardedSum.setup_context``.
"""
input, _device_mesh, _placements, async_op = inputs
with torch._C.DisableTorchFunctionSubclass():
ctx.current_spec = input._spec
ctx.async_op = async_op

@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
Expand Down
118 changes: 67 additions & 51 deletions physicsnemo/domain_parallel/_shard_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,26 @@ class ShardTensorSpec(DTensorSpec):
----------
_local_shape : Optional[torch.Size]
The shape of the local shard of the tensor.
_sharding_shapes : Optional[dict[int, Tuple[torch.Size, ...]]]
_sharding_shapes : Optional[dict[int, Tuple[Tuple[int, ...], ...]]]
Mapping from mesh dimension to shard shapes. Keys are mesh dimensions,
values are tuples of ``torch.Size`` representing shard shapes along
values are tuples of plain int tuples representing shard shapes along
that dimension. Shard shapes are only tracked along the sharded
dimensions, not replicated dimensions.

Storage type note: we deliberately use plain ``tuple[int, ...]``
rather than ``torch.Size`` here. ``torch.Size`` is special-cased by
PyTorch's symbolic shape machinery: when a ``ShardTensor`` is
fakeified by dynamo, any ``torch.Size`` stored in this dict has its
contained ints converted into unbacked ``SymInt``s. Those SymInts
then orphan whenever an op's output drops or filters
``_sharding_shapes`` (e.g. Partial-only outputs from reductions),
producing ``PendingUnbackedSymbolNotFound`` errors during AOT
tracing. Plain Python int tuples don't trigger this path.
"""

_local_shape: torch.Size | None = field(default_factory=lambda: None)
# This dict is a mapping from the mesh dimension to the shard shapes, _not_ the tensor index
_sharding_shapes: dict[int, tuple[torch.Size, ...]] | None = field(
_sharding_shapes: dict[int, tuple[tuple[int, ...], ...]] | None = field(
default_factory=lambda: None
)

Expand Down Expand Up @@ -101,7 +111,7 @@ def __hash__(self) -> int:

def sharding_shapes(
self, mesh_dim: int | None = None
) -> dict[int, tuple[torch.Size, ...]] | tuple[torch.Size, ...]:
) -> dict[int, tuple[tuple[int, ...], ...]] | tuple[tuple[int, ...], ...]:
r"""Get the shapes of shards along specified mesh dimensions.

Parameters
Expand All @@ -111,7 +121,7 @@ def sharding_shapes(

Returns
-------
Union[Dict[int, Tuple[torch.Size, ...]], Tuple[torch.Size, ...]]
Union[Dict[int, Tuple[Tuple[int, ...], ...]], Tuple[Tuple[int, ...], ...]]
Dictionary of shard shapes by mesh dim if ``mesh_dim`` is ``None``,
or tuple of shapes for the specific mesh dimension.
"""
Expand Down Expand Up @@ -309,7 +319,7 @@ def _gather_shard_shapes_for_dim(

dist.all_gather(all_shapes, local_shape, group=local_group)

all_shapes = [torch.Size(s.cpu().tolist()) for s in all_shapes]
all_shapes = [tuple(s.cpu().tolist()) for s in all_shapes]

if do_checks:
# Check that all shapes are the same rank
Expand All @@ -336,7 +346,7 @@ def _all_gather_shard_shapes(
placements: tuple[Placement, ...],
target_mesh: DeviceMesh,
do_checks: bool = False,
) -> tuple[dict[int, tuple[torch.Size, ...]], tuple[int, ...]]:
) -> tuple[dict[int, tuple[tuple[int, ...], ...]], tuple[int, ...]]:
r"""Gather shard shapes from all ranks across all sharded mesh dimensions.

Parameters
Expand All @@ -352,7 +362,7 @@ def _all_gather_shard_shapes(

Returns
-------
Tuple[Dict[int, Tuple[torch.Size, ...]], Tuple[int, ...]]
Tuple[Dict[int, Tuple[Tuple[int, ...], ...]], Tuple[int, ...]]
Tuple containing:

- Dictionary mapping mesh dimensions to tuples of shard shapes.
Expand Down Expand Up @@ -389,13 +399,13 @@ def compute_sharding_shapes_from_chunking_global_shape(
mesh: DeviceMesh,
placements: tuple[Placement, ...],
global_shape: tuple[int, ...],
) -> dict[int, list[torch.Size]]:
) -> dict[int, list[tuple[int, ...]]]:
r"""Compute shard sizes for each mesh dimension based on global shape.

For each sharded dimension in the mesh, computes the chunk sizes that
would result from evenly dividing the global tensor shape. Returns a
mapping from mesh dimensions to lists of ``torch.Size`` objects
representing the shape of each shard.
mapping from mesh dimensions to lists of plain int tuples representing
the shape of each shard.

Parameters
----------
Expand All @@ -408,8 +418,8 @@ def compute_sharding_shapes_from_chunking_global_shape(

Returns
-------
Dict[int, List[torch.Size]]
Dictionary mapping mesh dimensions to lists of ``torch.Size`` objects
Dict[int, List[Tuple[int, ...]]]
Dictionary mapping mesh dimensions to lists of plain int tuples
representing shard shapes for that dimension.

Raises
Expand All @@ -420,45 +430,45 @@ def compute_sharding_shapes_from_chunking_global_shape(
if len(placements) != mesh.ndim:
raise ValueError("Number of placements must match mesh dimensions")

# First compute raw chunk sizes for each sharded dimension
temp_sharding_shapes: dict[int, list[int]] = {}
for i in range(mesh.ndim):
if isinstance(placements[i], Shard):
# Compute the chunk size for this dimension:
input_dim = global_shape[placements[i].dim]
chunked_shapes = compute_split_shapes(input_dim, mesh.size(i))

# for each tensor in the list

temp_sharding_shapes[i] = chunked_shapes

# Temp sharding shapes always has a key for each mesh dim.
# Each is a list with length = size of that mesh dim.

# Initialize shapes for all sharded dimensions, but using the global shape.
# We will update next.
sharding_shapes = {
mesh_dim: [list(global_shape) for _ in chunks]
for mesh_dim, chunks in temp_sharding_shapes.items()
# Compute the full per-rank chunk-size lists for each sharded mesh dim
# (the same on every rank, derived purely from the global shape +
# mesh size via ``compute_split_shapes``).
chunk_sizes_per_dim: dict[int, list[int]] = {}
for m in range(mesh.ndim):
if isinstance(placements[m], Shard):
input_dim = global_shape[placements[m].dim]
chunk_sizes_per_dim[m] = compute_split_shapes(input_dim, mesh.size(m))

# This rank's chunk for each sharded mesh dim. Used to fill in tensor
# dims sharded along *other* mesh dims when constructing a given mesh
# dim's per-rank shape list.
this_rank_chunks: dict[int, int] = {
m: chunks[mesh.get_local_rank(m)] for m, chunks in chunk_sizes_per_dim.items()
}

# Go through and reduce each mesh dim to the right shape for _this_ rank
for mesh_dim, shape_list in temp_sharding_shapes.items():
this_rank = mesh.get_local_rank(mesh_dim)
temp_sharding_shapes[mesh_dim] = shape_list[this_rank]

# Finally, update the sharded shape with the right chunk size:
for shape_list in sharding_shapes.values():
for inner_mesh_dim, chunk_size in temp_sharding_shapes.items():
tensor_dim = placements[inner_mesh_dim].dim
for shape in shape_list:
shape[tensor_dim] = chunk_size

# Convert to immutable torch.Size
return {
mesh_dim: [torch.Size(tuple(size)) for size in sizes]
for mesh_dim, sizes in sharding_shapes.items()
}
# For each sharded mesh dim ``m``, build a list of length ``mesh.size(m)``
# where entry ``r`` is the local shape that rank ``r`` (along mesh_dim
# ``m``) holds. Along tensor dim ``placements[m].dim`` the value is
# rank ``r``'s chunk (varies). Along tensor dims sharded by *other*
# mesh dims, we use this rank's coordinate -- matching the historical
# multi-dim semantics where ``_sharding_shapes[mesh_dim][r]`` is the
# rank-``r``-on-mesh-dim-``m`` cross-section through this rank's
# coordinates on every other mesh dim.
sharding_shapes: dict[int, list[tuple[int, ...]]] = {}
for m, chunks in chunk_sizes_per_dim.items():
shape_list: list[tuple[int, ...]] = []
for r, rank_chunk in enumerate(chunks):
shape = list(global_shape)
shape[placements[m].dim] = rank_chunk
for other_m, other_chunk in this_rank_chunks.items():
if other_m == m:
continue
shape[placements[other_m].dim] = other_chunk
# Plain int tuple (not torch.Size) -- see field docstring.
shape_list.append(tuple(shape))
sharding_shapes[m] = shape_list

return sharding_shapes


def _infer_shard_tensor_spec_from_local_chunks(
Expand Down Expand Up @@ -581,7 +591,13 @@ def _infer_shard_tensor_spec_from_local_chunks(
shape=tuple(global_shape), stride=stride, dtype=local_chunk.dtype
)

sharding_shapes = {dim: tuple(s) for dim, s in shard_shapes_by_dim.items()}
# Normalize inner shapes to plain int tuples (never torch.Size) -- see the
# ``ShardTensorSpec._sharding_shapes`` field docstring for the dynamo /
# fakeification rationale.
sharding_shapes = {
dim: tuple(tuple(inner) for inner in shapes)
for dim, shapes in shard_shapes_by_dim.items()
}
return ShardTensorSpec(
mesh=target_mesh,
placements=placements,
Expand Down
1 change: 1 addition & 0 deletions physicsnemo/domain_parallel/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
if ST_AVAILABLE:
from . import _tensor_ops # noqa: F401 # registers unbind handlers
from ._reductions import mean_wrapper, sum_wrapper
from ._tensor_ops import unbind_wrapper
Loading
Loading