diff --git a/physicsnemo/domain_parallel/__init__.py b/physicsnemo/domain_parallel/__init__.py index 467bfb0a70..f13e2db31a 100644 --- a/physicsnemo/domain_parallel/__init__.py +++ b/physicsnemo/domain_parallel/__init__.py @@ -47,15 +47,27 @@ # In minumum versions are met, we can import the shard tensor and spec. from ._shard_tensor_spec import ShardTensorSpec - from .shard_tensor import ShardTensor, scatter_tensor + from .shard_tensor import ( + FSDPOutputTensorAdapter, + ShardTensor, + distribute_over_domain_for_fsdp, + scatter_tensor, + wrap_for_fsdp, + ) def register_custom_ops(): + """Register all custom ShardTensor ops and shard-aware wrappers. + + Imports are deferred to this function to avoid an import cycle between + ``shard_tensor`` and the individual op modules. + """ # These imports will register the custom ops with the ShardTensor class. # It's done here to avoid an import cycle. from .custom_ops import ( # noqa: F401 _tensor_ops, mean_wrapper, sum_wrapper, + unbind_wrapper, ) from .shard_utils import register_shard_wrappers @@ -69,3 +81,6 @@ def register_custom_ops(): ShardTensor = None ShardTensorSpec = None scatter_tensor = None + distribute_over_domain_for_fsdp = None + FSDPOutputTensorAdapter = None + wrap_for_fsdp = None diff --git a/physicsnemo/domain_parallel/_shard_redistribute.py b/physicsnemo/domain_parallel/_shard_redistribute.py index 06d14bc45d..1883612909 100644 --- a/physicsnemo/domain_parallel/_shard_redistribute.py +++ b/physicsnemo/domain_parallel/_shard_redistribute.py @@ -37,7 +37,10 @@ ) import physicsnemo.domain_parallel.shard_tensor as shard_tensor -from physicsnemo.domain_parallel._shard_tensor_spec import ShardTensorSpec +from physicsnemo.domain_parallel._shard_tensor_spec import ( + ShardTensorSpec, + compute_sharding_shapes_from_chunking_global_shape, +) # TODO: # DTensor makes assumptions about sharding sizes. @@ -242,8 +245,8 @@ def _to_new_shard_dim( # But we can optimize the null-communication dist.all_to_all(recv_shapes, send_shapes, group=group) - # Turn the recv_shapes back into torch shapes: - recv_shapes = [list(torch.Size(r)) for r in recv_shapes] + # Turn the recv_shapes back into plain int shape lists. + recv_shapes = [r.tolist() for r in recv_shapes] # Create the buffers for recv: recv_buffers = [ @@ -268,7 +271,7 @@ def redistribute_local_shard_tensor( *, async_op: bool = False, is_backward: bool = False, - target_sharding_shapes: dict[int, tuple[torch.Size, ...]] | None = None, + target_sharding_shapes: dict[int, tuple[tuple[int, ...], ...]] | None = None, ) -> torch.Tensor: r"""Redistribute a local tensor between different ShardTensorSpec configurations. @@ -304,8 +307,9 @@ def redistribute_local_shard_tensor( Whether to run asynchronously. is_backward : bool, default=False Whether this is a backward pass (affects some redistribution behaviors). - target_sharding_shapes : Optional[Dict[int, Tuple[torch.Size, ...]]], optional - Target sharding shapes to use for redistribution. Default is empty dict. + target_sharding_shapes : Optional[Dict[int, Tuple[Tuple[int, ...], ...]]], optional + Target sharding shapes (plain int tuples) to use for redistribution. + Default is empty dict. Returns ------- @@ -549,7 +553,6 @@ class ShardRedistribute(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, input: "shard_tensor.ShardTensor", device_mesh: DeviceMesh, placements: tuple[Placement, ...], @@ -559,8 +562,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving tensors/variables for backward. input : ShardTensor Input sharded tensor to redistribute. device_mesh : DeviceMesh @@ -577,9 +578,6 @@ def forward( """ current_spec = input._spec - ctx.current_spec = current_spec - ctx.async_op = async_op - if current_spec.placements != placements: # We have to assume, here, that the current spec has correct sharding_shapes. # Therefore, we can use the target placement + current sharding_shapes @@ -612,6 +610,36 @@ def forward( ) # Set the local shape: target_spec._local_shape = output.shape + + # Populate _sharding_shapes on the target spec so downstream + # consumers (especially under torch.compile) don't trip + # `_all_gather_shard_shapes` -- a blocking collective that is + # not AOT-traceable. Start from chunk semantics (pure + # arithmetic, no comms) and override preserved-shard tensor + # dims with the precomputed per-rank sizes from + # `target_sharding_shapes` so uneven sharding is preserved. + global_shape = tuple(input._spec.tensor_meta.shape) + chunk_shapes = compute_sharding_shapes_from_chunking_global_shape( + device_mesh, placements, global_shape + ) + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Shard): + continue + tensor_dim = placement.dim + if tensor_dim in target_sharding_shapes: + mesh_size = device_mesh.size(mesh_dim) + per_rank_sizes = target_sharding_shapes[tensor_dim] + if len(per_rank_sizes) == mesh_size: + overridden = [] + for rank_size in per_rank_sizes: + rank_shape = list(global_shape) + rank_shape[tensor_dim] = int(rank_size) + overridden.append(rank_shape) + chunk_shapes[mesh_dim] = overridden + target_spec._sharding_shapes = { + mesh_dim: tuple(tuple(s) for s in shapes) + for mesh_dim, shapes in chunk_shapes.items() + } else: # use the same local tensor if placements are the same. output = input._local_tensor @@ -623,6 +651,20 @@ def forward( requires_grad=input.requires_grad, ) + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save the source spec and ``async_op`` flag for the backward redistribute. + + ``DisableTorchFunctionSubclass`` shielding avoids re-entering the + ShardTensor ``__torch_function__`` fallback while reading + ``input._spec`` -- the same AOT-hostile bridge motivated the + shielding in ``ShardedSum.setup_context``. + """ + input, _device_mesh, _placements, async_op = inputs + with torch._C.DisableTorchFunctionSubclass(): + ctx.current_spec = input._spec + ctx.async_op = async_op + @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, diff --git a/physicsnemo/domain_parallel/_shard_tensor_spec.py b/physicsnemo/domain_parallel/_shard_tensor_spec.py index 1fdd0c89f7..58699d5e75 100644 --- a/physicsnemo/domain_parallel/_shard_tensor_spec.py +++ b/physicsnemo/domain_parallel/_shard_tensor_spec.py @@ -45,16 +45,26 @@ class ShardTensorSpec(DTensorSpec): ---------- _local_shape : Optional[torch.Size] The shape of the local shard of the tensor. - _sharding_shapes : Optional[dict[int, Tuple[torch.Size, ...]]] + _sharding_shapes : Optional[dict[int, Tuple[Tuple[int, ...], ...]]] Mapping from mesh dimension to shard shapes. Keys are mesh dimensions, - values are tuples of ``torch.Size`` representing shard shapes along + values are tuples of plain int tuples representing shard shapes along that dimension. Shard shapes are only tracked along the sharded dimensions, not replicated dimensions. + + Storage type note: we deliberately use plain ``tuple[int, ...]`` + rather than ``torch.Size`` here. ``torch.Size`` is special-cased by + PyTorch's symbolic shape machinery: when a ``ShardTensor`` is + fakeified by dynamo, any ``torch.Size`` stored in this dict has its + contained ints converted into unbacked ``SymInt``s. Those SymInts + then orphan whenever an op's output drops or filters + ``_sharding_shapes`` (e.g. Partial-only outputs from reductions), + producing ``PendingUnbackedSymbolNotFound`` errors during AOT + tracing. Plain Python int tuples don't trigger this path. """ _local_shape: torch.Size | None = field(default_factory=lambda: None) # This dict is a mapping from the mesh dimension to the shard shapes, _not_ the tensor index - _sharding_shapes: dict[int, tuple[torch.Size, ...]] | None = field( + _sharding_shapes: dict[int, tuple[tuple[int, ...], ...]] | None = field( default_factory=lambda: None ) @@ -101,7 +111,7 @@ def __hash__(self) -> int: def sharding_shapes( self, mesh_dim: int | None = None - ) -> dict[int, tuple[torch.Size, ...]] | tuple[torch.Size, ...]: + ) -> dict[int, tuple[tuple[int, ...], ...]] | tuple[tuple[int, ...], ...]: r"""Get the shapes of shards along specified mesh dimensions. Parameters @@ -111,7 +121,7 @@ def sharding_shapes( Returns ------- - Union[Dict[int, Tuple[torch.Size, ...]], Tuple[torch.Size, ...]] + Union[Dict[int, Tuple[Tuple[int, ...], ...]], Tuple[Tuple[int, ...], ...]] Dictionary of shard shapes by mesh dim if ``mesh_dim`` is ``None``, or tuple of shapes for the specific mesh dimension. """ @@ -309,7 +319,7 @@ def _gather_shard_shapes_for_dim( dist.all_gather(all_shapes, local_shape, group=local_group) - all_shapes = [torch.Size(s.cpu().tolist()) for s in all_shapes] + all_shapes = [tuple(s.cpu().tolist()) for s in all_shapes] if do_checks: # Check that all shapes are the same rank @@ -336,7 +346,7 @@ def _all_gather_shard_shapes( placements: tuple[Placement, ...], target_mesh: DeviceMesh, do_checks: bool = False, -) -> tuple[dict[int, tuple[torch.Size, ...]], tuple[int, ...]]: +) -> tuple[dict[int, tuple[tuple[int, ...], ...]], tuple[int, ...]]: r"""Gather shard shapes from all ranks across all sharded mesh dimensions. Parameters @@ -352,7 +362,7 @@ def _all_gather_shard_shapes( Returns ------- - Tuple[Dict[int, Tuple[torch.Size, ...]], Tuple[int, ...]] + Tuple[Dict[int, Tuple[Tuple[int, ...], ...]], Tuple[int, ...]] Tuple containing: - Dictionary mapping mesh dimensions to tuples of shard shapes. @@ -389,13 +399,13 @@ def compute_sharding_shapes_from_chunking_global_shape( mesh: DeviceMesh, placements: tuple[Placement, ...], global_shape: tuple[int, ...], -) -> dict[int, list[torch.Size]]: +) -> dict[int, list[tuple[int, ...]]]: r"""Compute shard sizes for each mesh dimension based on global shape. For each sharded dimension in the mesh, computes the chunk sizes that would result from evenly dividing the global tensor shape. Returns a - mapping from mesh dimensions to lists of ``torch.Size`` objects - representing the shape of each shard. + mapping from mesh dimensions to lists of plain int tuples representing + the shape of each shard. Parameters ---------- @@ -408,8 +418,8 @@ def compute_sharding_shapes_from_chunking_global_shape( Returns ------- - Dict[int, List[torch.Size]] - Dictionary mapping mesh dimensions to lists of ``torch.Size`` objects + Dict[int, List[Tuple[int, ...]]] + Dictionary mapping mesh dimensions to lists of plain int tuples representing shard shapes for that dimension. Raises @@ -420,45 +430,45 @@ def compute_sharding_shapes_from_chunking_global_shape( if len(placements) != mesh.ndim: raise ValueError("Number of placements must match mesh dimensions") - # First compute raw chunk sizes for each sharded dimension - temp_sharding_shapes: dict[int, list[int]] = {} - for i in range(mesh.ndim): - if isinstance(placements[i], Shard): - # Compute the chunk size for this dimension: - input_dim = global_shape[placements[i].dim] - chunked_shapes = compute_split_shapes(input_dim, mesh.size(i)) - - # for each tensor in the list - - temp_sharding_shapes[i] = chunked_shapes - - # Temp sharding shapes always has a key for each mesh dim. - # Each is a list with length = size of that mesh dim. - - # Initialize shapes for all sharded dimensions, but using the global shape. - # We will update next. - sharding_shapes = { - mesh_dim: [list(global_shape) for _ in chunks] - for mesh_dim, chunks in temp_sharding_shapes.items() + # Compute the full per-rank chunk-size lists for each sharded mesh dim + # (the same on every rank, derived purely from the global shape + + # mesh size via ``compute_split_shapes``). + chunk_sizes_per_dim: dict[int, list[int]] = {} + for m in range(mesh.ndim): + if isinstance(placements[m], Shard): + input_dim = global_shape[placements[m].dim] + chunk_sizes_per_dim[m] = compute_split_shapes(input_dim, mesh.size(m)) + + # This rank's chunk for each sharded mesh dim. Used to fill in tensor + # dims sharded along *other* mesh dims when constructing a given mesh + # dim's per-rank shape list. + this_rank_chunks: dict[int, int] = { + m: chunks[mesh.get_local_rank(m)] for m, chunks in chunk_sizes_per_dim.items() } - # Go through and reduce each mesh dim to the right shape for _this_ rank - for mesh_dim, shape_list in temp_sharding_shapes.items(): - this_rank = mesh.get_local_rank(mesh_dim) - temp_sharding_shapes[mesh_dim] = shape_list[this_rank] - - # Finally, update the sharded shape with the right chunk size: - for shape_list in sharding_shapes.values(): - for inner_mesh_dim, chunk_size in temp_sharding_shapes.items(): - tensor_dim = placements[inner_mesh_dim].dim - for shape in shape_list: - shape[tensor_dim] = chunk_size - - # Convert to immutable torch.Size - return { - mesh_dim: [torch.Size(tuple(size)) for size in sizes] - for mesh_dim, sizes in sharding_shapes.items() - } + # For each sharded mesh dim ``m``, build a list of length ``mesh.size(m)`` + # where entry ``r`` is the local shape that rank ``r`` (along mesh_dim + # ``m``) holds. Along tensor dim ``placements[m].dim`` the value is + # rank ``r``'s chunk (varies). Along tensor dims sharded by *other* + # mesh dims, we use this rank's coordinate -- matching the historical + # multi-dim semantics where ``_sharding_shapes[mesh_dim][r]`` is the + # rank-``r``-on-mesh-dim-``m`` cross-section through this rank's + # coordinates on every other mesh dim. + sharding_shapes: dict[int, list[tuple[int, ...]]] = {} + for m, chunks in chunk_sizes_per_dim.items(): + shape_list: list[tuple[int, ...]] = [] + for r, rank_chunk in enumerate(chunks): + shape = list(global_shape) + shape[placements[m].dim] = rank_chunk + for other_m, other_chunk in this_rank_chunks.items(): + if other_m == m: + continue + shape[placements[other_m].dim] = other_chunk + # Plain int tuple (not torch.Size) -- see field docstring. + shape_list.append(tuple(shape)) + sharding_shapes[m] = shape_list + + return sharding_shapes def _infer_shard_tensor_spec_from_local_chunks( @@ -581,7 +591,13 @@ def _infer_shard_tensor_spec_from_local_chunks( shape=tuple(global_shape), stride=stride, dtype=local_chunk.dtype ) - sharding_shapes = {dim: tuple(s) for dim, s in shard_shapes_by_dim.items()} + # Normalize inner shapes to plain int tuples (never torch.Size) -- see the + # ``ShardTensorSpec._sharding_shapes`` field docstring for the dynamo / + # fakeification rationale. + sharding_shapes = { + dim: tuple(tuple(inner) for inner in shapes) + for dim, shapes in shard_shapes_by_dim.items() + } return ShardTensorSpec( mesh=target_mesh, placements=placements, diff --git a/physicsnemo/domain_parallel/custom_ops/__init__.py b/physicsnemo/domain_parallel/custom_ops/__init__.py index 87b98e0d1d..c91ae3dab5 100644 --- a/physicsnemo/domain_parallel/custom_ops/__init__.py +++ b/physicsnemo/domain_parallel/custom_ops/__init__.py @@ -22,3 +22,4 @@ if ST_AVAILABLE: from . import _tensor_ops # noqa: F401 # registers unbind handlers from ._reductions import mean_wrapper, sum_wrapper + from ._tensor_ops import unbind_wrapper diff --git a/physicsnemo/domain_parallel/custom_ops/_reductions.py b/physicsnemo/domain_parallel/custom_ops/_reductions.py index b673e6bec4..3a8e2d656a 100644 --- a/physicsnemo/domain_parallel/custom_ops/_reductions.py +++ b/physicsnemo/domain_parallel/custom_ops/_reductions.py @@ -44,12 +44,17 @@ ) import torch +from torch.distributed.tensor._dtensor_spec import TensorMeta from torch.distributed.tensor.placement_types import ( Partial, Shard, ) # noqa: E402 +from physicsnemo.domain_parallel._shard_tensor_spec import ( + ShardTensorSpec, + _stride_from_contiguous_shape_C_style, +) from physicsnemo.domain_parallel.shard_tensor import ShardTensor aten = torch.ops.aten @@ -175,14 +180,14 @@ def compute_result_placements( def reduction_shape( - S: torch.Size, dim: DimT = None, keepdim: bool = False -) -> torch.Size: + S: tuple[int, ...], dim: DimT = None, keepdim: bool = False +) -> tuple[int, ...]: r"""Calculate the resulting shape after a reduction operation. Parameters ---------- - S : torch.Size - Original shape of the tensor. + S : tuple[int, ...] + Original shape of the tensor (may be a ``torch.Size`` or plain tuple). dim : DimT, optional The dimension(s) to reduce. Can be ``None``, ``int``, or iterable of ints. keepdim : bool, default=False @@ -190,12 +195,15 @@ def reduction_shape( Returns ------- - torch.Size - The shape after reduction. + tuple[int, ...] + The shape after reduction, returned as a plain int tuple (not + ``torch.Size``) so the result can be safely embedded in a + ``ShardTensorSpec._sharding_shapes`` dict without triggering dynamo's + symbolic-shape special-casing for ``torch.Size``. """ shape = list(S) if dim is None: - return torch.Size([1] * len(shape)) if keepdim else torch.Size([]) + return tuple([1] * len(shape)) if keepdim else tuple() # Use enhanced normalize_dim to handle iterable and negative dims dim = normalize_dim(dim, len(shape), handle_negatives=True) @@ -206,12 +214,12 @@ def reduction_shape( else: for d in sorted(dim, reverse=True): del shape[d] - return torch.Size(shape) + return tuple(shape) def compute_result_sharding_shapes( tensor: ShardTensor, dim: DimT, keepdim: bool -) -> dict[int, list[torch.Size]]: +) -> dict[int, list[tuple[int, ...]]]: r"""Compute sharding sizes for the result of a reduction operation. Parameters @@ -225,8 +233,8 @@ def compute_result_sharding_shapes( Returns ------- - Dict[int, List[torch.Size]] - Mapping of mesh dimensions to sharding shapes. + Dict[int, List[Tuple[int, ...]]] + Mapping of mesh dimensions to plain int tuple sharding shapes. """ if is_full_reduction(dim, tensor.ndim): return {} @@ -248,6 +256,67 @@ def compute_result_sharding_shapes( return result_sharding_shapes +def build_reduction_result( + local_result: torch.Tensor, + input_tensor: ShardTensor, + placements: list[Partial | Shard], + sharding_shapes: dict[int, list[tuple[int, ...]]], +) -> ShardTensor: + r"""Construct a ShardTensor result from a local reduction output. + + Builds the ``ShardTensorSpec`` directly from the already-computed placements + and sharding shapes, avoiding the overhead and autograd side-effects of + ``ShardTensor.from_local``. + + Parameters + ---------- + local_result : torch.Tensor + The locally-computed reduction result. + input_tensor : ShardTensor + The original input ShardTensor (used for device mesh). + placements : List[Union[Partial, Shard]] + Result placements from :func:`compute_result_placements`. + sharding_shapes : Dict[int, List[Tuple[int, ...]]] + Result sharding shapes from :func:`compute_result_sharding_shapes`. + + Returns + ------- + ShardTensor + Wrapped result with correct sharding metadata. + """ + global_shape = list(local_result.shape) + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + tensor_dim = placement.dim + global_shape[tensor_dim] = sum( + s[tensor_dim] for s in sharding_shapes[mesh_dim] + ) + + stride = _stride_from_contiguous_shape_C_style(global_shape) + spec = ShardTensorSpec( + mesh=input_tensor.device_mesh, + placements=tuple(placements), + tensor_meta=TensorMeta( + shape=tuple(global_shape), + stride=stride, + dtype=local_result.dtype, + ), + _local_shape=local_result.shape, + # Normalize to plain int tuples (never torch.Size) for the + # _sharding_shapes field; see ShardTensorSpec docstring. + _sharding_shapes={ + dim: tuple(tuple(inner) for inner in s) + for dim, s in sharding_shapes.items() + }, + ) + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_result, + spec=spec, + requires_grad=input_tensor.requires_grad, + ) + + def create_sharded_grad_input( local_grad_input: torch.Tensor, original_spec: Any ) -> ShardTensor: @@ -265,11 +334,15 @@ def create_sharded_grad_input( ShardTensor A distributed tensor with the same sharding as the original input. """ - return ShardTensor.from_local( - local_grad_input, - device_mesh=original_spec.mesh, - placements=original_spec.placements, - sharding_shapes=original_spec.sharding_shapes(), + # In custom autograd backward, return the input gradient directly as a + # ShardTensor value. Avoid ``from_local`` here (which routes through a + # separate autograd Function) so the gradient is attached unambiguously to + # the original ShardTensor input. + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_grad_input, + spec=original_spec, + requires_grad=False, ) @@ -333,7 +406,6 @@ class ShardedSum(ShardedReductionBase): @staticmethod def forward( - ctx: Any, tensor: ShardTensor, dim: DimT = None, keepdim: bool = False, @@ -343,8 +415,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - The autograd context object. tensor : ShardTensor The input ShardTensor to be reduced. dim : DimT, optional @@ -358,27 +428,48 @@ def forward( ------- ShardTensor The result of sum reduction. + + Notes + ----- + The body runs under ``torch._C.DisableTorchFunctionSubclass``. + Reason: new-style autograd.Function (per-PyTorch design) executes + ``forward`` with grad-mode ON, and any property access on the + ShardTensor input (e.g. ``tensor.ndim`` -- a C-level getset + descriptor) re-enters ``__torch_function__`` -> the DTensor + fallback -> ``_ShardTensorToDTensor.apply``. The resulting + ``BackwardCFunction`` has a ``next_functions`` accessor that + raises a "legacy access pattern" error on newer PyTorch, + blocking AOTAutograd from walking the autograd graph. Shielding + these metadata-only accesses fully avoids that bridge for the + sum path. """ - dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) + with torch._C.DisableTorchFunctionSubclass(): + dim_n = normalize_dim(dim, tensor.ndim) + keepdim_n = bool(keepdim) - # Get local tensor - local_tensor = tensor._local_tensor - # Perform local sum - local_result = aten.sum(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) + local_result = aten.sum( + tensor._local_tensor, dim=dim_n, keepdim=keepdim_n, dtype=dtype + ) - # Compute placements for the result - placements = compute_result_placements(tensor, dim, "sum") - output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) + placements = compute_result_placements(tensor, dim_n, "sum") + sharding_shapes = compute_result_sharding_shapes(tensor, dim_n, keepdim_n) - # Create result ShardTensor - result = ShardTensor.from_local( - local_result, - tensor.device_mesh, - placements, - sharding_shapes=output_sharding_shapes, - ) + return build_reduction_result( + local_result, tensor, placements, sharding_shapes + ) - return result + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save the input ShardTensorSpec + normalized dim/keepdim for backward. + + Same ``DisableTorchFunctionSubclass`` shielding as ``forward`` so the + property accesses inside ``ShardedReductionBase.setup_ctx`` (e.g. + ``tensor.ndim``, ``tensor.requires_grad``) don't bridge through the + AOT-hostile autograd Function fallback. + """ + tensor, dim, keepdim, _dtype = inputs + with torch._C.DisableTorchFunctionSubclass(): + ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) @staticmethod def backward( @@ -446,7 +537,6 @@ class ShardedMean(ShardedReductionBase): @staticmethod def forward( - ctx: Any, tensor: ShardTensor, dim: DimT = None, keepdim: bool = False, @@ -456,8 +546,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - The autograd context object. tensor : ShardTensor The input ShardTensor to be reduced. dim : DimT, optional @@ -471,47 +559,71 @@ def forward( ------- ShardTensor The result of mean reduction. + + Notes + ----- + The body runs under ``torch._C.DisableTorchFunctionSubclass``. + Reason: new-style autograd.Function (per-PyTorch design) executes + ``forward`` with grad-mode ON, and any property access on the + ShardTensor input (e.g. ``tensor.ndim`` -- a C-level getset + descriptor) re-enters ``__torch_function__`` -> the DTensor + fallback -> ``_ShardTensorToDTensor.apply``. The resulting + ``BackwardCFunction`` has a ``next_functions`` accessor that + raises a "legacy access pattern" error on newer PyTorch, + blocking AOTAutograd from walking the autograd graph. Shielding + these metadata-only accesses fully avoids that bridge for the + mean path. """ - dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) + with torch._C.DisableTorchFunctionSubclass(): + dim_n = normalize_dim(dim, tensor.ndim) + keepdim_n = bool(keepdim) - # Get local tensor - local_tensor = tensor._local_tensor + # Get local tensor + local_tensor = tensor._local_tensor - # Compute proper weighting for mean - weight = 1.0 + # Compute proper weighting for mean + weight = 1.0 - # Normalize dimensions for consistent handling - if is_full_reduction(dim, tensor.ndim): - # For full reduction, use all dimensions - reduction_dims = set(range(tensor.ndim)) - else: - # Only use the normalized dimensions for partial reduction - reduction_dims = dim + # Normalize dimensions for consistent handling + if is_full_reduction(dim_n, tensor.ndim): + # For full reduction, use all dimensions + reduction_dims = set(range(tensor.ndim)) + else: + # Only use the normalized dimensions for partial reduction + reduction_dims = dim_n - # Calculate weight based on local vs global shape ratio for reduction dimensions - local_shape = local_tensor.shape - global_shape = tensor.shape + # Calculate weight based on local vs global shape ratio for reduction dimensions + local_shape = local_tensor.shape + global_shape = tensor.shape - for d in reduction_dims: - weight *= local_shape[d] / global_shape[d] + for d in reduction_dims: + weight *= local_shape[d] / global_shape[d] - # Perform local mean - local_result = aten.mean(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype) - # Apply weighting - local_result = local_result * weight + # Perform local mean and apply weighting for uneven shards + local_result = aten.mean( + local_tensor, dim=dim_n, keepdim=keepdim_n, dtype=dtype + ) + local_result = local_result * weight - placements = compute_result_placements(tensor, dim, "sum") - output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim) + placements = compute_result_placements(tensor, dim_n, "sum") + sharding_shapes = compute_result_sharding_shapes(tensor, dim_n, keepdim_n) - # Create result ShardTensor - result = ShardTensor.from_local( - local_result, - tensor.device_mesh, - placements, - sharding_shapes=output_sharding_shapes, - ) + return build_reduction_result( + local_result, tensor, placements, sharding_shapes + ) - return result + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save the input ShardTensorSpec + normalized dim/keepdim for backward. + + Same ``DisableTorchFunctionSubclass`` shielding as ``forward`` so the + property accesses inside ``ShardedReductionBase.setup_ctx`` (e.g. + ``tensor.ndim``, ``tensor.requires_grad``) don't bridge through the + AOT-hostile autograd Function fallback. + """ + tensor, dim, keepdim, _dtype = inputs + with torch._C.DisableTorchFunctionSubclass(): + ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim) @staticmethod def backward( diff --git a/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py b/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py index eca70a0a29..045f25af9a 100644 --- a/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py +++ b/physicsnemo/domain_parallel/custom_ops/_tensor_ops.py @@ -94,10 +94,10 @@ def _unbind_output_metadata( elif p.is_partial(): raise RuntimeError("Partial placement not supported yet for unbind") - out_sharding_shapes: dict[int, list[torch.Size]] = { - mesh_dim: [ - torch.Size(list(cs[:dim]) + list(cs[dim + 1 :])) for cs in shard_shapes - ] + # Plain int tuples (never torch.Size) -- see ShardTensorSpec._sharding_shapes + # field docs for the dynamo / fakeification rationale. + out_sharding_shapes: dict[int, list[tuple[int, ...]]] = { + mesh_dim: [tuple(list(cs[:dim]) + list(cs[dim + 1 :])) for cs in shard_shapes] for mesh_dim, shard_shapes in input_spec.sharding_shapes().items() } diff --git a/physicsnemo/domain_parallel/shard_tensor.py b/physicsnemo/domain_parallel/shard_tensor.py index 0a31ca7253..6da9b912ca 100644 --- a/physicsnemo/domain_parallel/shard_tensor.py +++ b/physicsnemo/domain_parallel/shard_tensor.py @@ -16,18 +16,21 @@ from __future__ import annotations +import threading from collections.abc import Iterable, Mapping +from contextlib import contextmanager from typing import Callable, Sequence, cast -from warnings import warn import torch import torch.distributed as dist +from torch import nn from torch.distributed.device_mesh import DeviceMesh, _mesh_resources -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, distribute_module from torch.distributed.tensor._dtensor_spec import ( TensorMeta, ) from torch.distributed.tensor.placement_types import ( + Partial, Placement, Replicate, Shard, @@ -36,72 +39,307 @@ from physicsnemo.distributed import DistributedManager from physicsnemo.domain_parallel._shard_redistribute import ( ShardRedistribute, + redistribute_local_shard_tensor, ) from physicsnemo.domain_parallel._shard_tensor_spec import ( ShardTensorSpec, _infer_shard_tensor_spec_from_local_chunks, _stride_from_contiguous_shape_C_style, + compute_sharding_shapes_from_chunking_global_shape, ) -from physicsnemo.utils.profiling import annotate, profile aten = torch.ops.aten -def _shard_tensor_to_dtensor(st: "ShardTensor") -> DTensor: - r"""Convert a ShardTensor to a plain DTensor for dispatch. +# ====================================================================== - Creates a DTensor with the same internal state as the ShardTensor, - which allows DTensor's dispatch to handle it correctly. +# ============================================================================ +# Layer 1 -- Semi-private conversions (no autograd, no spec inference) +# ============================================================================ - Parameters - ---------- - st : ShardTensor - The ShardTensor to convert. - Returns - ------- - DTensor - A DTensor sharing the same ``_local_tensor`` and ``_spec``. +def _shard_tensor_to_dtensor(st: "ShardTensor") -> DTensor: + r"""Convert a ShardTensor to a plain DTensor (no autograd). + + Creates a DTensor sharing the same ``_local_tensor`` and ``_spec``. + Use for dispatch or inside backward when building a DTensor gradient. """ - dtensor = torch.Tensor._make_wrapper_subclass( - DTensor, - st._spec.tensor_meta.shape, - strides=st._spec.tensor_meta.stride, - dtype=st.dtype, - device=st.device, - layout=st.layout, - requires_grad=st.requires_grad, - ) + if hasattr(torch.Tensor, "_dtensor__new__"): + dtensor = torch.Tensor._dtensor__new__( + DTensor, st._local_tensor, st._spec, requires_grad=st.requires_grad + ) + else: + dtensor = torch.Tensor._make_wrapper_subclass( + DTensor, + st._spec.tensor_meta.shape, + strides=st._spec.tensor_meta.stride, + dtype=st.dtype, + device=st.device, + layout=st.layout, + requires_grad=st.requires_grad, + ) dtensor._local_tensor = st._local_tensor dtensor._spec = st._spec return dtensor -def _convert_args_to_dtensor(arg: object) -> object: - r"""Recursively convert ShardTensors in args to DTensors. +def _dtensor_to_shard_tensor(dtensor: DTensor, spec: ShardTensorSpec) -> "ShardTensor": + r"""Promote a DTensor to a ShardTensor (no autograd). - Parameters - ---------- - arg : object - A single argument that may be a ShardTensor, an iterable of - arguments (e.g. list, tuple), a mapping (e.g. dict) whose - values are converted, or any other value. + Callers must supply a resolved ``spec``. Use inside backward (with spec + from ctx) or after resolving a spec via :func:`_resolve_spec_for_dtensor`. + """ + if isinstance(dtensor, ShardTensor): + # Shortcut if we're already a ShardTensor: + return dtensor + st = ShardTensor.__new__( + ShardTensor, + local_tensor=dtensor._local_tensor, + spec=spec, + requires_grad=dtensor.requires_grad, + ) + return st - Returns - ------- - object - The argument with any ShardTensors replaced by DTensors. + +# ============================================================================ +# Layer 2 -- Autograd Functions (use Layer 1 inside fwd / bwd) +# ============================================================================ + + +class _DTensorToShardTensor(torch.autograd.Function): + r"""Differentiable promotion: DTensor -> ShardTensor. + + This is to always connect the graphs for the backward pass + when we have to use a fallback option. + + Forward: :func:`_dtensor_to_shard_tensor`. + Backward: :func:`_shard_tensor_to_dtensor`. + """ + + @staticmethod + def forward(dtensor: DTensor, spec: ShardTensorSpec) -> "ShardTensor": + return _dtensor_to_shard_tensor(dtensor, spec) + + @staticmethod + def setup_context(ctx, inputs, output) -> None: + # Nothing to save; backward only needs grad_output. + pass + + @staticmethod + def backward(ctx, grad_output: "ShardTensor"): + return _shard_tensor_to_dtensor(grad_output), None + + +class _ShardTensorToDTensor(torch.autograd.Function): + r"""Differentiable conversion: ShardTensor -> DTensor. + + This is to always connect the graphs for the backward pass + when we have to use a fallback option. + + Forward: :func:`_shard_tensor_to_dtensor` (caches spec). + Backward: :func:`_dtensor_to_shard_tensor` (reuses cached spec). + """ + + @staticmethod + def forward(st: "ShardTensor") -> DTensor: + return _shard_tensor_to_dtensor(st) + + @staticmethod + def setup_context(ctx, inputs, output) -> None: + (st,) = inputs + ctx.shard_tensor_spec = st._spec + + @staticmethod + def backward(ctx, grad_output: DTensor): + return (_dtensor_to_shard_tensor(grad_output, ctx.shard_tensor_spec),) + + +# ============================================================================ +# Layer 3 -- Smart single-tensor converters (auto-diff when grad_fn present) +# ============================================================================ + + +def _resolve_spec_for_dtensor( + dtensor: DTensor, input_args: tuple = () +) -> ShardTensorSpec: + r"""Resolve a ShardTensorSpec for *dtensor*. + + Tries to reuse a spec from a ShardTensor in *input_args* whose + ``tensor_meta`` and ``placements`` match. Falls back to chunk-based + inference (no communication). + """ + for arg in input_args: + if ( + isinstance(arg, ShardTensor) + and dtensor._spec.tensor_meta == arg._spec.tensor_meta + and dtensor._spec.placements == arg._spec.placements + ): + return arg._spec + return _infer_shard_tensor_spec_from_local_chunks( + dtensor._local_tensor, + dtensor._spec.mesh, + dtensor._spec.placements, + sharding_shapes="chunk", + global_shape=dtensor.shape, + ) + + +# This is a thread-safe reentry guard. +# Goal is to prevent recursion into the fallback conversion paths. +_conversion_guard = threading.local() + + +def _conversion_active() -> bool: + r"""Return whether ShardTensor<->DTensor conversion is currently active.""" + return getattr(_conversion_guard, "depth", 0) > 0 + + +@contextmanager +def _conversion_scope(): + r"""Re-entrant conversion guard for cast-down/cast-up paths.""" + previous_depth = getattr(_conversion_guard, "depth", 0) + _conversion_guard.depth = previous_depth + 1 + try: + yield + finally: + if previous_depth == 0: + delattr(_conversion_guard, "depth") + else: + _conversion_guard.depth = previous_depth + + +def _dispatch_fallback_via_dtensor( + func: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object] | None = None, +) -> object: + r"""Execute an ATen op through DTensor fallback using PURE data conversion. + + Native Autograd wraps this hook, so we must NOT build an internal graph + using .apply(). We just do the math and let PyTorch track the outer graph. + """ + with _conversion_scope(): + converted_args = tuple( + _convert_args_to_dtensor(arg, use_autograd=False) for arg in args + ) + converted_kwargs = { + k: _convert_args_to_dtensor(v, use_autograd=False) + for k, v in (kwargs or {}).items() + } + + dispatch_res = func(*converted_args, **(converted_kwargs or {})) + + with _conversion_scope(): + return _convert_results_to_shard_tensor(dispatch_res, args, use_autograd=False) + + +def _torch_function_fallback_via_dtensor( + func: Callable, + args: tuple[object, ...], + kwargs: dict[str, object] | None = None, +) -> object: + r"""Execute a __torch_function__ fallback through DTensor safely. + + Because this executes at the Python API level (above Autograd), we MUST + use autograd functions (.apply) to bridge the tracking manually. """ - # ShardTensor is defined later in this module; the isinstance check - # is safe because this function is only called at runtime. - if isinstance(arg, ShardTensor): - return _shard_tensor_to_dtensor(arg) - elif isinstance(arg, Mapping): - return type(arg)({k: _convert_args_to_dtensor(v) for k, v in arg.items()}) - elif isinstance(arg, Iterable) and not isinstance(arg, (str, bytes)): - converted = [_convert_args_to_dtensor(a) for a in arg] - return type(arg)(converted) - return arg + with _conversion_scope(): + converted_args = tuple( + _convert_args_to_dtensor(arg, use_autograd=True) for arg in args + ) + converted_kwargs = { + k: _convert_args_to_dtensor(v, use_autograd=True) + for k, v in (kwargs or {}).items() + } + + with torch._C.DisableTorchFunctionSubclass(): + result = func(*converted_args, **converted_kwargs) + + with _conversion_scope(): + return _convert_results_to_shard_tensor(result, args, use_autograd=True) + + +# ============================================================================ +# Layer 4 -- Recurse utilities (walk args / kwargs / results) +# ============================================================================ + + +def _convert_args_to_dtensor(arg: object, use_autograd: bool = False) -> object: + r"""Recursively replace ShardTensors with DTensors. + + If use_autograd is True, uses Layer 2 to preserve the graph connection. + """ + match arg: + case ShardTensor(): + if use_autograd and arg.requires_grad and torch.is_grad_enabled(): + return _ShardTensorToDTensor.apply(arg) + return _shard_tensor_to_dtensor(arg) + case DTensor(): + # DTensor can be iterable; exit early deliberately + return arg + case Mapping(): + return type(arg)( + {k: _convert_args_to_dtensor(v, use_autograd) for k, v in arg.items()} + ) + case tuple(): + return tuple(_convert_args_to_dtensor(a, use_autograd) for a in arg) + case list(): + return [_convert_args_to_dtensor(a, use_autograd) for a in arg] + case _: + return arg + + +def _convert_results_to_shard_tensor( + result: object, input_args: tuple, use_autograd: bool = False +) -> object: + r"""Recursively replace DTensors with ShardTensors in an op result. + + If use_autograd is True, uses Layer 2 to preserve the graph connection. + Handles None returns gracefully for inplace ATen operations. + """ + if result is None: + return None + + if isinstance(result, DTensor): + spec = _resolve_spec_for_dtensor(result, input_args) + + # If autograd graph connection is requested AND the DTensor actually + # requires tracking (it has a grad_fn or requires_grad is active) + if ( + use_autograd + and torch.is_grad_enabled() + and (result.grad_fn is not None or result.requires_grad) + ): + return _DTensorToShardTensor.apply(result, spec) + + return _dtensor_to_shard_tensor(result, spec) + + if isinstance(result, Mapping): + return type(result)( + { + k: _convert_results_to_shard_tensor(v, input_args, use_autograd) + for k, v in result.items() + } + ) + + # Explicit allowlist mirroring _convert_args_to_dtensor: only walk into + # plain tuple / list containers. A generic Iterable check would crash on + # things like torch.UntypedStorage (iterable over bytes) or torch.Tensor + # because their constructors don't accept a generator. Note: namedtuples + # degrade to plain tuple here, same as in the args walker. + if isinstance(result, tuple): + return tuple( + _convert_results_to_shard_tensor(d, input_args, use_autograd) + for d in result + ) + + if isinstance(result, list): + return [ + _convert_results_to_shard_tensor(d, input_args, use_autograd) + for d in result + ] + + return result class _ToTorchTensor(torch.autograd.Function): @@ -114,7 +352,6 @@ class _ToTorchTensor(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, input: "ShardTensor", grad_placements: Sequence[Placement] | None = None, ) -> torch.Tensor: @@ -122,8 +359,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving tensors/variables for backward. input : ShardTensor ShardTensor to convert. grad_placements : Sequence[Placement], optional @@ -134,15 +369,24 @@ def forward( torch.Tensor Local tensor representation of the ShardTensor. """ - ctx.shard_tensor_spec = input._spec - ctx.grad_placements = grad_placements + # # JUST LIKE DTENSOR: + # # We need to return a fresh Tensor object there as autograd metadata + # # will be inplaced into it. So we don't want to pollute the Tensor + # # object stored in the _local_tensor of this ShardTensor. + # return local_tensor.view_as(local_tensor) + + # Force the local view to inherit the requires_grad state of the ShardTensor local_tensor = input._local_tensor + res = local_tensor.view_as(local_tensor) + res.requires_grad_(input.requires_grad) + return res - # JUST LIKE DTENSOR: - # We need to return a fresh Tensor object there as autograd metadata - # will be inplaced into it. So we don't want to pollute the Tensor - # object stored in the _local_tensor of this ShardTensor. - return local_tensor.view_as(local_tensor) + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save the source ShardTensorSpec and optional grad_placements.""" + input, grad_placements = inputs + ctx.shard_tensor_spec = input._spec + ctx.grad_placements = grad_placements @staticmethod def backward( @@ -200,11 +444,11 @@ class _FromTorchTensor(torch.autograd.Function): Global shape information is inferred using collective communication on the specified device mesh. + """ @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, local_input: torch.Tensor, device_mesh: DeviceMesh, placements: tuple[Placement, ...], @@ -214,8 +458,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving tensors/variables for backward. local_input : torch.Tensor Local tensor to convert to ShardTensor. device_mesh : DeviceMesh @@ -237,9 +479,6 @@ def forward( ShardTensor ShardTensor constructed from the local input tensor. """ - ctx.previous_placement = placements - ctx.previous_mesh = device_mesh - # This function is simpler than the corresponding DTensor implementation on the surface # because under the hood, we have some logic here to infer the sharding shapes. shard_tensor_spec = _infer_shard_tensor_spec_from_local_chunks( @@ -254,6 +493,13 @@ def forward( return shard_tensor + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save the source mesh and placements for the backward redistribute.""" + _local_input, device_mesh, placements, _sharding_shapes = inputs + ctx.previous_placement = placements + ctx.previous_mesh = device_mesh + @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, @@ -296,74 +542,7 @@ def backward( return grad_output.to_local(), None, None, None -class _PromoteDTensorToShardTensor(torch.autograd.Function): - r"""Autograd function to promote a DTensor to a ShardTensor while preserving ``grad_fn``. - - When DTensor's ``__torch_function__`` returns a non-leaf DTensor (one that - has a ``grad_fn``), creating a new ShardTensor via ``_make_wrapper_subclass`` - always produces a leaf — disconnecting it from the autograd graph. - - This function bridges that gap: the forward creates the ShardTensor wrapper, - and ``apply`` attaches a ``grad_fn`` that connects it back to the original - DTensor's graph. The backward simply passes gradients through unchanged. - - This is only used at the ``__torch_function__`` level where the DTensor - result already carries autograd state. At the ``__torch_dispatch__`` level, - promotion is safe without this because autograd wraps the result afterwards. - """ - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - dtensor: DTensor, - spec: "ShardTensorSpec", - ) -> "ShardTensor": - r"""Create a ShardTensor from a DTensor, preserving autograd via ``apply``. - - Parameters - ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context (unused — no state needed for backward). - dtensor : DTensor - The DTensor to promote. - spec : ShardTensorSpec - The ShardTensorSpec to use for the new ShardTensor. - - Returns - ------- - ShardTensor - A new ShardTensor wrapping the same local data. - """ - return ShardTensor.__new__( - ShardTensor, - local_tensor=dtensor._local_tensor, - spec=spec, - requires_grad=False, # autograd.Function.apply handles this - ) - - @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, - grad_output: "ShardTensor", - ) -> tuple[DTensor, None]: - r"""Pass gradient through unchanged. - - Parameters - ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context (unused). - grad_output : ShardTensor - Gradient with respect to the ShardTensor output. - - Returns - ------- - Tuple[DTensor, None] - The gradient for the DTensor input, and ``None`` for the spec. - """ - return grad_output, None - - -class ShardTensor(DTensor): +class ShardTensor(torch.Tensor): r"""A distributed tensor class with support for uneven data sharding. Similar to PyTorch's native ``DTensor`` but with more flexibility for @@ -496,41 +675,6 @@ def __new__( *, requires_grad: bool, ) -> "ShardTensor": - r"""Construct a new ShardTensor from a local tensor and specification. - - Note that unlike ``DTensor``, ShardTensor will automatically collect - the shard size information from all participating devices. This enables - uneven and dynamic sharding. - - Parameters - ---------- - local_tensor : torch.Tensor - Local tensor to use as the data. - spec : ShardTensorSpec - ShardTensorSpec defining the sharding scheme. - requires_grad : bool - Whether the tensor requires gradients. - - Returns - ------- - ShardTensor - A new ShardTensor instance. - - Note - ---- - This implementation is heavily derived from ``torch.distributed.tensor.DTensor``. - """ - if local_tensor.requires_grad and not requires_grad: - warn( - "To construct a new ShardTensor from torch.Tensor, " - "it's recommended to use local_tensor.detach() and " - "make requires_grad consistent." - ) - - if spec.tensor_meta is None: - raise ValueError("TensorMeta should not be None!") - - # Check the sharding information is known: ret = torch.Tensor._make_wrapper_subclass( cls, spec.tensor_meta.shape, @@ -538,178 +682,348 @@ def __new__( dtype=local_tensor.dtype, device=local_tensor.device, layout=local_tensor.layout, - requires_grad=requires_grad, + requires_grad=False, ) ret._spec = spec ret._local_tensor = local_tensor - cls._enable_shard_patches = True + # Set requires_grad AFTER _spec/_local_tensor are assigned, using + # the C-level setter directly (bypassing __torch_function__ which + # would convert to DTensor and set on a temporary). + if requires_grad: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.requires_grad.__set__(ret, True) + cls._enable_shard_patches = True return ret def __repr__(self) -> str: - return f"ShardTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + return ( + "ShardTensor(" + f"local_tensor={repr(self._local_tensor)}, " + f"device_mesh={repr(self._spec.mesh)}, " + f"placements={repr(self._spec.placements)}" + ")" + ) - @classmethod - def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor": - r"""Convert a DTensor to a ShardTensor. + def __str__(self) -> str: + # Avoid Tensor/DTensor string formatting paths that can re-enter dispatch. + return self.__repr__() - Assumes the DTensor is properly constructed. Since DTensor is locked - to sharding a tensor according to chunk format, the sharding sizes - can be inferred with no communication. + def __format__(self, format_spec: str) -> str: + # Format as plain Python string to bypass tensor formatting internals. + return format(str(self), format_spec) - If the DTensor is a non-leaf (has a ``grad_fn``), the autograd graph - is preserved via :class:`_PromoteDTensorToShardTensor`. + @property + def device_mesh(self) -> DeviceMesh: + """Return the :class:`DeviceMesh` that this tensor is distributed over.""" + return self._spec.mesh - Parameters - ---------- - dtensor : DTensor - DTensor to convert. + @property + def placements(self) -> tuple[Placement, ...]: + """Return the placement strategy for each mesh dimension.""" + return self._spec.placements - Returns - ------- - ShardTensor - Equivalent ShardTensor with the same local tensor and inferred spec. - """ - return cls._maybe_promote_dtensor(dtensor, ()) + def __tensor_flatten__(self): + return ["_local_tensor"], (self._spec, self.requires_grad) @staticmethod - def _maybe_promote_dtensor(dtensor, input_args): - r"""Promote a single DTensor back to ShardTensor if it matches input criteria. + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + spec, requires_grad = flatten_spec + local_tensor = inner_tensors["_local_tensor"] + unflatten_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + + # Normalize ``_sharding_shapes`` to plain ``tuple[int, ...]`` entries + # (never ``torch.Size``). Under dynamo fakeification, ``torch.Size`` + # special-casing converts the contained ints into unbacked SymInts + # that orphan whenever an op's output drops shard tracking + # (Partial / Replicate / None), producing + # ``PendingUnbackedSymbolNotFound`` during AOT tracing. + # + # If the incoming spec has no ``_sharding_shapes``, derive them from + # chunk semantics against the outer global shape -- pure arithmetic, + # no collectives. This avoids leaving the field ``None``, which would + # force the next ``sharding_shapes()`` call to ``_all_gather_shard_shapes`` + # (a blocking collective that is not AOT-traceable). + if spec._sharding_shapes is not None: + sharding_shapes = { + mesh_dim: tuple(tuple(s) for s in shapes) + for mesh_dim, shapes in spec._sharding_shapes.items() + } + else: + chunk_shapes = compute_sharding_shapes_from_chunking_global_shape( + spec.mesh, spec.placements, tuple(outer_size) + ) + sharding_shapes = { + mesh_dim: tuple(tuple(s) for s in shapes) + for mesh_dim, shapes in chunk_shapes.items() + } + + unflatten_spec = ShardTensorSpec( + mesh=spec.mesh, + placements=spec.placements, + tensor_meta=unflatten_meta, + _local_shape=local_tensor.shape, + _sharding_shapes=sharding_shapes, + ) + return ShardTensor.__new__( + ShardTensor, + local_tensor=local_tensor.requires_grad_(requires_grad), + spec=unflatten_spec, + requires_grad=requires_grad, + ) + + # -- AOTAutograd tangent coercion hooks ------------------------------------ + # AOTAutograd records the expected tangent metadata at trace time and + # validates it at backward runtime. When a forward output has a + # ``Partial`` placement (typical right after a reduction like sum/mean), + # the tangent flowing back from ``.backward()`` is materialized as + # ``Replicate`` and AOT raises: + # "During the backward, we encountered a tensor subclass where we + # guessed its metadata incorrectly." + # These two hooks mirror DTensor's implementation in + # ``torch.distributed.tensor._api`` and reconcile the two ends: + # (1) at trace time, rewrite the expected metadata so any Partial + # placement becomes Replicate (so the recorded tangent metadata + # matches what runtime will actually produce); + # (2) at runtime, redistribute the incoming tangent to whatever + # placement the expected spec demands. + + def __coerce_tangent_metadata__(self) -> "ShardTensor": + """Trace-time hook: coerce this tensor so its metadata matches a tangent. + + Returns ``self`` if no Partial placement is present (no work needed). + Otherwise redistributes Partial placements to Replicate, which is the + layout the autograd engine produces for tangents. + """ + if not any(isinstance(p, Partial) for p in self.placements): + return self + new_placements = [ + Replicate() if isinstance(p, Partial) else p for p in self.placements + ] + return self.redistribute( + device_mesh=self.device_mesh, placements=new_placements + ) - If ``dtensor`` is already a ShardTensor, it is returned as-is. Otherwise, - determines a ``ShardTensorSpec`` (reusing an input's spec when possible, - otherwise inferring one) and creates a new ShardTensor. + def __coerce_same_metadata_as_tangent__( + self, + flatten_spec: tuple, + expected_type: type | None = None, + ) -> "ShardTensor | None": + """Runtime hook: redistribute ``self`` to match the recorded tangent's + placements and ``_sharding_shapes`` (preserves uneven layouts). - When the DTensor is a non-leaf (has a ``grad_fn``), the promotion goes - through :class:`_PromoteDTensorToShardTensor` so that the autograd graph - is preserved. For leaf DTensors, direct construction is used since there - is no graph to preserve. + Returns ``None`` when ``expected_type`` differs (DTensor convention). + """ + if expected_type is not None: + return None + + (spec, _requires_grad) = flatten_spec + + if ( + self._spec.placements == spec.placements + and self._spec._sharding_shapes == spec._sharding_shapes + ): + return self + + # Bypass ``self.redistribute()`` so we can thread the recorded per-tensor-dim + # shard sizes through to the local redistribute (the public API drops them). + target_spec = ShardTensorSpec( + mesh=self.device_mesh, + placements=spec.placements, + tensor_meta=self._spec.tensor_meta, + _sharding_shapes=spec._sharding_shapes, + ) + + target_sharding_shapes_by_tensor_dim: dict[int, list[int]] = {} + if spec._sharding_shapes is not None: + for mesh_dim, placement in enumerate(spec.placements): + if isinstance(placement, Shard) and mesh_dim in spec._sharding_shapes: + shard_shapes = spec._sharding_shapes[mesh_dim] + target_sharding_shapes_by_tensor_dim[placement.dim] = [ + s[placement.dim] for s in shard_shapes + ] + + new_local = redistribute_local_shard_tensor( + self._local_tensor, + self._spec, + target_spec, + async_op=False, + target_sharding_shapes=target_sharding_shapes_by_tensor_dim, + ) + target_spec._local_shape = new_local.shape + + return ShardTensor( + new_local.contiguous(), + target_spec, + requires_grad=self.requires_grad, + ) + + # -- Autograd property overrides ------------------------------------------- + # The C-level requires_grad is authoritative for autograd engine + # decisions; we read it first and fall back to _local_tensor for the + # case where _make_wrapper_subclass didn't propagate it correctly. + # For grad, the autograd engine accumulates at the C level, so we + # check there first then fall back to _local_tensor.grad. + + @property # type: ignore[override] + def requires_grad(self) -> bool: # type: ignore[override] + """Whether this tensor requires gradient computation. + + Returns ``True`` if either the wrapper tensor or the underlying local + tensor has ``requires_grad`` set. + """ + with torch._C.DisableTorchFunctionSubclass(): + if torch.Tensor.requires_grad.__get__(self): + return True + return self._local_tensor.requires_grad + + @requires_grad.setter + def requires_grad(self, value: bool) -> None: + """Set ``requires_grad`` on both the wrapper and the local tensor.""" + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.requires_grad.__set__(self, value) + self._local_tensor.requires_grad = value + + def requires_grad_(self, requires_grad: bool = True) -> "ShardTensor": + """Set ``requires_grad`` in-place on both the wrapper and local tensor. Parameters ---------- - dtensor : DTensor - The DTensor result to promote. - input_args : tuple - Original input arguments to search for matching ShardTensors. + requires_grad : bool, optional + Whether to enable gradient tracking. Default is ``True``. Returns ------- ShardTensor - Promoted ShardTensor (or the original if already a ShardTensor). + ``self``, for method chaining. """ - if isinstance(dtensor, ShardTensor): - return dtensor - - # Determine the ShardTensorSpec — reuse an input's spec when the - # tensor_meta and placements match (avoids communication). - spec = None - for arg in input_args: - if ( - isinstance(arg, ShardTensor) - and dtensor._spec.tensor_meta == arg._spec.tensor_meta - and dtensor._spec.placements == arg._spec.placements - ): - spec = arg._spec - break - - if spec is None: - # Infer from DTensor (no communication for chunk-based sharding). - spec = _infer_shard_tensor_spec_from_local_chunks( - dtensor._local_tensor, - dtensor._spec.mesh, - dtensor._spec.placements, - sharding_shapes="chunk", - global_shape=dtensor.shape, - ) + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.requires_grad.__set__(self, requires_grad) + self._local_tensor.requires_grad_(requires_grad) + return self + + @property # type: ignore[override] + def is_leaf(self) -> bool: # type: ignore[override] + """Whether this tensor is a leaf in the autograd graph.""" + with torch._C.DisableTorchFunctionSubclass(): + return torch.Tensor.is_leaf.__get__(self) + + @property # type: ignore[override] + def grad_fn(self): # type: ignore[override] + """Return the stored grad_fn without re-entering ``__torch_function__``. + + Without this override, ``.grad_fn`` (a C-level getset_descriptor on + ``torch.Tensor``) re-enters ``ShardTensor.__torch_function__`` + whenever someone reads it, falls back via + :func:`_torch_function_fallback_via_dtensor`, and the fallback + constructs a *new* temporary DTensor via + ``_ShardTensorToDTensor.apply(self)`` -- whose ``.grad_fn`` (a + ``_ShardTensorToDTensorBackward`` ``BackwardCFunction`` instance) + is what the caller actually receives. On newer PyTorch that + node's ``.next_functions`` accessor raises a "legacy access + pattern" error, which is exactly what makes + ``AOTAutograd.setup_stacktrace_preservation_hooks`` (and our + own diagnostic ``dump_grad_fn_chain``) fail when they try to + walk the autograd graph of a ShardTensor output. + + Mirrors the same shielding pattern already used by ``.is_leaf`` + and ``.grad``. + """ + with torch._C.DisableTorchFunctionSubclass(): + return torch.Tensor.grad_fn.__get__(self) - # Non-leaf DTensors carry a grad_fn from the operation that produced - # them. Creating a new ShardTensor via _make_wrapper_subclass would - # discard that grad_fn (producing a leaf). Go through the autograd - # function so that apply() connects the new ShardTensor back to the - # original graph. - if dtensor.grad_fn is not None: - return _PromoteDTensorToShardTensor.apply(dtensor, spec) + @property # type: ignore[override] + def grad(self) -> "ShardTensor | None": # type: ignore[override] + """Return the accumulated gradient, wrapped as a :class:`ShardTensor`. - # Leaf DTensors (parameters, buffers, detached tensors) can be - # constructed directly — there is no autograd graph to preserve. + If no gradient has been accumulated yet, returns ``None``. + """ + with torch._C.DisableTorchFunctionSubclass(): + c_grad = torch.Tensor.grad.__get__(self) + if c_grad is not None: + if isinstance(c_grad, ShardTensor): + return c_grad + return ShardTensor.__new__( + ShardTensor, + local_tensor=c_grad._local_tensor + if isinstance(c_grad, DTensor) + else c_grad, + spec=self._spec, + requires_grad=False, + ) + local_grad = self._local_tensor.grad + if local_grad is None: + return None return ShardTensor.__new__( ShardTensor, - local_tensor=dtensor._local_tensor, - spec=spec, - requires_grad=dtensor.requires_grad, + local_tensor=local_grad, + spec=self._spec, + requires_grad=False, ) - @staticmethod - def _promote_dtensor_results(result, input_args): - r"""Promote DTensor(s) in a dispatch/function result back to ShardTensor. + @grad.setter + def grad(self, value: "ShardTensor | torch.Tensor | None") -> None: + """Set or clear the gradient on both the wrapper and local tensor.""" + if value is None: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.grad.__set__(self, None) + self._local_tensor.grad = None + elif isinstance(value, ShardTensor): + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.grad.__set__(self, value) + self._local_tensor.grad = value._local_tensor + else: + with torch._C.DisableTorchFunctionSubclass(): + torch.Tensor.grad.__set__(self, value) + self._local_tensor.grad = value - Handles four cases: + @classmethod + def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor": + r"""Convert a DTensor to a ShardTensor. - 1. Single DTensor — promoted via :meth:`_maybe_promote_dtensor`. - 2. Mapping (e.g. dict) — each value is promoted if it is a DTensor. - 3. Iterable of results — each DTensor element is promoted individually. - 4. Anything else — returned as-is. + Differentiable when *dtensor* is non-leaf (has a ``grad_fn``). + Spec is inferred from the DTensor (chunk-based, no communication). Parameters ---------- - result : object - The result returned by DTensor dispatch or ``__torch_function__``. - input_args : tuple - Original input arguments used for matching specs. + dtensor : DTensor + DTensor to convert. Returns ------- - object - The result with any DTensors promoted to ShardTensors. + ShardTensor + Equivalent ShardTensor with the same local tensor and inferred spec. """ - if isinstance(result, DTensor): - return ShardTensor._maybe_promote_dtensor(result, input_args) - - if isinstance(result, Mapping): - return type(result)( - { - k: ShardTensor._maybe_promote_dtensor(v, input_args) - if isinstance(v, DTensor) - else v - for k, v in result.items() - } - ) - - # Exclude str/bytes so we don't iterate over characters. - if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): - return type(result)( - ShardTensor._maybe_promote_dtensor(d, input_args) - if isinstance(d, DTensor) - else d - for d in result - ) - - return result + if isinstance(dtensor, ShardTensor): + return dtensor + spec = _resolve_spec_for_dtensor(dtensor) + if dtensor.grad_fn is not None: + return _DTensorToShardTensor.apply(dtensor, spec) + return _dtensor_to_shard_tensor(dtensor, spec) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - with annotate(f"__torch_function___{func.__name__}"): - # Check for overrides: - if func in cls._function_registry and cls._enable_shard_patches: - res = cls._function_registry[func](func, types, args, kwargs) - return res - elif ( - str(func) in cls._named_function_registry and cls._enable_shard_patches - ): - res = cls._named_function_registry[str(func)](func, types, args, kwargs) - return res - # Fall back to the default behavior, but promote any DTensor - # results back to ShardTensor (matching dispatch behavior): - result = super().__torch_function__(func, types, args, kwargs) - return cls._promote_dtensor_results(result, args) + if _conversion_active(): + # When converting shard tensor to dtensor, or dtensor to shard tensor, + # we just run the function without ShardTensor dispatch. + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + if func in cls._function_registry and cls._enable_shard_patches: + return cls._function_registry[func](func, types, args, kwargs) + if str(func) in cls._named_function_registry and cls._enable_shard_patches: + return cls._named_function_registry[str(func)](func, types, args, kwargs) + res = _torch_function_fallback_via_dtensor(func, args, kwargs) + return res @classmethod - @torch._disable_dynamo - @profile def __torch_dispatch__( cls, func: torch._ops.OpOverload, @@ -717,33 +1031,14 @@ def __torch_dispatch__( args: tuple[object, ...] = (), kwargs: dict[str, object] | None = None, ) -> "ShardTensor" | Iterable["ShardTensor"] | object: - with annotate(f"__torch_dispatch___{func.__name__}"): - # Leverage DTensor Dispatch as much as possible, but, enable - # the ability to operate on this output in the future: - handler = cls._dispatch_registry.get(func) - if handler is None: - handler = cls._dispatch_registry_by_name.get(str(func)) - if handler is not None: - res = handler(*args, **kwargs) - return res - - # We assume that if we reach this point, the operator has not been - # intercepted by a wrapper or in the registry. So the DTensor - # default behavior is likely to be correct. - - # Convert ShardTensors to DTensors so DTensor's dispatcher - # receives the types it expects. - converted_args = tuple(_convert_args_to_dtensor(arg) for arg in args) - converted_kwargs = { - k: _convert_args_to_dtensor(v) for k, v in (kwargs or {}).items() - } - - dispatch_res = DTensor._op_dispatcher.dispatch( - func, converted_args, converted_kwargs - ) - - # Promote any DTensor results back to ShardTensor. - return cls._promote_dtensor_results(dispatch_res, args) + # Use a handler, if we have one: + handler = cls._dispatch_registry.get(func) + if handler is None: + handler = cls._dispatch_registry_by_name.get(str(func)) + if handler is not None: + return handler(*args, **kwargs) + # Otherwise, try the dtensor route: + return _dispatch_fallback_via_dtensor(func, args, kwargs) @staticmethod def from_local( @@ -962,9 +1257,48 @@ def backward(self, *args, **kwargs): if needs_redistribute: self = self.redistribute(placements=new_placements) + if self.grad_fn is not None: + return torch.Tensor.backward(self, *args, **kwargs) + return self.to_local().backward(*args, **kwargs) +### TODO +### Do we still need this? +### I think we do not - CJA + + +class FSDPOutputTensorAdapter(nn.Module): + """Wrap a module and convert ShardTensor outputs to torch.Tensor.""" + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + out = self.module(*args, **kwargs) + return out.to_local() if isinstance(out, ShardTensor) else out + + +def wrap_for_fsdp(module: nn.Module) -> nn.Module: + """Return a module wrapper that exposes tensor outputs for FSDP hooks.""" + return FSDPOutputTensorAdapter(module) + + +def distribute_over_domain_for_fsdp( + module: nn.Module, + device_mesh: DeviceMesh, + partition_fn: (Callable[[str, nn.Module, DeviceMesh], None] | None) = None, +) -> nn.Module: + """Distribute a module over a domain mesh and adapt outputs for FSDP.""" + distributed_module = distribute_module( + module, + device_mesh=device_mesh, + partition_fn=partition_fn, + ) + return wrap_for_fsdp(distributed_module) + + def scatter_tensor( tensor: torch.Tensor, global_src: int, @@ -1044,12 +1378,15 @@ def scatter_tensor( # scatter along Shard dimensions. BUT, the focus is on performance of full applications # and this is a once-per-iteration cost. - # Broadcast the tensor to all ranks + # Broadcast the tensor to all ranks. + # scatter_tensor is an input-boundary utility; keep internal collectives/layout + # transforms out of autograd and construct the requested leaf explicitly. if tensor is None and not is_src: # Tensor is allowed to be none if not on the root rank tensor = torch.empty(local_meta.shape, dtype=local_meta.dtype, device=dm.device) - dist.broadcast(tensor, src=global_src, group=mesh_group) + with torch.no_grad(): + dist.broadcast(tensor, src=global_src, group=mesh_group) # Create a fully-replicated spec: spec = ShardTensorSpec( @@ -1059,18 +1396,30 @@ def scatter_tensor( _sharding_shapes={}, ) - # Make a "fully-replicated" tensor on all ranks: - st = ShardTensor.__new__( - ShardTensor, - local_tensor=tensor, - spec=spec, - requires_grad=requires_grad, - ) + with torch.no_grad(): + # Build a replicated ShardTensor and redistribute to the requested + # placements without recording autograd history. + st = ShardTensor.__new__( + ShardTensor, + local_tensor=tensor, + spec=spec, + requires_grad=False, + ) + st = st.redistribute(mesh, placements, async_op=False) - # Redistribute the tensor to the desired placements: - st = st.redistribute(mesh, placements, async_op=False) - # This is an unoptimal step but is functional: if requires_grad: - st = st.detach() - st.requires_grad = True + # 1. Ensure the local data is a clean leaf + local_leaf = st._local_tensor.detach().requires_grad_(True) + + # 2. Create the ShardTensor wrapper + st = ShardTensor.__new__( + ShardTensor, + local_tensor=local_leaf, + spec=st._spec, + requires_grad=True, + ) + + # 3. CRITICAL: Force the wrapper itself to be a leaf in the autograd graph + st = st.detach().requires_grad_(True) + return st diff --git a/physicsnemo/domain_parallel/shard_utils/__init__.py b/physicsnemo/domain_parallel/shard_utils/__init__.py index 4ce1bfc714..69b7370cb2 100644 --- a/physicsnemo/domain_parallel/shard_utils/__init__.py +++ b/physicsnemo/domain_parallel/shard_utils/__init__.py @@ -25,6 +25,11 @@ from physicsnemo.domain_parallel.shard_tensor import ShardTensor def register_shard_wrappers(): + """Import and register all shard-aware operation wrappers with ShardTensor. + + Each imported module registers its wrapper via + :meth:`ShardTensor.register_op` at import time. + """ from .attention_patches import sdpa_wrapper from .conv_patches import generic_conv_nd_wrapper from .index_ops import ( diff --git a/physicsnemo/domain_parallel/shard_utils/attention_patches.py b/physicsnemo/domain_parallel/shard_utils/attention_patches.py index 344781d83c..8079fde970 100644 --- a/physicsnemo/domain_parallel/shard_utils/attention_patches.py +++ b/physicsnemo/domain_parallel/shard_utils/attention_patches.py @@ -146,7 +146,6 @@ class RingSDPA(torch.autograd.Function): @staticmethod def forward( - ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -154,7 +153,7 @@ def forward( mesh: DeviceMesh, ring_config: RingPassingConfig, attn_args: dict, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r"""Forward pass for the ring attention implementation. Overlaps communication with computation using a dedicated comm stream @@ -164,8 +163,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving tensors/variables for backward. q : torch.Tensor Query tensor of shape :math:`(B, H, S, D)`. k : torch.Tensor @@ -183,14 +180,13 @@ def forward( Returns ------- - torch.Tensor - Output tensor of shape :math:`(B, H, S, D)`. + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Tuple of ``(output, global_log_sumexp, philox_seed, + philox_offset)`` of shape :math:`(B, H, S, D)` for output and the + intermediate stats needed by backward. The public wrapper + discards the extras and they are marked non-differentiable. """ - ctx.attn_args = attn_args - ctx.mesh = mesh - ctx.ring_config = ring_config - # Accumulation state (log-space for numerical stability) log_global_output = None sign_global_output = None @@ -290,6 +286,13 @@ def forward( log_global_output - global_log_sumexp ) + return stable_output, global_log_sumexp, philox_seed, philox_offset + + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save inputs and forward-computed stats for the backward pass.""" + q, k, v, attn_mask, mesh, ring_config, attn_args = inputs + stable_output, global_log_sumexp, philox_seed, philox_offset = output ctx.save_for_backward( q, k, @@ -300,13 +303,19 @@ def forward( philox_seed, philox_offset, ) + ctx.attn_args = attn_args + ctx.mesh = mesh + ctx.ring_config = ring_config ctx.grad_input_mask = (True, True, True, attn_mask is not None) - - return stable_output + ctx.mark_non_differentiable(global_log_sumexp, philox_seed, philox_offset) @staticmethod def backward( - ctx, grad_output: torch.Tensor + ctx, + grad_output: torch.Tensor, + _grad_log_sumexp: torch.Tensor | None = None, + _grad_philox_seed: torch.Tensor | None = None, + _grad_philox_offset: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor, @@ -507,7 +516,6 @@ class RingSDPABlocking(torch.autograd.Function): @staticmethod def forward( - ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -515,15 +523,13 @@ def forward( mesh: DeviceMesh, ring_config: RingPassingConfig, attn_args: dict, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r"""Forward pass for the ring attention implementation. This implementation will NOT overlap the communication with the computation. Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving tensors/variables for backward. q : torch.Tensor Query tensor of shape :math:`(B, H, S, D)`. k : torch.Tensor @@ -541,14 +547,13 @@ def forward( Returns ------- - torch.Tensor - Output tensor of shape :math:`(B, H, S, D)`. + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Tuple of ``(output, global_log_sumexp, philox_seed, + philox_offset)`` of shape :math:`(B, H, S, D)` for output and the + intermediate stats needed by backward. The public wrapper + discards the extras and they are marked non-differentiable. """ - ctx.attn_args = attn_args - ctx.mesh = mesh - ctx.ring_config = ring_config - # Create buffers to store outputs log_global_output = None sign_global_output = None @@ -589,14 +594,21 @@ def forward( global_log_sumexp = add_log_sumexp(global_log_sumexp, log_sumexp) # send k and v to the next rank: - current_k = perform_ring_iteration(current_k, ctx.mesh, ctx.ring_config) - current_v = perform_ring_iteration(current_v, ctx.mesh, ctx.ring_config) + current_k = perform_ring_iteration(current_k, mesh, ring_config) + current_v = perform_ring_iteration(current_v, mesh, ring_config) # Compute the final output stable_output = sign_global_output * torch.exp( log_global_output - global_log_sumexp ) + return stable_output, global_log_sumexp, philox_seed, philox_offset + + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save inputs and forward-computed stats for the backward pass.""" + q, k, v, attn_mask, mesh, ring_config, attn_args = inputs + stable_output, global_log_sumexp, philox_seed, philox_offset = output ctx.save_for_backward( q, k, @@ -607,13 +619,19 @@ def forward( philox_seed, philox_offset, ) + ctx.attn_args = attn_args + ctx.mesh = mesh + ctx.ring_config = ring_config ctx.grad_input_mask = (True, True, True, attn_mask is not None) - - return stable_output + ctx.mark_non_differentiable(global_log_sumexp, philox_seed, philox_offset) @staticmethod def backward( - ctx, grad_output: torch.Tensor + ctx, + grad_output: torch.Tensor, + _grad_log_sumexp: torch.Tensor | None = None, + _grad_philox_seed: torch.Tensor | None = None, + _grad_philox_offset: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor, @@ -711,6 +729,7 @@ def backward( return grad_q, grad_k, grad_v, grad_attn_mask, None, None, None +@torch.compiler.disable(recursive=True) def ring_sdpa( q: ShardTensor, k: ShardTensor, @@ -724,6 +743,30 @@ def ring_sdpa( locally on its tensors, and then kv is passed to the next rank while receiving from the previous rank. + Notes + ----- + ``torch.compile`` around this function is currently unsupported and + will error. The ``@torch.compiler.disable`` decorator below blocks + dynamo's symbolic tracing of the body, which is *necessary* (the + overlap path uses ``torch.cuda.stream``, ``torch.cuda.Event``, + ``tensor.record_stream``, and async ``Work`` handles -- none of which + have an FX representation), but it is **not sufficient**: AOTAutograd + re-executes the captured graph against ``FunctionalTensor`` inputs + during metadata propagation, and our ``ShardTensor`` + ``__torch_function__`` dispatcher re-enters this function on the + captured SDPA node, where ``record_stream``'s alias annotation trips + PyTorch's functionalization layer. + + Eager (no ``torch.compile``) usage works as designed: this is the + overlap-K/V-with-compute attention kernel used by sharded models in + production. Compile support for sharded attention requires a separate + refactor of the ring (drop ``record_stream``, switch + ``perform_ring_iteration`` to functional p2p collectives, replace the + explicit ``cuda.stream`` overlap with implicit collective overlap). + Until then, callers that need ``torch.compile`` must either keep + sharded attention outside the compiled region or avoid sharded + attention entirely. + Parameters ---------- q : ShardTensor @@ -770,7 +813,12 @@ def ring_sdpa( else: latn_mask = None - x = RingSDPA.apply(lq, lk, lv, latn_mask, q._spec.mesh, ring_config, kwargs) + # RingSDPA returns (output, global_log_sumexp, philox_seed, philox_offset); + # the three trailing tensors are intermediate stats marked non-differentiable + # and consumed only by its backward pass. + x, _, _, _ = RingSDPA.apply( + lq, lk, lv, latn_mask, q._spec.mesh, ring_config, kwargs + ) # Convert back to ShardTensor x = ShardTensor.from_local( diff --git a/physicsnemo/domain_parallel/shard_utils/conv_patches.py b/physicsnemo/domain_parallel/shard_utils/conv_patches.py index a67ea869b4..2214f9694a 100644 --- a/physicsnemo/domain_parallel/shard_utils/conv_patches.py +++ b/physicsnemo/domain_parallel/shard_utils/conv_patches.py @@ -20,6 +20,7 @@ import torch import torch.distributed as dist +import torch.distributed._functional_collectives as funcol from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import ( Shard, @@ -408,16 +409,13 @@ class ConvGradReducer(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, weight_or_bias: torch.Tensor, spec: ShardTensorSpec, ) -> torch.Tensor: - r"""Forward pass that saves the spec for backward. + r"""Forward pass: return the weight/bias tensor unchanged. Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving variables for backward. weight_or_bias : torch.Tensor The weight or bias tensor to pass through. spec : ShardTensorSpec @@ -428,15 +426,20 @@ def forward( torch.Tensor The input tensor unchanged. """ - ctx.spec = spec return weight_or_bias + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save the input ShardTensorSpec for the backward all-reduce.""" + _weight_or_bias, spec = inputs + ctx.spec = spec + @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_weight_or_bias: torch.Tensor, ) -> tuple[torch.Tensor, None]: - r"""Backward pass that performs allreduce on gradients. + r"""Backward pass: all-reduce gradients over each sharded mesh dim. Parameters ---------- @@ -452,8 +455,12 @@ def backward( """ for mesh_dim in range(ctx.spec.mesh.ndim): if ctx.spec.placements[mesh_dim].is_shard(): - group = ctx.spec.mesh.get_group(mesh_dim) - dist.all_reduce(grad_weight_or_bias, group=group) + # funcol.all_reduce returns a new tensor (AsyncCollectiveTensor) + # that auto-waits when used; assigning back into the loop var + # serializes the iterations correctly. + grad_weight_or_bias = funcol.all_reduce( + grad_weight_or_bias, "sum", (ctx.spec.mesh, mesh_dim) + ) return grad_weight_or_bias, None diff --git a/physicsnemo/domain_parallel/shard_utils/halo.py b/physicsnemo/domain_parallel/shard_utils/halo.py index 9c46c52372..02a511b10d 100644 --- a/physicsnemo/domain_parallel/shard_utils/halo.py +++ b/physicsnemo/domain_parallel/shard_utils/halo.py @@ -37,6 +37,7 @@ import torch import torch.distributed as dist +import torch.distributed._functional_collectives as funcol from torch.autograd.profiler import record_function from torch.distributed.device_mesh import DeviceMesh @@ -170,7 +171,6 @@ class HaloPadding(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, tensor: torch.Tensor, mesh: DeviceMesh, config: HaloConfig, @@ -179,8 +179,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving tensors/variables for backward. tensor : torch.Tensor Tensor to apply halo padding to. mesh : DeviceMesh @@ -193,18 +191,15 @@ def forward( torch.Tensor Padded tensor with halos added locally to each chunk. """ + return halo_padding_fwd_primitive(tensor, mesh, config) - # Save context for backward pass + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save ``mesh`` and ``config`` for the backward pass.""" + _tensor, mesh, config = inputs ctx.mesh = mesh ctx.config = config - padded_tensor = halo_padding_fwd_primitive( - tensor, - mesh, - config, - ) - return padded_tensor - @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor @@ -224,13 +219,10 @@ def backward( Tuple of (gradient for input tensor, ``None`` for mesh, ``None`` for config). """ - mesh = ctx.mesh - config = ctx.config - grad_input = halo_padding_bwd_primitive( grad_output, - mesh, - config, + ctx.mesh, + ctx.config, ) return grad_input, None, None @@ -249,7 +241,6 @@ class UnHaloPadding(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, tensor: torch.Tensor, mesh: DeviceMesh, config: HaloConfig, @@ -262,8 +253,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving tensors/variables for backward. tensor : torch.Tensor Tensor to remove halo padding from. mesh : DeviceMesh @@ -277,10 +266,6 @@ def forward( Tensor with halo regions removed. """ - # Save context for backward pass - ctx.mesh = mesh - ctx.config = config - # Chop off the halos _left, unpadded_tensor, _right = slice_halo_regions( tensor, @@ -288,11 +273,43 @@ def forward( config, ) - ctx.left_shape = _left.shape - ctx.right_shape = _right.shape - return unpadded_tensor + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save ``mesh``, ``config`` and the left/right halo shapes for backward. + + The left/right halo shapes are derived from the input tensor's shape + along ``config.tensor_dim`` and the rank in the mesh, matching the + slicing performed by ``slice_halo_regions``. + """ + tensor, mesh, config = inputs + + # Reconstruct the left/right slice boundaries on this rank without + # actually running ``slice_halo_regions`` a second time. + local_group = mesh.get_group(config.mesh_dim) + local_rank = mesh.get_local_rank(config.mesh_dim) + local_size = dist.get_world_size(group=local_group) + + dim_shape = tensor.shape[config.tensor_dim] + + start = config.halo_size if local_rank != 0 else config.edge_padding_size + end = ( + dim_shape - config.halo_size + if local_rank != local_size - 1 + else dim_shape - config.edge_padding_size + ) + + left_shape = list(tensor.shape) + left_shape[config.tensor_dim] = start + right_shape = list(tensor.shape) + right_shape[config.tensor_dim] = dim_shape - end + + ctx.mesh = mesh + ctx.config = config + ctx.left_shape = torch.Size(left_shape) + ctx.right_shape = torch.Size(right_shape) + @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, @@ -630,45 +647,61 @@ def perform_halo_collective( req.wait() elif method == "a2a": - # All-to-all communication - all_to_all_send = [ - torch.empty(0, dtype=dtype, device=device) for _ in range(local_size) - ] - all_to_all_recv = [ - torch.empty(0, dtype=dtype, device=device) for _ in range(local_size) - ] - - # Set up send/recv buffers + # This has to be funcol collectives, below, to be + # conmpatible with torc.compile. + # + # Symmetric-halo assumption: what I receive from neighbor R has the + # same numel as what I send to R (true for uniform halo sizes used + # by conv / natten / pooling / etc). + input_split_sizes = [0] * local_size + output_split_sizes = [0] * local_size + halo_from_left_shape: torch.Size | None = None + halo_from_right_shape: torch.Size | None = None + send_chunks: list[torch.Tensor] = [] + if local_rank != 0: - # Send one left - all_to_all_send[local_rank - 1] = halo_to_left - # Receive one right (need to initialize an empty buffer of the right size): - all_to_all_recv[local_rank - 1] = torch.zeros_like( - halo_to_left - ).contiguous() + # Send one to the left; receive one (same size) from the left. + flat_left = halo_to_left.reshape(-1).contiguous() + send_chunks.append(flat_left) + input_split_sizes[local_rank - 1] = flat_left.numel() + output_split_sizes[local_rank - 1] = flat_left.numel() + halo_from_left_shape = halo_to_left.shape if local_rank != local_size - 1: - # Send one to the right: - all_to_all_send[local_rank + 1] = halo_to_right - # Receive one from the right: - all_to_all_recv[local_rank + 1] = torch.zeros_like( - halo_to_right - ).contiguous() - - # Perform exchange - with record_function("all_to_all_queue_and_wait"): - request = dist.all_to_all( - all_to_all_recv, all_to_all_send, group=local_group, async_op=async_op + # Send one to the right; receive one (same size) from the right. + flat_right = halo_to_right.reshape(-1).contiguous() + send_chunks.append(flat_right) + input_split_sizes[local_rank + 1] = flat_right.numel() + output_split_sizes[local_rank + 1] = flat_right.numel() + halo_from_right_shape = halo_to_right.shape + + # Concatenated send buffer. The cat order (left then right) matches + # the ascending destination-rank order required by all_to_all_single. + if send_chunks: + send_buf = torch.cat(send_chunks) + else: + send_buf = torch.empty(0, dtype=dtype, device=device) + + with record_function("all_to_all_single_funcol"): + recv_buf = funcol.all_to_all_single_autograd( + send_buf, + output_split_sizes, + input_split_sizes, + (mesh, mesh_dim), ) - if async_op: - # According to the docs, this will wait until the collectives are enqueued and it's safe to use the recv buffers. - request.wait() - - # Extract received halos - halo_from_left = all_to_all_recv[local_rank - 1] if local_rank != 0 else None + # Split into per-source-rank chunks (some empty), then reshape to + # the original halo tensor shapes. + recv_chunks = list(torch.split(recv_buf, output_split_sizes)) + halo_from_left = ( + recv_chunks[local_rank - 1].view(halo_from_left_shape) + if local_rank != 0 + else None + ) halo_from_right = ( - all_to_all_recv[local_rank + 1] if local_rank != local_size - 1 else None + recv_chunks[local_rank + 1].view(halo_from_right_shape) + if local_rank != local_size - 1 + else None ) return halo_from_left, halo_from_right diff --git a/physicsnemo/domain_parallel/shard_utils/index_ops.py b/physicsnemo/domain_parallel/shard_utils/index_ops.py index 25e96fc865..0bf32a72d4 100644 --- a/physicsnemo/domain_parallel/shard_utils/index_ops.py +++ b/physicsnemo/domain_parallel/shard_utils/index_ops.py @@ -47,7 +47,6 @@ class ShardedIndexSelect(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, tensor: ShardTensor, dim: int, index: ShardTensor, @@ -59,8 +58,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Context object to store information for backward pass. tensor : ShardTensor Input tensor to select from. dim : int @@ -81,19 +78,12 @@ def forward( # This is the simplest implementation, to enable functionality. # It could be optimized for very large tensors to ensure performace. - # We save the local version of the index and the input tensor spec for the backwards pass - - ctx.spec = tensor._spec - ctx.grad_shape = tensor._local_tensor.shape - ctx.dim = dim - # First - Make sure we have the full input tensor # Triggers an all_gather(_v) for (uneven) tensors. local_tensor = tensor.full_tensor() # Perform the index select using the local values of the index: local_index = index.to_local() - ctx.save_for_backward(index) # Get everything requested from the local index: local_values = aten.index_select(local_tensor, dim, local_index) @@ -113,15 +103,11 @@ def forward( for local_chunk_size in index_shard_sizes: this_shard_size = output_size this_shard_size[dim] = local_chunk_size[0] - # Make sure it's a tuple: - output_shard_sizes[mesh_dim].append( - torch.Size(tuple(this_shard_size)) - ) - # Make sure it's a tuple: + # Plain int tuples (never torch.Size) -- see + # ShardTensorSpec._sharding_shapes field docs. + output_shard_sizes[mesh_dim].append(tuple(this_shard_size)) output_shard_sizes[mesh_dim] = tuple(output_shard_sizes[mesh_dim]) - ctx.output_shard_sizes = output_shard_sizes - return_tensor = ShardTensor.from_local( local_values, device_mesh=tensor._spec.mesh, @@ -156,6 +142,22 @@ def forward( f"Index select is not implemented for {index_placement} sharding." ) + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save the source ShardTensorSpec, local shape, dim, and index for backward. + + ``DisableTorchFunctionSubclass`` shielding avoids re-entering the + ShardTensor ``__torch_function__`` fallback while reading + ``tensor._spec`` / ``tensor._local_tensor`` -- the same AOT-hostile + bridge motivated the shielding in ``ShardedSum.setup_context``. + """ + tensor, dim, index = inputs + with torch._C.DisableTorchFunctionSubclass(): + ctx.spec = tensor._spec + ctx.grad_shape = tensor._local_tensor.shape + ctx.dim = dim + ctx.save_for_backward(index) + @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_output: ShardTensor @@ -347,9 +349,8 @@ def sharded_select_helper(tensor: ShardTensor, dim: int, index: int) -> ShardTen for local_chunk_size in index_shard_sizes: local_chunk_size_list = list(local_chunk_size) local_chunk_size_list.pop(dim) - output_shard_sizes[mesh_dim].append( - torch.Size(tuple(local_chunk_size_list)) - ) + # Plain int tuples (never torch.Size) for _sharding_shapes. + output_shard_sizes[mesh_dim].append(tuple(local_chunk_size_list)) output_shard_sizes[mesh_dim] = tuple(output_shard_sizes[mesh_dim]) output_spec = ShardTensorSpec( @@ -433,9 +434,8 @@ def sharded_select_backward_helper( # We need to insert input_sizes[dim] at index: local_chunk_size_list = list(local_chunk_size) local_chunk_size_list.insert(dim, input_sizes[dim]) - output_shard_sizes[mesh_dim].append( - torch.Size(tuple(local_chunk_size_list)) - ) + # Plain int tuples (never torch.Size) for _sharding_shapes. + output_shard_sizes[mesh_dim].append(tuple(local_chunk_size_list)) output_shard_sizes[mesh_dim] = tuple(output_shard_sizes[mesh_dim]) output_spec = ShardTensorSpec( diff --git a/physicsnemo/domain_parallel/shard_utils/knn.py b/physicsnemo/domain_parallel/shard_utils/knn.py index a07db91435..d3558b8a67 100644 --- a/physicsnemo/domain_parallel/shard_utils/knn.py +++ b/physicsnemo/domain_parallel/shard_utils/knn.py @@ -257,9 +257,9 @@ def knn_sharded_wrapper( output_queries_shard_shapes = {} for mesh_dim in input_queries_spec.sharding_shapes().keys(): + # Plain int tuples (never torch.Size) for _sharding_shapes. shard_shapes = tuple( - torch.Size((s[0], k)) - for s in input_queries_spec.sharding_shapes()[mesh_dim] + (int(s[0]), int(k)) for s in input_queries_spec.sharding_shapes()[mesh_dim] ) output_queries_shard_shapes[mesh_dim] = shard_shapes diff --git a/physicsnemo/domain_parallel/shard_utils/mesh_ops.py b/physicsnemo/domain_parallel/shard_utils/mesh_ops.py index 99ea557efa..540d68fd08 100644 --- a/physicsnemo/domain_parallel/shard_utils/mesh_ops.py +++ b/physicsnemo/domain_parallel/shard_utils/mesh_ops.py @@ -82,8 +82,10 @@ def sharded_signed_distance_field( # Output shape is always (N, 1), hit point is (N, 3) input_shard_shapes = input_points._spec.sharding_shapes() + # Plain int tuples (never torch.Size) for _sharding_shapes -- see + # ShardTensorSpec field docs for the dynamo / fakeification rationale. output_shard_shapes = { - mesh_dim: tuple(torch.Size((s[0],)) for s in input_shard_shapes[mesh_dim]) + mesh_dim: tuple((int(s[0]),) for s in input_shard_shapes[mesh_dim]) for mesh_dim in input_shard_shapes.keys() } diff --git a/physicsnemo/domain_parallel/shard_utils/normalization_patches.py b/physicsnemo/domain_parallel/shard_utils/normalization_patches.py index 4685a9a49f..272cd975dd 100644 --- a/physicsnemo/domain_parallel/shard_utils/normalization_patches.py +++ b/physicsnemo/domain_parallel/shard_utils/normalization_patches.py @@ -33,7 +33,7 @@ from typing import Any, Callable import torch -import torch.distributed as dist +import torch.distributed._functional_collectives as funcol from torch.distributed.tensor import DTensor from physicsnemo.domain_parallel import ShardTensor, ShardTensorSpec @@ -69,20 +69,17 @@ class PartialGroupNorm(torch.autograd.Function): @staticmethod def forward( - ctx: Any, input: torch.Tensor, spec: ShardTensorSpec, num_groups: int, weight: torch.Tensor | None, bias: torch.Tensor | None, eps: float, - ) -> ShardTensor: + ) -> tuple[ShardTensor, torch.Tensor, torch.Tensor]: r"""Apply group normalization over a sharded tensor. Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving tensors/variables for backward. input : torch.Tensor Local input tensor of shape :math:`(N, C, *)`. spec : ShardTensorSpec @@ -98,8 +95,12 @@ def forward( Returns ------- - ShardTensor - Normalized tensor of same shape as input. + tuple[ShardTensor, torch.Tensor, torch.Tensor] + Tuple of ``(normalized_output, global_mean, global_rstd)`` where + the trailing two tensors of shape ``(N, G)`` are intermediate + statistics needed by the backward pass; the public wrapper + ``group_norm_wrapper`` discards them, and ``setup_context`` + marks them non-differentiable. """ # These are local shapes: N, C = input.shape[0], input.shape[1] @@ -120,8 +121,6 @@ def forward( "Group normalization is not implemented for sharded tensors along the channel dimension" ) - group = spec.mesh.get_group(mesh_dim=0) - # Cast weight/bias to input dtype once. if weight is not None: weight = weight.to(input.dtype) @@ -143,9 +142,16 @@ def forward( local_sum = x.sum(dim=2) # (N, G) local_sum_sq = x.pow(2).sum(dim=2) # (N, G) - # Fuse into one all-reduce for lower latency. + # Fuse into one all-reduce for lower latency. We use the functional + # collective (parameterized by (mesh, mesh_dim)) rather than + # ``dist.all_reduce(group=...)`` so the AOT-captured backward graph + # holds a ``DeviceMesh`` reference instead of a C++ ``ProcessGroup`` + # ScriptObject; AOTAutograd ``deepcopy``s the backward GraphModule + # during caching and ``ProcessGroup`` has no ``__getstate__``. + # ``mesh_dim=0`` is safe here because ``mesh.ndim > 1`` is rejected + # above. packed = torch.stack([local_sum, local_sum_sq], dim=0) # (2, N, G) - dist.all_reduce(packed, group=group) + packed = funcol.all_reduce(packed, "sum", (spec.mesh, 0)) global_sum, global_sum_sq = packed[0], packed[1] global_mean = (global_sum / D_global).unsqueeze(2) # (N, G, 1) @@ -175,26 +181,55 @@ def forward( local_output = y.view(input.shape) - # -- Save for backward ---------------------------------------------- - ctx.save_for_backward(input, weight, bias) - ctx.global_mean = global_mean.squeeze(2) # (N, G) - ctx.global_rstd = global_rstd.squeeze(2) # (N, G) - ctx.num_groups = num_groups - ctx.eps = eps - ctx.spec = spec - - return ShardTensor.from_local( + shard_output = ShardTensor.from_local( local_output, spec.mesh, spec.placements, sharding_shapes=spec.sharding_shapes(), ) + # Return statistics so setup_context can save them; the public + # wrapper discards these extras and they are marked non-diff. + return shard_output, global_mean.squeeze(2), global_rstd.squeeze(2) + + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save tensors and metadata for the backward pass. + + ``global_mean`` and ``global_rstd`` are intermediate statistics + computed by ``forward`` and returned as extra outputs; we save them + via ``save_for_backward`` and mark them non-differentiable so they + don't appear in the autograd graph as live tensors. + """ + input, spec, num_groups, weight, bias, eps = inputs + _shard_output, global_mean, global_rstd = output + + # Re-cast weight/bias to match the input dtype so the backward sees + # the same dtype as forward used. + if weight is not None: + weight = weight.to(input.dtype) + if bias is not None: + bias = bias.to(input.dtype) + + ctx.save_for_backward(input, weight, bias, global_mean, global_rstd) + ctx.num_groups = num_groups + ctx.eps = eps + ctx.spec = spec + ctx.mark_non_differentiable(global_mean, global_rstd) + @staticmethod def backward( - ctx: Any, grad_output: ShardTensor + ctx: Any, + grad_output: ShardTensor, + _grad_mean: torch.Tensor | None = None, + _grad_rstd: torch.Tensor | None = None, ) -> tuple[ - torch.Tensor, None, None, torch.Tensor | None, torch.Tensor | None, None + torch.Tensor, + None, + None, + torch.Tensor | None, + torch.Tensor | None, + None, ]: r"""Backward pass for distributed group normalization. @@ -218,7 +253,7 @@ def backward( Tuple containing gradients for (input, spec, num_groups, weight, bias, eps). ``None`` values indicate non-differentiable parameters. """ - input, weight, bias = ctx.saved_tensors + input, weight, bias, global_mean, global_rstd = ctx.saved_tensors num_groups = ctx.num_groups N, C = input.shape[0], input.shape[1] channels_per_group = C // num_groups @@ -230,11 +265,7 @@ def backward( if local_grad_output.dtype != input.dtype: local_grad_output = local_grad_output.to(input.dtype) - global_mean = ctx.global_mean # (N, G) - global_rstd = ctx.global_rstd # (N, G) - spec = ctx.spec - group = spec.mesh.get_group(mesh_dim=0) # Total elements in reduction dimension (correct for uneven sharding). global_spatial = spec.tensor_meta.shape[2:] @@ -267,8 +298,10 @@ def backward( sum_dx_hat = dx_hat.sum(dim=2, keepdim=True) # (N, G, 1) sum_dx_hat_y = (dx_hat * y).sum(dim=2, keepdim=True) # (N, G, 1) + # Functional collective: keeps the AOT backward graph free of raw + # ProcessGroup references (see forward for the full rationale). packed_sums = torch.cat([sum_dx_hat, sum_dx_hat_y], dim=2) # (N, G, 2) - dist.all_reduce(packed_sums, group=group) + packed_sums = funcol.all_reduce(packed_sums, "sum", (spec.mesh, 0)) sum_dx_hat = packed_sums[:, :, :1] # (N, G, 1) sum_dx_hat_y = packed_sums[:, :, 1:] # (N, G, 1) @@ -282,25 +315,26 @@ def backward( grad_weight = None grad_bias = None - if weight is not None and weight.requires_grad: + if weight is not None and ctx.needs_input_grad[3]: # grad_weight_c = sum_{n, spatial} grad_output * y (per-channel) y_c = y.view(N, C, HxW_local) grad_out_c = local_grad_output.view(N, C, HxW_local) grad_weight = (grad_out_c * y_c).sum(dim=(0, 2)) # (C,) - if bias is not None and bias.requires_grad: + if bias is not None and ctx.needs_input_grad[4]: grad_out_c = local_grad_output.view(N, C, HxW_local) grad_bias = grad_out_c.sum(dim=(0, 2)) # (C,) - # Fuse the two small all-reduces when both are needed. + # Fuse the two small all-reduces when both are needed. Same functional + # collective rationale as above. if grad_weight is not None and grad_bias is not None: packed_wb = torch.stack([grad_weight, grad_bias], dim=0) # (2, C) - dist.all_reduce(packed_wb, group=group) + packed_wb = funcol.all_reduce(packed_wb, "sum", (spec.mesh, 0)) grad_weight, grad_bias = packed_wb[0], packed_wb[1] elif grad_weight is not None: - dist.all_reduce(grad_weight, group=group) + grad_weight = funcol.all_reduce(grad_weight, "sum", (spec.mesh, 0)) elif grad_bias is not None: - dist.all_reduce(grad_bias, group=group) + grad_bias = funcol.all_reduce(grad_bias, "sum", (spec.mesh, 0)) return grad_input, None, None, grad_weight, grad_bias, None @@ -341,7 +375,10 @@ def group_norm_wrapper( bias = bias.full_tensor() output_spec = input._spec - x = PartialGroupNorm.apply( + # PartialGroupNorm returns (output, global_mean, global_rstd); the two + # extras are intermediate statistics marked non-differentiable and only + # needed by its backward pass. + x, _, _ = PartialGroupNorm.apply( input.to_local(), output_spec, num_groups, weight, bias, eps ) diff --git a/physicsnemo/domain_parallel/shard_utils/point_cloud_ops.py b/physicsnemo/domain_parallel/shard_utils/point_cloud_ops.py index 23e0ebb006..763bfdf758 100644 --- a/physicsnemo/domain_parallel/shard_utils/point_cloud_ops.py +++ b/physicsnemo/domain_parallel/shard_utils/point_cloud_ops.py @@ -119,21 +119,22 @@ def ring_ball_query( mp = indices_shard.shape[-1] d = queries.shape[-1] + # Plain int tuples (never torch.Size) for _sharding_shapes -- see + # ShardTensorSpec._sharding_shapes field docs. indices_shard_output_sharding = { 0: tuple( - torch.Size([*s[:q_shard_dim], s[q_shard_dim], mp]) + tuple([*s[:q_shard_dim], s[q_shard_dim], mp]) for s in queries_shard_sizes ), } num_neighbors_shard_output_sharding = { 0: tuple( - torch.Size([*s[:q_shard_dim], s[q_shard_dim]]) - for s in queries_shard_sizes + tuple([*s[:q_shard_dim], s[q_shard_dim]]) for s in queries_shard_sizes ), } outputs_shard_output_sharding = { 0: tuple( - torch.Size([*s[:q_shard_dim], s[q_shard_dim], mp, d]) + tuple([*s[:q_shard_dim], s[q_shard_dim], mp, d]) for s in queries_shard_sizes ), } @@ -218,14 +219,15 @@ def ringless_ball_query( q_shard_dim = queries_placement.dim if queries_placement.is_shard() else 0 for i_dim, s in queries._spec.sharding_shapes().items(): + # Plain int tuples (never torch.Size) for _sharding_shapes. indices_placement[i_dim] = tuple( - torch.Size([*_s[:q_shard_dim], _s[q_shard_dim], max_points]) for _s in s + tuple([*_s[:q_shard_dim], _s[q_shard_dim], max_points]) for _s in s ) num_neighbors_placement[i_dim] = tuple( - torch.Size([*_s[:q_shard_dim], _s[q_shard_dim]]) for _s in s + tuple([*_s[:q_shard_dim], _s[q_shard_dim]]) for _s in s ) output_points_placement[i_dim] = tuple( - torch.Size([*_s[:q_shard_dim], _s[q_shard_dim], max_points, 3]) for _s in s + tuple([*_s[:q_shard_dim], _s[q_shard_dim], max_points, 3]) for _s in s ) indices = ShardTensor.from_local( @@ -396,7 +398,6 @@ class RingBallQuery(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, points: torch.Tensor, queries: torch.Tensor, mesh: Any, @@ -409,8 +410,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Context for saving variables for backward pass. points : torch.Tensor First set of points. queries : torch.Tensor @@ -431,9 +430,6 @@ def forward( Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple of (mapping, outputs, None, num_neighbors) tensors. """ - ctx.mesh = mesh - ctx.ring_config = ring_config - # Create buffers to store outputs current_indices = None current_num_neighbors = None @@ -452,10 +448,10 @@ def forward( # For uneven point clouds, the global stide is important: strides = [s[shard_dim] for s in shard_sizes] - ctx.max_points = bq_kwargs["max_points"] - ctx.radius = bq_kwargs["radius"] - ctx.return_dists = bq_kwargs["return_dists"] - ctx.return_points = bq_kwargs["return_points"] + max_points = bq_kwargs["max_points"] + radius = bq_kwargs["radius"] + return_dists = bq_kwargs["return_dists"] + return_points = bq_kwargs["return_points"] for i in range(world_size): source_rank = (mesh_rank - i) % world_size @@ -468,10 +464,10 @@ def forward( ) = radius_search_impl( current_points, current_queries, - ctx.radius, - ctx.max_points, - ctx.return_dists, - ctx.return_points, + radius, + max_points, + return_dists, + return_points, ) # Store the result with its source rank rank_results[source_rank] = ( @@ -488,8 +484,8 @@ def forward( # Don't do a ring on the last iteration. current_points = perform_ring_iteration( current_points, - ctx.mesh, - ctx.ring_config, + mesh, + ring_config, recv_shape=shard_sizes[next_source_rank], ) @@ -511,12 +507,27 @@ def forward( ) stride += strides[r] - ctx.save_for_backward( - points, queries, current_indices, current_num_neighbors, current_out_points - ) return current_indices, current_out_points, None, current_num_neighbors + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save context for the (currently unimplemented) backward pass. + + Backward is not implemented and raises ``MissingShardPatch``. We still + stash ``mesh``, ``ring_config``, and the ball-query kwargs so that if a + backward is added later it has the information it needs. + """ + _points, _queries, mesh, ring_config, _shard_sizes, _shard_dim, bq_kwargs = ( + inputs + ) + ctx.mesh = mesh + ctx.ring_config = ring_config + ctx.max_points = bq_kwargs["max_points"] + ctx.radius = bq_kwargs["radius"] + ctx.return_dists = bq_kwargs["return_dists"] + ctx.return_points = bq_kwargs["return_points"] + @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, @@ -556,16 +567,13 @@ class GradReducer(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, input: torch.Tensor, spec: ShardTensorSpec, ) -> torch.Tensor: - r"""Forward pass that saves the spec for backward. + r"""Forward pass: return the input tensor unchanged. Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving variables for backward. input : torch.Tensor Input tensor to pass through. spec : ShardTensorSpec @@ -576,9 +584,14 @@ def forward( torch.Tensor The input tensor unchanged. """ - ctx.spec = spec return input + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save the input ShardTensorSpec for the backward all-reduce.""" + _input, spec = inputs + ctx.spec = spec + @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, diff --git a/physicsnemo/domain_parallel/shard_utils/unary_ops.py b/physicsnemo/domain_parallel/shard_utils/unary_ops.py index 27c972857d..ac7c8c0a86 100644 --- a/physicsnemo/domain_parallel/shard_utils/unary_ops.py +++ b/physicsnemo/domain_parallel/shard_utils/unary_ops.py @@ -40,8 +40,8 @@ aten = torch.ops.aten -def unsqueeze_shape(shape: torch.Size | Sequence[int], dim: int) -> torch.Size: - r"""Return a new torch.Size with a singleton dimension inserted at ``dim``. +def unsqueeze_shape(shape: torch.Size | Sequence[int], dim: int) -> tuple[int, ...]: + r"""Return a new plain int tuple with a singleton dimension inserted at ``dim``. If ``dim`` is within the current rank, the new dimension is inserted at that index. This mirrors the behavior of ``torch.unsqueeze`` at the shape level. @@ -55,12 +55,13 @@ def unsqueeze_shape(shape: torch.Size | Sequence[int], dim: int) -> torch.Size: Returns ------- - torch.Size - A new ``torch.Size`` with the inserted dimension. + tuple[int, ...] + A plain int tuple with the inserted dimension (never a ``torch.Size``, + so it can be safely embedded in ``ShardTensorSpec._sharding_shapes``). """ o_shape = list(shape) o_shape.insert(dim, 1) - return torch.Size(tuple(o_shape)) + return tuple(o_shape) def normalize_dim(dim: int, tensor_rank: int) -> int: @@ -152,7 +153,7 @@ def unsqueeze_wrapper( output_placements.append(p) in_sharding_shapes = input._spec.sharding_shapes() - out_sharding_shapes: dict[int, list[torch.Size]] = { + out_sharding_shapes: dict[int, list[tuple[int, ...]]] = { mesh_dim: [unsqueeze_shape(s, dim) for s in in_sharding_shapes[mesh_dim]] for mesh_dim in in_sharding_shapes.keys() } diff --git a/physicsnemo/domain_parallel/shard_utils/unpooling_patches.py b/physicsnemo/domain_parallel/shard_utils/unpooling_patches.py index 07c7ab8f36..038a8f208b 100644 --- a/physicsnemo/domain_parallel/shard_utils/unpooling_patches.py +++ b/physicsnemo/domain_parallel/shard_utils/unpooling_patches.py @@ -294,8 +294,9 @@ def partial_interpolate_nd( result_shapes = {} for mesh_dim, sharding_shape in input._spec.sharding_shapes().items(): + # Plain int tuples (never torch.Size) for _sharding_shapes. updated_shapes = tuple( - torch.Size(compute_interpolate_output_shape(s, interp_kwargs)) + tuple(compute_interpolate_output_shape(s, interp_kwargs)) for s in sharding_shape ) result_shapes[mesh_dim] = updated_shapes diff --git a/physicsnemo/domain_parallel/shard_utils/view_ops.py b/physicsnemo/domain_parallel/shard_utils/view_ops.py index 42120272ab..621f450994 100644 --- a/physicsnemo/domain_parallel/shard_utils/view_ops.py +++ b/physicsnemo/domain_parallel/shard_utils/view_ops.py @@ -385,11 +385,11 @@ def _compute_view_placements( def _compute_view_sharding_shapes( - old_sharding_shapes: dict[int, tuple[torch.Size, ...]] | None, + old_sharding_shapes: dict[int, tuple[tuple[int, ...], ...]] | None, global_old: Sequence[int], global_new: Sequence[int], placements: tuple[Placement, ...], -) -> dict[int, tuple[torch.Size, ...]] | None: +) -> dict[int, tuple[tuple[int, ...], ...]] | None: r"""Compute new per-rank sharding shapes after a view. For each rank, maps its old local shape to the new local shape using the @@ -398,7 +398,7 @@ def _compute_view_sharding_shapes( Parameters ---------- - old_sharding_shapes : dict[int, tuple[torch.Size, ...]] or None + old_sharding_shapes : dict[int, tuple[tuple[int, ...], ...]] or None Old sharding shapes from the input spec. global_old : Sequence[int] Global shape before view. @@ -409,20 +409,22 @@ def _compute_view_sharding_shapes( Returns ------- - dict[int, tuple[torch.Size, ...]] or None - New sharding shapes, or ``None`` if input was ``None``. + dict[int, tuple[tuple[int, ...], ...]] or None + New sharding shapes as plain int tuples (see + ``ShardTensorSpec._sharding_shapes`` field docs), or ``None`` if + input was ``None``. """ if old_sharding_shapes is None: return None - new_sharding: dict[int, tuple[torch.Size, ...]] = {} + new_sharding: dict[int, tuple[tuple[int, ...], ...]] = {} for mesh_dim, rank_shapes in old_sharding_shapes.items(): - new_rank_shapes: list[torch.Size] = [] + new_rank_shapes: list[tuple[int, ...]] = [] for rank_shape in rank_shapes: new_shape = _compute_local_view_shape( global_old, tuple(rank_shape), global_new, placements ) - new_rank_shapes.append(torch.Size(new_shape)) + new_rank_shapes.append(tuple(new_shape)) new_sharding[mesh_dim] = tuple(new_rank_shapes) return new_sharding @@ -604,7 +606,6 @@ class ShardedView(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, tensor: ShardTensor, target_shape: tuple[int, ...], ) -> ShardTensor: @@ -612,8 +613,6 @@ def forward( Parameters ---------- - ctx : torch.autograd.function.FunctionCtx - Autograd context for saving state for backward. tensor : ShardTensor Input sharded tensor. target_shape : tuple[int, ...] @@ -624,8 +623,21 @@ def forward( ShardTensor Viewed ShardTensor. """ - ctx.input_global_shape = tuple(tensor.shape) - return _sharded_view_forward(tensor, target_shape) + out = _sharded_view_forward(tensor, target_shape) + return out + + @staticmethod + def setup_context(ctx, inputs, output) -> None: + r"""Save the input global shape so backward can view back to it. + + ``DisableTorchFunctionSubclass`` shielding avoids re-entering the + ShardTensor ``__torch_function__`` fallback while reading + ``tensor.shape`` (a C-level getset descriptor) -- the same AOT-hostile + bridge that motivated the shielding in ``ShardedSum.setup_context``. + """ + tensor, _target_shape = inputs + with torch._C.DisableTorchFunctionSubclass(): + ctx.input_global_shape = tuple(tensor.shape) @staticmethod def backward( @@ -646,6 +658,7 @@ def backward( tuple[ShardTensor, None] Gradient for the input tensor, and ``None`` for ``target_shape``. """ + return ( _sharded_view_forward(grad_output, ctx.input_global_shape), None, diff --git a/test/domain_parallel/ops/test_compile_ops.py b/test/domain_parallel/ops/test_compile_ops.py new file mode 100644 index 0000000000..bbc5ac3162 --- /dev/null +++ b/test/domain_parallel/ops/test_compile_ops.py @@ -0,0 +1,456 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Compile-traceability tests for the refactored ShardTensor autograd ops. + +These tests guarantee that the custom ``torch.autograd.Function`` subclasses +in ``physicsnemo/domain_parallel/`` can be traced through +``torch.compile(backend="aot_eager", fullgraph=True)``. The migration from +the old-style ``forward(ctx, ...)`` API to the new-style +``forward(...)`` + ``setup_context(ctx, inputs, output)`` API is the +prerequisite for AOTAutograd to traverse the backward graph; this file is +the regression net. + +The assertions here are intentionally lightweight: each test compiles a +small module that exercises one refactored op and verifies that forward +(and, where backward is supported, ``loss.backward()``) does not raise. +End-to-end numerical correctness is already covered by the sibling +non-compile tests in this directory. +""" + +from typing import Any + +import pytest +import torch +import torch.distributed as dist +from torch.distributed.tensor.placement_types import Replicate, Shard + +from physicsnemo.distributed import DistributedManager +from physicsnemo.domain_parallel import scatter_tensor +from physicsnemo.domain_parallel.shard_utils.halo import HaloConfig, unhalo_padding +from physicsnemo.domain_parallel.shard_utils.point_cloud_ops import GradReducer + + +def _scalar_loss(out: Any) -> torch.Tensor: + """Reduce arbitrary tensor-like outputs to a scalar for ``.backward()``.""" + if isinstance(out, tuple): + out = out[0] + return out.float().sum() + + +def _run_compile_fwd_bwd( + module: torch.nn.Module, + inputs: list, + *, + backward: bool = True, + fullgraph: bool = True, +) -> Any: + r"""Compile ``module`` with ``aot_eager`` and run forward (+ optional bwd). + + Smoke check that the autograd Function dispatched inside ``module`` is + AOT-traceable. Resets dynamo first so each test is independent. + """ + torch._dynamo.reset() + compiled = torch.compile(module, backend="aot_eager", fullgraph=fullgraph) + output = compiled(*inputs) + + if backward: + loss = _scalar_loss(output) + loss.backward() + return output + + +# --------------------------------------------------------------------------- +# Module wrappers +# --------------------------------------------------------------------------- + + +class MeanWrapper(torch.nn.Module): + r"""``tensor.mean(dim)`` on a ShardTensor (exercises ``ShardedMean``).""" + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.mean(dim=self.dim) + + +class ViewWrapper(torch.nn.Module): + r"""``tensor.view(target_shape)`` on a ShardTensor (exercises ``ShardedView``).""" + + def __init__(self, target_shape: tuple[int, ...]): + super().__init__() + self.target_shape = target_shape + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.view(self.target_shape) + + +class RedistributeWrapper(torch.nn.Module): + r"""``tensor.redistribute(...)`` (exercises ``ShardRedistribute``).""" + + def __init__(self, mesh, placements): + super().__init__() + self.mesh = mesh + self.placements = placements + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.redistribute(self.mesh, self.placements) + + +class IndexSelectWrapper(torch.nn.Module): + r"""``torch.index_select(...)`` on a ShardTensor (exercises ``ShardedIndexSelect``).""" + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, tensor: torch.Tensor, index: torch.Tensor) -> torch.Tensor: + return torch.index_select(tensor, self.dim, index.flatten()) + + +class UnhaloPaddingWrapper(torch.nn.Module): + r"""``unhalo_padding(...)`` (exercises ``UnHaloPadding``).""" + + def __init__(self, mesh, halo_config: HaloConfig): + super().__init__() + self.mesh = mesh + self.halo_config = halo_config + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + return unhalo_padding(tensor, self.mesh, self.halo_config) + + +class GroupNormWrapper(torch.nn.Module): + r"""``F.group_norm`` on a ShardTensor (exercises ``PartialGroupNorm``).""" + + def __init__(self, num_groups: int, num_channels: int): + super().__init__() + self.gn = torch.nn.GroupNorm(num_groups, num_channels) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + return self.gn(tensor) + + +class SDPAWrapper(torch.nn.Module): + r"""``F.scaled_dot_product_attention`` on sharded Q/K/V (exercises RingSDPA).""" + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + + +class GradReducerWrapper(torch.nn.Module): + r"""``GradReducer.apply(tensor, spec)`` (exercises ``GradReducer``). + + The ``spec`` is captured as a non-tensor module attribute so it is a + constant from dynamo's perspective. ``GradReducer`` is the trivial + identity in forward; the work happens in backward (all-reduce on + replicated mesh dims). + """ + + def __init__(self, spec): + super().__init__() + self.spec = spec + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + return GradReducer.apply(tensor, self.spec) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_sharded_mean_1d(distributed_mesh): + r"""Compile + backward through ``ShardedMean`` on a sharded dim.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + shape = (4, 64, 16) + original = torch.rand(shape, device=dm.device, requires_grad=True) + sharded = scatter_tensor( + original, + global_src=0, + mesh=distributed_mesh, + placements=(Shard(1),), + requires_grad=True, + ) + + _run_compile_fwd_bwd(MeanWrapper(dim=1), [sharded]) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_sharded_view_1d(distributed_mesh): + r"""Compile + backward through ``ShardedView`` (merge last two dims).""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + shape = (4, 64, 8, 4) + target_shape = (4, 64, 32) + original = torch.rand(shape, device=dm.device, requires_grad=True) + sharded = scatter_tensor( + original, + global_src=0, + mesh=distributed_mesh, + placements=(Shard(1),), + requires_grad=True, + ) + + _run_compile_fwd_bwd(ViewWrapper(target_shape=target_shape), [sharded]) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_shard_redistribute_1d(distributed_mesh): + r"""Compile + backward through ``ShardRedistribute`` (Shard -> Replicate).""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + shape = (4, 64, 16) + original = torch.rand(shape, device=dm.device, requires_grad=True) + sharded = scatter_tensor( + original, + global_src=0, + mesh=distributed_mesh, + placements=(Shard(1),), + requires_grad=True, + ) + + _run_compile_fwd_bwd( + RedistributeWrapper(mesh=distributed_mesh, placements=(Replicate(),)), + [sharded], + ) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_shard_redistribute_2d(distributed_mesh_2d): + r"""Compile + backward through ``ShardRedistribute`` on a 2D mesh.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + shape = (4, 64, 32) + original = torch.rand(shape, device=dm.device, requires_grad=True) + sharded = scatter_tensor( + original, + global_src=0, + mesh=distributed_mesh_2d, + placements=(Shard(1), Shard(2)), + requires_grad=True, + ) + + _run_compile_fwd_bwd( + RedistributeWrapper( + mesh=distributed_mesh_2d, placements=(Replicate(), Replicate()) + ), + [sharded], + ) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_sharded_index_select_replicated_index_1d(distributed_mesh): + r"""Compile + backward through ``ShardedIndexSelect`` with a replicated index. + + A replicated ``index`` keeps the output sharding aligned with the input, + which is the cheaper / less collective-heavy code path inside the op. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + shape = (32, 32, 32) + dim = 1 + n_idx = 8 + + original = torch.rand(shape, device=dm.device, requires_grad=True) + index = torch.randint(low=0, high=shape[dim] - 1, size=(n_idx,), device=dm.device) + + sharded = scatter_tensor( + original, + global_src=0, + mesh=distributed_mesh, + placements=(Shard(2),), + requires_grad=True, + ) + sharded_index = scatter_tensor( + index, + global_src=0, + mesh=distributed_mesh, + placements=(Replicate(),), + requires_grad=False, + ) + + _run_compile_fwd_bwd(IndexSelectWrapper(dim=dim), [sharded, sharded_index]) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_unhalo_padding_1d(distributed_mesh): + r"""Compile + backward through ``UnHaloPadding``. + + Constructs a synthetic tensor that includes halo regions, then drops them + via the public ``unhalo_padding`` wrapper. Numerical correctness is + covered by ``test_padding.py``; this test only checks compile. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + local_group = distributed_mesh.get_group(0) + local_size = dist.get_world_size(group=local_group) + if local_size < 2: + pytest.skip("UnHaloPadding requires at least 2 ranks on the mesh dim") + + halo_size = 2 + H = 32 + # Build a per-rank tensor that already has halos baked in: each rank gets + # a slab of size H + 2*halo_size along dim 2 (except endpoints). + tensor = torch.rand(2, 4, H + 2 * halo_size, device=dm.device, requires_grad=True) + + halo_config = HaloConfig( + mesh_dim=0, + tensor_dim=2, + halo_size=halo_size, + edge_padding_size=halo_size, + communication_method="a2a", + ) + + _run_compile_fwd_bwd( + UnhaloPaddingWrapper(mesh=distributed_mesh, halo_config=halo_config), + [tensor], + ) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_partial_group_norm_1d(distributed_mesh): + r"""Compile + backward through ``PartialGroupNorm`` on a ShardTensor. + + Smoke-checks AOTAutograd traceability through the now-3-output + autograd Function. End-to-end numerical agreement against single-GPU + eager is covered by ``test_normalization.py``; here we only ensure + the compile + fwd + bwd path doesn't raise. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + N, C, H, W = 2, 8, 32, 32 + num_groups = 4 + + original = torch.rand(N, C, H, W, device=dm.device, requires_grad=True) + sharded = scatter_tensor( + original, + global_src=0, + mesh=distributed_mesh, + placements=(Shard(2),), + requires_grad=True, + ) + + module = GroupNormWrapper(num_groups=num_groups, num_channels=C).to(dm.device) + + _run_compile_fwd_bwd(module, [sharded]) + + +# Note: ``test_compile_ring_sdpa_1d`` (smoke-test that compile around sharded +# SDPA succeeds) was removed because the overlap variant's ``record_stream`` +# call cannot survive AOTAutograd functionalization, and the +# ``@torch.compiler.disable`` on ``ring_sdpa`` only suppresses dynamo's +# tracing of the body, not AOT's later re-execution of the captured FX +# graph (which re-enters via ``ShardTensor.__torch_function__`` on the +# captured SDPA node and trips ``record_stream``). Re-enabling compile of +# sharded SDPA requires a separate refactor (drop ``record_stream``, switch +# ``perform_ring_iteration`` to functional p2p collectives, etc.). The +# limitation is documented in ``ring_sdpa``'s docstring in +# ``shard_utils/attention_patches.py``. The eager path is fully covered by +# ``test_sdpa.py`` and ``test_ring_sdpa_overlap.py``. + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(60) +def test_compile_ring_sdpa_fullgraph_errors(distributed_mesh): + r"""Regression guard: ``torch.compile`` around sharded ring SDPA must error. + + Today compile around sharded SDPA fails at AOT functionalization of + ``aten::record_stream`` (an alias-annotated op called inside the + overlap variant). The exact exception type varies between PyTorch + versions, but it is *some* ``Exception``. We assert that, so if a + future refactor accidentally makes compile silently "succeed" without + actually wiring up a functional-collective ring (the only way it can + be correct under AOT), this test starts failing and forces us to + re-evaluate. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + batch_size, num_heads, seq_len, head_dim = 1, 4, 128, 32 + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=dm.device, requires_grad=True + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=dm.device, requires_grad=True + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=dm.device, requires_grad=True + ) + + q_s = scatter_tensor(q, 0, distributed_mesh, (Shard(2),), requires_grad=True) + k_s = scatter_tensor(k, 0, distributed_mesh, (Shard(2),), requires_grad=True) + v_s = scatter_tensor(v, 0, distributed_mesh, (Shard(2),), requires_grad=True) + + with pytest.raises(Exception): + _run_compile_fwd_bwd(SDPAWrapper(), [q_s, k_s, v_s], fullgraph=True) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_grad_reducer_1d(distributed_mesh): + r"""Compile + backward through ``GradReducer``. + + Forward is identity; backward all-reduces over each replicated mesh + dim. We feed a plain tensor + the spec from a Replicate-placed + ShardTensor so the backward path actually exercises the all-reduce. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + dm = DistributedManager() + + base = torch.rand(4, 16, device=dm.device) + replicated_shard = scatter_tensor( + base, + global_src=0, + mesh=distributed_mesh, + placements=(Replicate(),), + requires_grad=False, + ) + spec = replicated_shard._spec + + tensor = torch.rand(4, 16, device=dm.device, requires_grad=True) + + _run_compile_fwd_bwd(GradReducerWrapper(spec=spec), [tensor]) diff --git a/test/domain_parallel/ops/test_convolution.py b/test/domain_parallel/ops/test_convolution.py index a20ea6efe6..a26a8deeca 100644 --- a/test/domain_parallel/ops/test_convolution.py +++ b/test/domain_parallel/ops/test_convolution.py @@ -48,6 +48,33 @@ from .utils import generate_image_like_data, numerical_shard_tensor_check +@pytest.fixture(autouse=True) +def _disable_tf32_for_conv_equivalence(): + r"""Force FP32 precision in cuDNN/matmul for the duration of each conv test. + + ``numerical_shard_tensor_check`` asserts that the sharded conv output + matches the local single-GPU conv output within ``atol=rtol=1e-5``. On + Ampere+ GPUs, PyTorch defaults to TF32 (~10-bit mantissa) for cuDNN + convolutions and matmul, which gives a relative error of ~2^-10 ~ 1e-3 + that easily blows the 1e-5 budget once cuDNN picks a different algorithm + for the larger local-vs-sharded tensor shapes (consistently observed at + H=256: the H=128 cases happen to land on a higher-precision kernel and + therefore still pass). Disabling TF32 here removes the algorithm-pick + artifact and lets the equivalence check verify only the sharding + correctness, which is what this suite is for. Restored after the test + so global state isn't perturbed for other modules. + """ + matmul_prev = torch.backends.cuda.matmul.allow_tf32 + cudnn_prev = torch.backends.cudnn.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + try: + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = matmul_prev + torch.backends.cudnn.allow_tf32 = cudnn_prev + + @pytest.mark.multigpu_static @pytest.mark.parametrize("H", [32, 256]) @pytest.mark.parametrize( @@ -167,7 +194,7 @@ def test_conv_transpose_1d_1dmesh( @pytest.mark.multigpu_static -@pytest.mark.parametrize("H", [32, 256]) +@pytest.mark.parametrize("H", [128, 256]) @pytest.mark.parametrize( "C_in", [ @@ -222,7 +249,7 @@ def test_conv2d_1dmesh( @pytest.mark.multigpu_static -@pytest.mark.parametrize("H", [32, 256]) +@pytest.mark.parametrize("H", [128, 256]) @pytest.mark.parametrize( "C_in", [ @@ -265,7 +292,7 @@ def test_conv_transpose_2d_1dmesh( 2, C_in, ( - H, + 2 * H, H, ), device=dm.device, @@ -293,7 +320,7 @@ def test_conv_transpose_2d_1dmesh( @pytest.mark.multigpu_static -@pytest.mark.parametrize("H", [32, 256]) +@pytest.mark.parametrize("H", [128, 256]) @pytest.mark.parametrize( "C_in", [ @@ -355,7 +382,7 @@ def test_conv2d_2dmesh( @pytest.mark.multigpu_static -@pytest.mark.parametrize("H", [32, 256]) +@pytest.mark.parametrize("H", [64, 256]) @pytest.mark.parametrize( "C_in", [ @@ -405,8 +432,8 @@ def test_conv_transpose_2d_2dmesh( 2, C_in, ( - H, - H, + 2 * H, + 2 * H, ), device=dm.device, ) diff --git a/test/domain_parallel/ops/test_ring_sdpa_overlap.py b/test/domain_parallel/ops/test_ring_sdpa_overlap.py index f6842d7c44..2b4a3866c9 100644 --- a/test/domain_parallel/ops/test_ring_sdpa_overlap.py +++ b/test/domain_parallel/ops/test_ring_sdpa_overlap.py @@ -101,7 +101,10 @@ def test_ring_sdpa_forward_matches_blocking( attn_args = {"dropout_p": 0.0, "is_causal": False, "scale": None} - out_blocking = RingSDPABlocking.apply( + # Both RingSDPA and RingSDPABlocking now return (output, *stats); the + # trailing stats are intermediate non-differentiable tensors needed by + # backward only. + out_blocking, *_ = RingSDPABlocking.apply( q, k, v, @@ -110,7 +113,7 @@ def test_ring_sdpa_forward_matches_blocking( ring_config, attn_args, ) - out_overlap = RingSDPA.apply( + out_overlap, *_ = RingSDPA.apply( q, k, v, @@ -176,7 +179,9 @@ def test_ring_sdpa_backward_matches_blocking( k_b.requires_grad_(True) v_b.requires_grad_(True) - out_b = RingSDPABlocking.apply(q_b, k_b, v_b, None, mesh, ring_config, attn_args) + out_b, *_ = RingSDPABlocking.apply( + q_b, k_b, v_b, None, mesh, ring_config, attn_args + ) loss_b = out_b.mean() loss_b.backward() @@ -185,7 +190,7 @@ def test_ring_sdpa_backward_matches_blocking( k_o = k_b.detach().clone().requires_grad_(True) v_o = v_b.detach().clone().requires_grad_(True) - out_o = RingSDPA.apply(q_o, k_o, v_o, None, mesh, ring_config, attn_args) + out_o, *_ = RingSDPA.apply(q_o, k_o, v_o, None, mesh, ring_config, attn_args) loss_o = out_o.mean() loss_o.backward() diff --git a/test/domain_parallel/ops/test_unbind.py b/test/domain_parallel/ops/test_unbind.py deleted file mode 100644 index f7fb45f4c5..0000000000 --- a/test/domain_parallel/ops/test_unbind.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test unbind operations on ShardTensor. We use a 3D tensor sharded along -dim 2 and test unbinding along non-sharded dimensions. Both forward -correctness and backward gradient flow are verified. -""" - -import pytest -import torch -from torch.distributed.tensor.placement_types import Shard - -from physicsnemo.distributed import DistributedManager -from physicsnemo.domain_parallel import scatter_tensor - -from .utils import numerical_shard_tensor_check - - -class UnbindSelectWrapper(torch.nn.Module): - """ - Wrapper that unbinds a tensor and returns a single element from the - result tuple. This allows reuse of ``numerical_shard_tensor_check`` - which expects a single tensor output. - """ - - def __init__(self, dim: int, index: int): - super().__init__() - self.dim = dim - self.index = index - - def forward(self, tensor: torch.Tensor): - pieces = torch.unbind(tensor, self.dim) - return pieces[self.index] - - -@pytest.mark.multigpu_static -@pytest.mark.parametrize("backward", [False, True]) -@pytest.mark.parametrize("unbind_dim,index", [(0, 0), (0, 2), (1, 3), (-3, 0), (-2, 3)]) -def test_unbind(distributed_mesh, backward, unbind_dim, index): - """Verify forward and backward via ``numerical_shard_tensor_check``.""" - - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - - dm = DistributedManager() - shape = (4, 6, 128) - placements = (Shard(2),) - - original_tensor = torch.rand(shape, device=dm.device, requires_grad=backward) - - sharded_tensor = scatter_tensor( - original_tensor, - global_src=0, - mesh=distributed_mesh, - placements=placements, - requires_grad=True, - ) - - module = UnbindSelectWrapper(dim=unbind_dim, index=index) - - numerical_shard_tensor_check( - distributed_mesh, - module, - [sharded_tensor], - {}, - check_grads=backward, - ) - - -# -- Error tests -------------------------------------------------------------- - - -@pytest.mark.multigpu_static -def test_unbind_along_sharded_dim(distributed_mesh): - """Unbinding along the sharded dimension should raise.""" - - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - - dm = DistributedManager() - shape = (4, 6, 128) - placements = (Shard(2),) - - original_tensor = torch.rand(shape, device=dm.device) - - sharded_tensor = scatter_tensor( - original_tensor, - global_src=0, - mesh=distributed_mesh, - placements=placements, - requires_grad=False, - ) - - with pytest.raises(RuntimeError, match="unbinding along sharding axis"): - torch.unbind(sharded_tensor, 2) diff --git a/test/domain_parallel/ops/test_view_ops.py b/test/domain_parallel/ops/test_view_ops.py index 7094074f2c..dd1a2a7d8d 100644 --- a/test/domain_parallel/ops/test_view_ops.py +++ b/test/domain_parallel/ops/test_view_ops.py @@ -534,20 +534,23 @@ def test_view_trailing_dims_1d_to_3d( distributed_mesh, backward, ): - """Test view (6,) -> (2, 3, 1) with Shard(0): trailing dim must stay in group. + """Test view (48,) -> (8, 6, 1) with Shard(0): trailing singleton in target. - With the shard on dim 0, each rank has a contiguous chunk of the 1D tensor. - The target shape has a trailing singleton (2, 3, 1). The trailing dimension - must be included in the same dimension group so that the local element - count is correct (product of local shape equals chunk_size). Without that, - the old code produced wrong local shapes (e.g. product 4 instead of 2 or 3). + The 1D tensor is sharded on dim 0. The target shape has a trailing + singleton ``(8, 6, 1)`` that falls outside the dimension group matched + by ``_match_view_dim_groups`` (which pairs ``(48,)`` with ``(8, 6)``). + The trailing ``1`` must be carried through unchanged in the local shape + so that ``product(local_shape) == chunk_size``. + + We use a tensor size (48) that divides cleanly across 2-, 4-, and 8-GPU + meshes so that every rank's chunk aligns to a row boundary in ``(8, 6)``. """ if not torch.cuda.is_available(): pytest.skip("CUDA is not available") dm = DistributedManager() - shape = (6,) - target_shape = (2, 3, 1) + shape = (48,) + target_shape = (8, 6, 1) original_tensor = torch.rand(shape, device=dm.device, requires_grad=backward) diff --git a/test/domain_parallel/ops/utils.py b/test/domain_parallel/ops/utils.py index de8052c93c..fcf04f9c92 100644 --- a/test/domain_parallel/ops/utils.py +++ b/test/domain_parallel/ops/utils.py @@ -169,7 +169,9 @@ def unparallelize_module(module): This function is for testing purposes only. Do not use in production code. """ for name, param in list(module._parameters.items()): - if isinstance(param, torch.nn.Parameter) and isinstance(param.data, DTensor): + if isinstance(param, torch.nn.Parameter) and isinstance( + param.data, (ShardTensor, DTensor) + ): # gather to replicated then unwrap local_tensor = param.data.full_tensor() # replace with a normal Parameter diff --git a/test/domain_parallel/test_compile.py b/test/domain_parallel/test_compile.py new file mode 100644 index 0000000000..0dde02b80a --- /dev/null +++ b/test/domain_parallel/test_compile.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Tests for ShardTensor integration with ``torch.compile`` / AOTAutograd. + +The focus is on the runtime tangent-coercion hook +``ShardTensor.__coerce_same_metadata_as_tangent__``, which AOTAutograd +invokes during the compiled backward when the runtime tangent's spec +doesn't match the recorded one. The tests cover uneven sharding, which +DTensor does not have to handle and which earlier coerce implementations +silently dropped (defaulting back to even chunking). +""" + +import pytest +import torch +from torch.distributed.tensor.placement_types import Replicate + +from physicsnemo.domain_parallel import ShardTensor +from physicsnemo.domain_parallel._shard_tensor_spec import ShardTensorSpec +from test.domain_parallel.test_redistribute import shard_tensor_factory + + +def _replicate_placements(mesh): + return [Replicate()] * mesh.ndim + + +def run_coerce_replicate_to_uneven_shard(mesh): + # Round-trip: uneven Shard -> Replicate -> coerce back to recorded uneven Shard. + st_uneven = shard_tensor_factory(mesh, uneven=True) + recorded_spec = st_uneven._spec + expected_local_shape = tuple(st_uneven._local_tensor.shape) + expected_full = st_uneven.full_tensor().clone() + + st_replicated = st_uneven.redistribute(placements=_replicate_placements(mesh)) + + coerced = st_replicated.__coerce_same_metadata_as_tangent__((recorded_spec, False)) + + assert isinstance(coerced, ShardTensor) + assert coerced._spec.placements == recorded_spec.placements + assert tuple(coerced._local_tensor.shape) == expected_local_shape + assert torch.allclose(coerced.full_tensor(), expected_full) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(120) +def test_coerce_replicate_to_uneven_shard_1d(distributed_mesh): + run_coerce_replicate_to_uneven_shard(distributed_mesh) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(120) +def test_coerce_replicate_to_uneven_shard_2d(distributed_mesh_2d): + run_coerce_replicate_to_uneven_shard(distributed_mesh_2d) + + +def run_coerce_same_placements_unknown_shapes(mesh): + # Recorded spec carries the same placements but no _sharding_shapes; the + # hook must accept it without erroring and preserve local data. + st = shard_tensor_factory(mesh, uneven=True) + expected_local_shape = tuple(st._local_tensor.shape) + expected_full = st.full_tensor().clone() + + modified_spec = ShardTensorSpec( + mesh=st._spec.mesh, + placements=st._spec.placements, + tensor_meta=st._spec.tensor_meta, + _sharding_shapes=None, + ) + + coerced = st.__coerce_same_metadata_as_tangent__((modified_spec, False)) + + assert isinstance(coerced, ShardTensor) + assert coerced._spec.placements == st._spec.placements + assert coerced._spec._sharding_shapes is None + assert tuple(coerced._local_tensor.shape) == expected_local_shape + assert torch.allclose(coerced.full_tensor(), expected_full) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(120) +def test_coerce_same_placements_unknown_shapes_1d(distributed_mesh): + run_coerce_same_placements_unknown_shapes(distributed_mesh) + + +def run_coerce_expected_type_returns_none(mesh): + # Mismatched expected_type must short-circuit to None (DTensor convention). + st = shard_tensor_factory(mesh, uneven=True) + out = st.__coerce_same_metadata_as_tangent__( + (st._spec, False), expected_type=torch.Tensor + ) + assert out is None + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(120) +def test_coerce_expected_type_returns_none_1d(distributed_mesh): + run_coerce_expected_type_returns_none(distributed_mesh) + + +def _sum_squares(x): + return (x**2).sum() + + +def run_compile_backward_uneven_shard(mesh): + # Smoke test: compile + backward over an uneven ShardTensor must not raise + # AOTAutograd's "guessed metadata incorrectly" tangent error. Gradient values + # are validated by the direct __coerce_same_metadata_as_tangent__ tests. + x = shard_tensor_factory(mesh, uneven=True).detach().requires_grad_(True) + + torch._dynamo.reset() + compiled = torch.compile(_sum_squares, fullgraph=True, backend="aot_eager") + + loss = compiled(x) + loss.backward() + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_backward_uneven_shard_1d(distributed_mesh): + run_compile_backward_uneven_shard(distributed_mesh) + + +@pytest.mark.multigpu_static +@pytest.mark.timeout(180) +def test_compile_backward_uneven_shard_2d(distributed_mesh_2d): + run_compile_backward_uneven_shard(distributed_mesh_2d) diff --git a/test/domain_parallel/test_grad_sharding.py b/test/domain_parallel/test_grad_sharding.py index 871782459b..167044159d 100644 --- a/test/domain_parallel/test_grad_sharding.py +++ b/test/domain_parallel/test_grad_sharding.py @@ -275,7 +275,7 @@ def run_dtensor_to_shard_tensor_non_leaf_gradient(mesh): loss_ref.backward() assert dt.grad is not None - assert isinstance(dt.grad, DTensor) + assert isinstance(dt.grad, (ShardTensor, DTensor)) assert torch.allclose(dt.grad.full_tensor(), ref.grad) diff --git a/test/domain_parallel/test_initialization.py b/test/domain_parallel/test_initialization.py index 5c5c8fbf02..d6cd7d054b 100644 --- a/test/domain_parallel/test_initialization.py +++ b/test/domain_parallel/test_initialization.py @@ -121,6 +121,101 @@ def init_from_data_rank_worker(mesh): assert dim == local_data.shape[i] +def scatter_tensor_requires_grad_contract_worker(mesh, requires_grad: bool): + r"""Validate scatter_tensor construction contract for requires_grad modes.""" + dm = DistributedManager() + rank = dm.rank + global_shape, placements = init_global_shape_and_placements(mesh) + source = 0 + + if rank == source: + raw_data = torch.randn( + global_shape, device=torch.device(f"cuda:{dm.local_rank}") + ) + else: + raw_data = None + + st = scatter_tensor( + raw_data, + source, + mesh, + placements, + global_shape=torch.Size(global_shape), + dtype=torch.float32, + requires_grad=requires_grad, + ) + + assert st.requires_grad is requires_grad + if requires_grad: + assert st.is_leaf + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +@pytest.mark.parametrize("requires_grad", [False, True]) +def test_scatter_tensor_requires_grad_contract_1d(distributed_mesh, requires_grad): + scatter_tensor_requires_grad_contract_worker(distributed_mesh, requires_grad) + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +@pytest.mark.parametrize("requires_grad", [False, True]) +def test_scatter_tensor_requires_grad_contract_2d(distributed_mesh_2d, requires_grad): + scatter_tensor_requires_grad_contract_worker(distributed_mesh_2d, requires_grad) + + +def scatter_tensor_grad_population_worker(mesh): + r"""Validate that gradients populate for scatter_tensor(..., requires_grad=True).""" + dm = DistributedManager() + rank = dm.rank + global_shape, placements = init_global_shape_and_placements(mesh) + source = 0 + + if rank == source: + raw_data = torch.randn( + global_shape, device=torch.device(f"cuda:{dm.local_rank}") + ) + else: + raw_data = None + + st = scatter_tensor( + raw_data, + source, + mesh, + placements, + global_shape=torch.Size(global_shape), + dtype=torch.float32, + requires_grad=True, + ) + assert st.is_leaf + assert st.requires_grad + + reference = st.full_tensor().detach().requires_grad_(True) + reference_loss = (reference**2).sum() + reference_loss.backward() + + st2 = st**2 + sharded_loss = st2.sum() + sharded_loss.backward() + + assert st.grad is not None + assert st.grad._spec.placements == st._spec.placements + assert st.grad._spec.sharding_shapes() == st._spec.sharding_shapes() + assert torch.allclose(st.grad.full_tensor(), reference.grad) + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +def test_scatter_tensor_requires_grad_gradient_1d(distributed_mesh): + scatter_tensor_grad_population_worker(distributed_mesh) + + +@pytest.mark.timeout(10) +@pytest.mark.multigpu_static +def test_scatter_tensor_requires_grad_gradient_2d(distributed_mesh_2d): + scatter_tensor_grad_population_worker(distributed_mesh_2d) + + @pytest.mark.timeout(10) @pytest.mark.multigpu_static def test_shard_tensor_initialization_from_data_rank_1d(distributed_mesh, verbose=False): @@ -162,8 +257,6 @@ def shard_tensor_initialization_from_all_dtensor_worker(mesh): st = ShardTensor.from_dtensor(dt) - print(f"Rank {dm.rank} made shard tensors.") - dt_full = dt.full_tensor() st_full = st.full_tensor() diff --git a/test/domain_parallel/test_reductions.py b/test/domain_parallel/test_reductions.py index 8cb8931e45..2145f48af1 100644 --- a/test/domain_parallel/test_reductions.py +++ b/test/domain_parallel/test_reductions.py @@ -118,6 +118,10 @@ def test_shard_tensor_reduction( requires_grad=backward, ) + # if backward: + # assert shard_tensor.is_leaf + # assert shard_tensor.requires_grad + if verbose: print( f"Shard tensor global shape: {shard_tensor.shape} and local shape: {shard_tensor._local_tensor.shape}"