From cc5706248127f67f59d12625f6f13becf225389d Mon Sep 17 00:00:00 2001 From: mahf708 Date: Tue, 24 Feb 2026 09:05:57 -0800 Subject: [PATCH 1/4] fix --- fme/ace/aggregator/train.py | 1 + fme/ace/data_loading/batch_data.py | 20 ++ fme/ace/models/makani/sfnonet.py | 79 ++++-- fme/ace/models/makani_fcn3/utils/comm.py | 61 ++++- .../models/makani_fcn3/utils/distributed.py | 163 ++++++++++-- fme/ace/models/modulus/sfnonet.py | 89 +++++-- fme/ace/registry/sfno.py | 3 + fme/ace/registry/stochastic_sfno.py | 13 +- fme/ace/stepper/single_module.py | 24 ++ fme/core/distributed/base.py | 55 ++++ fme/core/distributed/distributed.py | 55 +++- .../distributed/model_torch_distributed.py | 114 +++++++++ .../parallel_tests/test_spatial.py | 235 ++++++++++++++++++ fme/core/gridded_ops.py | 49 +++- fme/core/models/conditional_sfno/layers.py | 83 ++++++- fme/core/models/conditional_sfno/sfnonet.py | 37 ++- 16 files changed, 1005 insertions(+), 76 deletions(-) create mode 100644 fme/core/distributed/parallel_tests/test_spatial.py diff --git a/fme/ace/aggregator/train.py b/fme/ace/aggregator/train.py index 609d01807..984850454 100644 --- a/fme/ace/aggregator/train.py +++ b/fme/ace/aggregator/train.py @@ -64,6 +64,7 @@ def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations) @torch.no_grad() def record_batch(self, batch: TrainOutput): + batch = batch.gather_spatial() self._loss += batch.metrics["loss"] self._n_loss_batches += 1 diff --git a/fme/ace/data_loading/batch_data.py b/fme/ace/data_loading/batch_data.py index f94a02a05..c523341e1 100644 --- a/fme/ace/data_loading/batch_data.py +++ b/fme/ace/data_loading/batch_data.py @@ -451,6 +451,26 @@ def pin_memory(self: SelfType) -> SelfType: self.data = {name: tensor.pin_memory() for name, tensor in self.data.items()} return self + def to_spatial_shard(self: SelfType) -> SelfType: + """Slice every tensor in *data* to this rank's spatial tile. + + When spatial parallelism is inactive the ``scatter_spatial`` call is a + no-op and this method returns *self* unchanged. + """ + from fme.core.distributed import Distributed + + dist = Distributed.get_instance() + if dist.h_size <= 1 and dist.w_size <= 1: + return self + return self.__class__( + data={k: dist.scatter_spatial(v) for k, v in self.data.items()}, + time=self.time, + horizontal_dims=self.horizontal_dims, + epoch=self.epoch, + labels=self.labels, + n_ensemble=self.n_ensemble, + ) + @dataclasses.dataclass class PairedData: diff --git a/fme/ace/models/makani/sfnonet.py b/fme/ace/models/makani/sfnonet.py index a1ebed145..caa7d5996 100644 --- a/fme/ace/models/makani/sfnonet.py +++ b/fme/ace/models/makani/sfnonet.py @@ -25,8 +25,15 @@ # get spectral transforms from torch_harmonics import torch_harmonics as th +import torch_harmonics.distributed as thd from torch.utils.checkpoint import checkpoint +from fme.ace.models.makani_fcn3.mpu.layers import ( + DistributedInverseRealFFT2, + DistributedRealFFT2, +) +from fme.ace.models.makani_fcn3.mpu.layer_norm import DistributedInstanceNorm2d + from .layers import MLP, DropPath, EncoderDecoder, InverseRealFFT2, RealFFT2 from .spectral_convolution import FactorizedSpectralConv, SpectralConv @@ -336,13 +343,24 @@ def __init__( if normalization_layer == "layer_norm": raise NotImplementedError("requires makani distributed libraries") elif normalization_layer == "instance_norm": - norm_layer_inp = partial( - nn.InstanceNorm2d, - num_features=embed_dim, - eps=1e-6, - affine=True, - track_running_stats=False, - ) + from fme.core.distributed import Distributed + + dist = Distributed.get_instance() + if dist.h_size > 1 or dist.w_size > 1: + norm_layer_inp = partial( + DistributedInstanceNorm2d, + num_features=embed_dim, + eps=1e-6, + affine=True, + ) + else: + norm_layer_inp = partial( + nn.InstanceNorm2d, + num_features=embed_dim, + eps=1e-6, + affine=True, + track_running_stats=False, + ) norm_layer_out = norm_layer_mid = norm_layer_inp elif normalization_layer == "none": norm_layer_out = norm_layer_mid = norm_layer_inp = nn.Identity @@ -464,9 +482,17 @@ def _init_spectral_transforms( max_modes=None, ): """ - Initialize the spectral transforms based on the maximum number of modes to keep. Handles the computation - of local image shapes and domain parallelism, based on the + Initialize the spectral transforms based on the maximum number of + modes to keep. Automatically selects distributed variants when + spatial parallelism is active. """ + from fme.core.distributed import Distributed + + dist = Distributed.get_instance() + spatial_parallel = dist.h_size > 1 or dist.w_size > 1 + + if spatial_parallel: + thd.init(dist.h_group, dist.w_group) if max_modes is not None: modes_lat, modes_lon = max_modes @@ -476,8 +502,12 @@ def _init_spectral_transforms( # prepare the spectral transforms if spectral_transform == "sht": - sht_handle = th.RealSHT - isht_handle = th.InverseRealSHT + if spatial_parallel: + sht_handle = thd.DistributedRealSHT + isht_handle = thd.DistributedInverseRealSHT + else: + sht_handle = th.RealSHT + isht_handle = th.InverseRealSHT # set up self.trans_down = sht_handle( @@ -494,8 +524,12 @@ def _init_spectral_transforms( ).float() elif spectral_transform == "fft": - fft_handle = RealFFT2 - ifft_handle = InverseRealFFT2 + if spatial_parallel: + fft_handle = DistributedRealFFT2 + ifft_handle = DistributedInverseRealFFT2 + else: + fft_handle = RealFFT2 + ifft_handle = InverseRealFFT2 self.trans_down = fft_handle( self.inp_shape[0], self.inp_shape[1], lmax=modes_lat, mmax=modes_lon @@ -512,11 +546,20 @@ def _init_spectral_transforms( else: raise (ValueError("Unknown spectral transform")) - # use the SHT/FFT to compute the local, downscaled grid dimensions - self.inp_shape_loc = (self.trans_down.nlat, self.trans_down.nlon) - self.out_shape_loc = (self.itrans_up.nlat, self.itrans_up.nlon) - self.h_loc = self.itrans.nlat - self.w_loc = self.itrans.nlon + # use the SHT/FFT to compute the local, downscaled grid dimensions. + # Under distributed transforms, .nlat / .nlon still report the + # *global* size, so we divide by the spatial group dimensions to + # get the local tile size used for position embeddings etc. + self.inp_shape_loc = ( + self.trans_down.nlat // dist.h_size, + self.trans_down.nlon // dist.w_size, + ) + self.out_shape_loc = ( + self.itrans_up.nlat // dist.h_size, + self.itrans_up.nlon // dist.w_size, + ) + self.h_loc = self.itrans.nlat // dist.h_size + self.w_loc = self.itrans.nlon // dist.w_size @torch.jit.ignore def no_weight_decay(self): diff --git a/fme/ace/models/makani_fcn3/utils/comm.py b/fme/ace/models/makani_fcn3/utils/comm.py index 744dc0b2d..3c58015b7 100644 --- a/fme/ace/models/makani_fcn3/utils/comm.py +++ b/fme/ace/models/makani_fcn3/utils/comm.py @@ -13,18 +13,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Communicator shim for Makani / FCN3. + +Delegates to the ACE :class:`~fme.core.distributed.Distributed` singleton so +that the rest of the vendored Makani code works unchanged. + +Supported group names: ``"h"``, ``"w"``, ``"spatial"``, ``"matmul"``. +``"matmul"`` parallelism is not used (always returns size 1 / rank 0). +""" + +from __future__ import annotations + + +def _dist(): + """Lazy import to avoid circular imports at module level.""" + from fme.core.distributed import Distributed + + return Distributed.get_instance() + def get_size(name: str) -> int: - return 1 + """Return the number of ranks in the named group.""" + d = _dist() + if name == "h": + return d.h_size + if name == "w": + return d.w_size + if name == "spatial": + return d.h_size * d.w_size + if name == "matmul": + return 1 + raise ValueError(f"Unknown comm group name: {name!r}") def get_rank(name: str) -> int: - return 0 + """Return the rank of this process in the named group.""" + d = _dist() + if name == "h": + return d.h_rank + if name == "w": + return d.w_rank + if name == "spatial": + # Row-major linearised rank within the (h, w) tile. + return d.h_rank * d.w_size + d.w_rank + if name == "matmul": + return 0 + raise ValueError(f"Unknown comm group name: {name!r}") def get_group(name: str): - return None + """Return the ``torch.distributed`` process group for the named group.""" + d = _dist() + if name == "h": + return d.h_group + if name == "w": + return d.w_group + if name == "spatial": + return d.spatial_group + if name == "matmul": + return None + raise ValueError(f"Unknown comm group name: {name!r}") -def is_distributed(name: str): - return False +def is_distributed(name: str) -> bool: + """Return whether the named group has more than one rank.""" + return get_size(name) > 1 diff --git a/fme/ace/models/makani_fcn3/utils/distributed.py b/fme/ace/models/makani_fcn3/utils/distributed.py index e9af79a72..12e03e350 100644 --- a/fme/ace/models/makani_fcn3/utils/distributed.py +++ b/fme/ace/models/makani_fcn3/utils/distributed.py @@ -1,31 +1,158 @@ -# parallel helpers +# parallel helpers — distributed primitives for Makani / FCN3. +# +# Delegates to ``torch.distributed`` via the ACE ``comm`` shim so +# that operations like Welford-based DistributedInstanceNorm2d and +# DistributedRealFFT2 work correctly under spatial parallelism. +# When spatial parallelism is inactive (group size 1), every +# operation degrades to a no-op. + +from __future__ import annotations + from typing import Any, List, Tuple +import torch +from torch_harmonics.distributed import compute_split_shapes as _thd_split_shapes + + +def _get_group(name: str): + """Lazy import to avoid circular imports.""" + from fme.ace.models.makani_fcn3.utils import comm + + return comm.get_group(name) + + +def _get_size(name: str) -> int: + from fme.ace.models.makani_fcn3.utils import comm + + return comm.get_size(name) + + +# ----------------------------------------------------------------------- +# Autograd functions for correct gradient flow +# ----------------------------------------------------------------------- + + +class _CopyToParallelRegion(torch.autograd.Function): + """Identity in forward; all-reduce (SUM) in backward.""" + + @staticmethod + def forward(ctx, input_: torch.Tensor, group_name: str) -> torch.Tensor: + ctx.group_name = group_name + return input_ + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + import torch.distributed as td + + pg = _get_group(ctx.group_name) + if pg is not None and td.get_world_size(group=pg) > 1: + td.all_reduce(grad_output, group=pg) + return grad_output, None + + +class _GatherFromParallelRegion(torch.autograd.Function): + """All-gather in forward; scatter (select local chunk) in backward.""" -class NotDistributed: - def compute_split_shapes(self, size: int, num_chunks: int) -> List[int]: - return [size] + @staticmethod + def forward( + ctx, + input_: torch.Tensor, + dim_: int, + shapes_: Any, + group_name: str, + ) -> torch.Tensor: + import torch.distributed as td - def reduce_from_parallel_region(self, input: Any, group: Any) -> Any | None: - return input + ctx.dim = dim_ + ctx.group_name = group_name + pg = _get_group(group_name) + if pg is None or td.get_world_size(group=pg) <= 1: + ctx.world_size = 1 + return input_ + + world_size = td.get_world_size(group=pg) + ctx.world_size = world_size + ctx.rank = td.get_rank(group=pg) + + gathered = [torch.empty_like(input_) for _ in range(world_size)] + td.all_gather(gathered, input_.contiguous(), group=pg) + return torch.cat(gathered, dim=dim_) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + if ctx.world_size <= 1: + return grad_output, None, None, None + chunks = grad_output.chunk(ctx.world_size, dim=ctx.dim) + return chunks[ctx.rank].contiguous(), None, None, None + + +class _ReduceFromParallelRegion(torch.autograd.Function): + """All-reduce (SUM) in forward; identity in backward.""" + + @staticmethod + def forward(ctx, input_: torch.Tensor, group_name: str) -> torch.Tensor: + import torch.distributed as td + + pg = _get_group(group_name) + if pg is not None and td.get_world_size(group=pg) > 1: + td.all_reduce(input_, group=pg) + return input_ + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + return grad_output, None + + +# ----------------------------------------------------------------------- +# Public helper with the same API as the original NotDistributed stub +# ----------------------------------------------------------------------- + + +class DistributedHelper: + """Distributed primitives used by Makani / FCN3 layers.""" + + @staticmethod + def compute_split_shapes(size: int, num_chunks: int) -> List[int]: + return _thd_split_shapes(size, num_chunks) + + @staticmethod + def reduce_from_parallel_region( + input_: torch.Tensor, group: str + ) -> torch.Tensor: + return _ReduceFromParallelRegion.apply(input_, group) + + @staticmethod def scatter_to_parallel_region( - self, input: Any, dim: Any, group: Any - ) -> Any | None: - return input + input_: torch.Tensor, dim: int, group: str + ) -> torch.Tensor: + import torch.distributed as td + + pg = _get_group(group) + if pg is None or td.get_world_size(group=pg) <= 1: + return input_ + rank = td.get_rank(group=pg) + world_size = td.get_world_size(group=pg) + chunks = input_.chunk(world_size, dim=dim) + return chunks[rank].contiguous() + @staticmethod def gather_from_parallel_region( - self, input: Any, dim: Any, shapes: Any, group: Any - ) -> Any | None: - return input + input_: torch.Tensor, dim: int, shapes: Any, group: str + ) -> torch.Tensor: + return _GatherFromParallelRegion.apply(input_, dim, shapes, group) - def copy_to_parallel_region(self, input: Any, group: Any) -> Any | None: - return input + @staticmethod + def copy_to_parallel_region( + input_: torch.Tensor, group: str + ) -> torch.Tensor: + return _CopyToParallelRegion.apply(input_, group) + @staticmethod def split_tensor_along_dim( - self, tensor: Any, dim: Any, num_chunks: Any - ) -> Tuple[Any, ...]: - return tensor + tensor: torch.Tensor, dim: int, num_chunks: int + ) -> Tuple[torch.Tensor, ...]: + return tensor.chunk(num_chunks, dim=dim) -dist = NotDistributed() +dist = DistributedHelper() diff --git a/fme/ace/models/modulus/sfnonet.py b/fme/ace/models/modulus/sfnonet.py index de66056c8..1295996e8 100644 --- a/fme/ace/models/modulus/sfnonet.py +++ b/fme/ace/models/modulus/sfnonet.py @@ -25,6 +25,8 @@ import torch_harmonics.distributed as thd from torch.utils.checkpoint import checkpoint +from fme.core.distributed import Distributed + from .initialization import trunc_normal_ # wrap fft, to unify interface to spectral transforms @@ -463,10 +465,26 @@ def __init__( data_grid = params.data_grid if hasattr(params, "data_grid") else "equiangular" # self.pretrain_encoding = params.pretrain_encoding if hasattr(params, "pretrain_encoding") else False + # spatial parallelism + dist = Distributed.get_instance() + sp_active = dist.h_size > 1 or dist.w_size > 1 + # compute the downscaled image size self.h = int(self.img_shape[0] // self.scale_factor) self.w = int(self.img_shape[1] // self.scale_factor) + # guards for unsupported SP configurations + if sp_active and self.scale_factor > 1: + raise ValueError( + "SphericalFourierNeuralOperatorNet does not support " + "spatial parallelism with scale_factor > 1." + ) + if sp_active and self.residual_filter_factor > 1: + raise ValueError( + "SphericalFourierNeuralOperatorNet does not support " + "spatial parallelism with residual_filter_factor > 1." + ) + # Compute the maximum frequencies in h and in w modes_lat = int(self.h * self.hard_thresholding_fraction) modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction) @@ -497,8 +515,13 @@ def __init__( # prepare the spectral transforms if self.spectral_transform == "sht": - sht_handle = th.RealSHT - isht_handle = th.InverseRealSHT + if sp_active: + thd.init(dist.h_group, dist.w_group) + sht_handle = thd.DistributedRealSHT + isht_handle = thd.DistributedInverseRealSHT + else: + sht_handle = th.RealSHT + isht_handle = th.InverseRealSHT # set up self.trans_down = sht_handle( @@ -519,8 +542,18 @@ def __init__( raise NotImplementedError( "Residual filter factor is not implemented for FFT spectral transform" ) - fft_handle = th.RealFFT2 - ifft_handle = th.InverseRealFFT2 + + if sp_active: + from fme.ace.models.makani_fcn3.mpu.layers import ( + DistributedInverseRealFFT2, + DistributedRealFFT2, + ) + + fft_handle = DistributedRealFFT2 + ifft_handle = DistributedInverseRealFFT2 + else: + fft_handle = th.RealFFT2 + ifft_handle = th.InverseRealFFT2 # effective image size: self.img_shape_eff = ( @@ -548,10 +581,21 @@ def __init__( raise (ValueError("Unknown spectral transform")) # use the SHT/FFT to compute the local, downscaled grid dimensions - self.img_shape_loc = (self.trans_down.nlat, self.trans_down.nlon) - self.img_shape_eff = (self.trans_down.nlat, self.trans_down.nlon) - self.h_loc = self.itrans.nlat - self.w_loc = self.itrans.nlon + # Note: for distributed transforms, .nlat/.nlon return GLOBAL sizes, + # so we divide by the spatial group sizes to get local tile dimensions. + if sp_active: + self.img_shape_loc = ( + self.trans_down.nlat // dist.h_size, + self.trans_down.nlon // dist.w_size, + ) + self.img_shape_eff = self.img_shape_loc + self.h_loc = self.itrans.nlat // dist.h_size + self.w_loc = self.itrans.nlon // dist.w_size + else: + self.img_shape_loc = (self.trans_down.nlat, self.trans_down.nlon) + self.img_shape_eff = (self.trans_down.nlat, self.trans_down.nlon) + self.h_loc = self.itrans.nlat + self.w_loc = self.itrans.nlon # determine activation function if self.activation_function == "relu": @@ -591,13 +635,25 @@ def __init__( nn.LayerNorm, normalized_shape=(self.h_loc, self.w_loc), eps=1e-6 ) elif self.normalization_layer == "instance_norm": - norm_layer0 = partial( - nn.InstanceNorm2d, - num_features=self.embed_dim, - eps=1e-6, - affine=True, - track_running_stats=False, - ) + if sp_active: + from fme.ace.models.makani_fcn3.mpu.layer_norm import ( + DistributedInstanceNorm2d, + ) + + norm_layer0 = partial( + DistributedInstanceNorm2d, + num_features=self.embed_dim, + eps=1e-6, + affine=True, + ) + else: + norm_layer0 = partial( + nn.InstanceNorm2d, + num_features=self.embed_dim, + eps=1e-6, + affine=True, + track_running_stats=False, + ) norm_layer1 = norm_layer0 elif self.normalization_layer == "none": norm_layer0 = nn.Identity @@ -678,8 +734,9 @@ def __init__( 1, self.embed_dim, self.img_shape_loc[0], self.img_shape_loc[1] ) ) - # self.pos_embed = nn.Parameter( torch.zeros(1, self.embed_dim, self.img_shape_eff[0], self.img_shape_eff[1]) ) self.pos_embed.is_shared_mp = ["matmul"] + if sp_active: + self.pos_embed.sharded_dims_mp = [None, None, "h", "w"] trunc_normal_(self.pos_embed, std=0.02) self.apply(self._init_weights) diff --git a/fme/ace/registry/sfno.py b/fme/ace/registry/sfno.py index 9fc737f58..b1e727942 100644 --- a/fme/ace/registry/sfno.py +++ b/fme/ace/registry/sfno.py @@ -47,6 +47,9 @@ def build( n_out_channels: int, dataset_info: DatasetInfo, ): + from fme.core.distributed import Distributed + + dist = Distributed.get_instance() if len(dataset_info.all_labels) > 0: raise ValueError( "SphericalFourierNeuralOperatorNet does not support labels" diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index 8be3f0193..bf22f0d54 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -85,13 +85,20 @@ def forward( ) -> torch.Tensor: x = x.reshape(-1, *x.shape[-3:]) if self.noise_type == "isotropic": - lmax = self.conditional_model.itrans_up.lmax - mmax = self.conditional_model.itrans_up.mmax + import torch_harmonics.distributed as thd + + itrans = self.conditional_model.itrans_up + if isinstance(itrans, thd.DistributedInverseRealSHT): + lmax = itrans.lmax_local + mmax = itrans.mmax_local + else: + lmax = itrans.lmax + mmax = itrans.mmax noise = isotropic_noise( (x.shape[0], self.embed_dim), lmax, mmax, - self.conditional_model.itrans_up, + itrans, device=x.device, ) elif self.noise_type == "gaussian": diff --git a/fme/ace/stepper/single_module.py b/fme/ace/stepper/single_module.py index a991d2742..032feda2a 100644 --- a/fme/ace/stepper/single_module.py +++ b/fme/ace/stepper/single_module.py @@ -455,6 +455,29 @@ def compute_derived_variables( def get_metrics(self) -> TensorDict: return self.metrics + def gather_spatial(self) -> "TrainOutput": + """All-gather spatial shards back to full resolution. + + Returns *self* unchanged when spatial parallelism is inactive. + """ + from fme.core.distributed import Distributed + + dist = Distributed.get_instance() + if dist.h_size <= 1 and dist.w_size <= 1: + return self + return TrainOutput( + metrics=self.metrics, + gen_data=EnsembleTensorDict( + {k: dist.gather_spatial(v) for k, v in self.gen_data.items()} + ), + target_data=EnsembleTensorDict( + {k: dist.gather_spatial(v) for k, v in self.target_data.items()} + ), + time=self.time, + normalize=self.normalize, + derive_func=self.derive_func, + ) + def stack_list_of_tensor_dicts( dict_list: list[TensorDict], @@ -1525,6 +1548,7 @@ def train_on_batch( and the normalized batch data. """ self._init_for_epoch(data.epoch) + data = data.to_spatial_shard() metrics: dict[str, float] = {} input_data = data.get_start(self._prognostic_names, self.n_ic_timesteps) target_data = self._stepper.get_forward_data( diff --git a/fme/core/distributed/base.py b/fme/core/distributed/base.py index cb928cdb7..e9323bdc4 100644 --- a/fme/core/distributed/base.py +++ b/fme/core/distributed/base.py @@ -111,3 +111,58 @@ def barrier(self): ... @abstractmethod def shutdown(self): ... + + # ------------------------------------------------------------------ + # Spatial parallelism API (concrete defaults → no-ops) + # ------------------------------------------------------------------ + + @property + def h_rank(self) -> int: + """Rank along the h spatial dimension. Default: 0.""" + return 0 + + @property + def w_rank(self) -> int: + """Rank along the w spatial dimension. Default: 0.""" + return 0 + + @property + def h_size(self) -> int: + """Number of ranks along h. Default: 1.""" + return 1 + + @property + def w_size(self) -> int: + """Number of ranks along w. Default: 1.""" + return 1 + + @property + def h_group(self): + """Process group for the h dimension. Default: None.""" + return None + + @property + def w_group(self): + """Process group for the w dimension. Default: None.""" + return None + + @property + def spatial_group(self): + """Process group spanning all spatial (h × w) peers. Default: None.""" + return None + + def scatter_spatial( + self, tensor: torch.Tensor, h_dim: int = -2, w_dim: int = -1 + ) -> torch.Tensor: + """Slice *tensor* to the local spatial shard. No-op by default.""" + return tensor + + def gather_spatial( + self, tensor: torch.Tensor, h_dim: int = -2, w_dim: int = -1 + ) -> torch.Tensor: + """All-gather spatial shards back to full resolution. No-op by default.""" + return tensor + + def reduce_sum_spatial(self, tensor: torch.Tensor) -> torch.Tensor: + """All-reduce (SUM) over spatial peers. No-op by default.""" + return tensor diff --git a/fme/core/distributed/distributed.py b/fme/core/distributed/distributed.py index 274afcb7c..d056bfa73 100644 --- a/fme/core/distributed/distributed.py +++ b/fme/core/distributed/distributed.py @@ -159,8 +159,8 @@ def get_sampler( return torch.utils.data.DistributedSampler( dataset, shuffle=shuffle, - num_replicas=self._distributed.total_ranks, - rank=self._distributed.rank, + num_replicas=self._distributed.total_data_parallel_ranks, + rank=self._distributed.data_parallel_rank, seed=self._seed, drop_last=drop_last, ) @@ -363,5 +363,56 @@ def get_seed(self) -> int: def shutdown(self): return self._distributed.shutdown() + # ------------------------------------------------------------------ + # Spatial parallelism delegation + # ------------------------------------------------------------------ + + @property + def h_rank(self) -> int: + return self._distributed.h_rank + + @property + def w_rank(self) -> int: + return self._distributed.w_rank + + @property + def h_size(self) -> int: + return self._distributed.h_size + + @property + def w_size(self) -> int: + return self._distributed.w_size + + @property + def h_group(self): + """Process group for the h spatial dimension.""" + return self._distributed.h_group + + @property + def w_group(self): + """Process group for the w spatial dimension.""" + return self._distributed.w_group + + @property + def spatial_group(self): + """Process group spanning all spatial (h × w) peers.""" + return self._distributed.spatial_group + + def scatter_spatial( + self, tensor: torch.Tensor, h_dim: int = -2, w_dim: int = -1 + ) -> torch.Tensor: + """Slice *tensor* to the local spatial shard (no-op if non-spatial).""" + return self._distributed.scatter_spatial(tensor, h_dim, w_dim) + + def gather_spatial( + self, tensor: torch.Tensor, h_dim: int = -2, w_dim: int = -1 + ) -> torch.Tensor: + """All-gather spatial shards back to full resolution.""" + return self._distributed.gather_spatial(tensor, h_dim, w_dim) + + def reduce_sum_spatial(self, tensor: torch.Tensor) -> torch.Tensor: + """All-reduce (SUM) over spatial peers only.""" + return self._distributed.reduce_sum_spatial(tensor) + singleton: Distributed | None = None diff --git a/fme/core/distributed/model_torch_distributed.py b/fme/core/distributed/model_torch_distributed.py index d4ae5025e..8838bcf96 100644 --- a/fme/core/distributed/model_torch_distributed.py +++ b/fme/core/distributed/model_torch_distributed.py @@ -270,3 +270,117 @@ def shutdown(self): self.barrier() logger.debug("Shutting down rank %d", self._rank) DistributedManager.cleanup() + + # ------------------------------------------------------------------ + # Spatial parallelism overrides + # ------------------------------------------------------------------ + + @property + def h_rank(self) -> int: + return self._h_rank + + @property + def w_rank(self) -> int: + return self._w_rank + + @property + def h_size(self) -> int: + return self._h_size + + @property + def w_size(self) -> int: + return self._w_size + + @property + def h_group(self): + return self._h_group + + @property + def w_group(self): + return self._w_group + + @property + def spatial_group(self): + """Flat (h, w) process group for all-gather / all-reduce.""" + if not hasattr(self, "_spatial_group"): + # Flatten the (h, w) sub-mesh into a single group. + self._spatial_group = self._dm.get_mesh_group( + self._mesh["h", "w"] + ) + return self._spatial_group + + def scatter_spatial( + self, tensor: torch.Tensor, h_dim: int = -2, w_dim: int = -1 + ) -> torch.Tensor: + """Pure local slicing — pick this rank's tile from the full tensor.""" + if self._h_size == 1 and self._w_size == 1: + return tensor + + # Resolve negative dims. + ndim = tensor.ndim + h_dim = h_dim % ndim + w_dim = w_dim % ndim + + h_total = tensor.shape[h_dim] + w_total = tensor.shape[w_dim] + if h_total % self._h_size != 0: + raise ValueError( + f"h dim size {h_total} not divisible by h_size {self._h_size}" + ) + if w_total % self._w_size != 0: + raise ValueError( + f"w dim size {w_total} not divisible by w_size {self._w_size}" + ) + + h_chunk = h_total // self._h_size + w_chunk = w_total // self._w_size + + tensor = tensor.narrow( + h_dim, self._h_rank * h_chunk, h_chunk + ) + tensor = tensor.narrow( + w_dim, self._w_rank * w_chunk, w_chunk + ) + return tensor.contiguous() + + def gather_spatial( + self, tensor: torch.Tensor, h_dim: int = -2, w_dim: int = -1 + ) -> torch.Tensor: + """All-gather spatial shards and reassemble the full tensor.""" + spatial_size = self._h_size * self._w_size + if spatial_size == 1: + return tensor + + # Resolve negative dims. + ndim = tensor.ndim + h_dim = h_dim % ndim + w_dim = w_dim % ndim + + # All-gather across the flat spatial group. + gather_list = [ + torch.empty_like(tensor) for _ in range(spatial_size) + ] + torch.distributed.all_gather( + gather_list, tensor.contiguous(), group=self.spatial_group + ) + + # Reassemble the 2-D tile grid. + # Spatial ranks are laid out in row-major (h, w) order. + rows = [] + for hi in range(self._h_size): + row_tiles = [ + gather_list[hi * self._w_size + wi] + for wi in range(self._w_size) + ] + rows.append(torch.cat(row_tiles, dim=w_dim)) + return torch.cat(rows, dim=h_dim) + + def reduce_sum_spatial(self, tensor: torch.Tensor) -> torch.Tensor: + """All-reduce (SUM) over spatial peers.""" + spatial_size = self._h_size * self._w_size + if spatial_size == 1: + return tensor + torch.distributed.all_reduce( + tensor, op=torch.distributed.ReduceOp.SUM, group=self.spatial_group + ) + return tensor diff --git a/fme/core/distributed/parallel_tests/test_spatial.py b/fme/core/distributed/parallel_tests/test_spatial.py new file mode 100644 index 000000000..1e517947f --- /dev/null +++ b/fme/core/distributed/parallel_tests/test_spatial.py @@ -0,0 +1,235 @@ +""" +Tests for spatial parallelism primitives. + +These tests verify scatter/gather round-trips, reduce_sum_spatial, and +correct area-weighted mean under spatial sharding. They work both +serially (NonDistributed) and in parallel (torchrun with ModelTorchDistributed). +""" + +import pytest +import torch + +from fme.core import get_device +from fme.core.distributed import Distributed + + +# ----------------------------------------------------------------------- +# scatter + gather round-trip +# ----------------------------------------------------------------------- + + +@pytest.mark.parallel +def test_scatter_gather_roundtrip_2d(): + """scatter then gather on a simple 2-D tensor recovers the original.""" + dist = Distributed.get_instance() + h, w = 8, 12 + x = torch.arange(h * w, dtype=torch.float32, device=get_device()).reshape(h, w) + local = dist.scatter_spatial(x, h_dim=0, w_dim=1) + reconstructed = dist.gather_spatial(local, h_dim=0, w_dim=1) + torch.testing.assert_close(reconstructed, x) + + +@pytest.mark.parallel +def test_scatter_gather_roundtrip_4d(): + """scatter then gather on a (B, C, H, W) tensor recovers the original.""" + dist = Distributed.get_instance() + B, C, H, W = 2, 3, 8, 12 + x = torch.randn(B, C, H, W, device=get_device()) + local = dist.scatter_spatial(x) # default h_dim=-2, w_dim=-1 + reconstructed = dist.gather_spatial(local) + torch.testing.assert_close(reconstructed, x) + + +@pytest.mark.parallel +def test_scatter_correct_local_shape(): + """Local shard has the expected reduced spatial shape.""" + dist = Distributed.get_instance() + B, C, H, W = 2, 3, 8, 12 + x = torch.randn(B, C, H, W, device=get_device()) + local = dist.scatter_spatial(x) + assert local.shape == (B, C, H // dist.h_size, W // dist.w_size) + + +@pytest.mark.parallel +def test_scatter_spatial_noop_when_single(): + """When spatial parallelism is off (both sizes==1), x is returned as-is.""" + dist = Distributed.get_instance() + x = torch.randn(4, 4, device=get_device()) + result = dist.scatter_spatial(x, h_dim=0, w_dim=1) + if dist.h_size == 1 and dist.w_size == 1: + assert result is x # exact same object + else: + # still correct in shape + assert result.shape[0] == 4 // dist.h_size + + +# ----------------------------------------------------------------------- +# reduce_sum_spatial +# ----------------------------------------------------------------------- + + +@pytest.mark.parallel +def test_reduce_sum_spatial_all_ones(): + """Summing all-ones tensors across spatial peers scales by h_size * w_size.""" + dist = Distributed.get_instance() + t = torch.ones(3, 4, device=get_device()) + result = dist.reduce_sum_spatial(t.clone()) + expected = torch.full_like(t, dist.h_size * dist.w_size) + torch.testing.assert_close(result, expected) + + +@pytest.mark.parallel +def test_reduce_sum_spatial_noop_when_single(): + """reduce_sum_spatial is identity when no spatial parallelism.""" + dist = Distributed.get_instance() + t = torch.randn(5, device=get_device()) + result = dist.reduce_sum_spatial(t.clone()) + if dist.h_size == 1 and dist.w_size == 1: + torch.testing.assert_close(result, t) + else: + # When parallel, the sum should differ from the original. + # Just verify it returned a tensor of the right shape. + assert result.shape == t.shape + + +# ----------------------------------------------------------------------- +# area-weighted mean under sharding +# ----------------------------------------------------------------------- + + +@pytest.mark.parallel +def test_area_weighted_mean_uniform_weights(): + """With uniform weights, area_weighted_mean == vanilla mean over spatial dims. + + This tests that the separate num/den all-reduce in _spatial_weighted_mean + produces the correct result even when the data is sharded. + """ + from fme.core.gridded_ops import _spatial_weighted_mean + + dist = Distributed.get_instance() + H, W = 8, 12 + x = torch.arange(H * W, dtype=torch.float32, device=get_device()).reshape(1, H, W) + weights = torch.ones(1, H, W, device=get_device()) + + # Shard data + weights to match what training loop would do. + x_local = dist.scatter_spatial(x, h_dim=-2, w_dim=-1) + w_local = dist.scatter_spatial(weights, h_dim=-2, w_dim=-1) + + result = _spatial_weighted_mean(x_local, w_local, dim=(-2, -1)) + + # Expected: plain mean over all H*W elements. + expected = x.float().mean(dim=(-2, -1)) + torch.testing.assert_close(result, expected) + + +@pytest.mark.parallel +def test_area_weighted_mean_nonuniform_weights(): + """Non-uniform (latitude-like) weights must give the correct global mean. + + Construct cosine-latitude weights (vary only in h) and verify the + weighted mean matches the full-grid computation. + """ + from fme.core.gridded_ops import _spatial_weighted_mean + + dist = Distributed.get_instance() + H, W = 8, 12 + # Data: simple latitude gradient. + data = ( + torch.arange(H, dtype=torch.float32, device=get_device()) + .unsqueeze(-1) + .expand(H, W) + .unsqueeze(0) + ) + # Weights: cosine-like (vary only in lat / h dim). + lats = torch.linspace(-90, 90, H, device=get_device()) + cos_w = torch.cos(lats * 3.14159265 / 180.0).unsqueeze(-1).expand(H, W) + cos_w = cos_w.unsqueeze(0) + + # Full-grid reference. + expected = (data * cos_w).sum(dim=(-2, -1)) / cos_w.sum(dim=(-2, -1)) + + # Sharded computation. + data_local = dist.scatter_spatial(data, h_dim=-2, w_dim=-1) + w_local = dist.scatter_spatial(cos_w, h_dim=-2, w_dim=-1) + result = _spatial_weighted_mean(data_local, w_local, dim=(-2, -1)) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + + +# ----------------------------------------------------------------------- +# Sampler uses data-parallel rank / size +# ----------------------------------------------------------------------- + + +@pytest.mark.parallel +def test_sampler_uses_data_parallel_ranks(): + """DistributedSampler should be configured with data-parallel rank/size.""" + dist = Distributed.get_instance() + ds = torch.utils.data.TensorDataset(torch.randn(100)) + sampler = dist.get_sampler(ds, shuffle=False) + assert sampler.num_replicas == dist.total_data_parallel_ranks + assert sampler.rank == dist.data_parallel_rank + + +# ----------------------------------------------------------------------- +# Makani comm.py wiring +# ----------------------------------------------------------------------- + + +@pytest.mark.parallel +def test_comm_get_size_matches_distributed(): + """comm.get_size should agree with the Distributed singleton.""" + from fme.ace.models.makani_fcn3.utils import comm + + dist = Distributed.get_instance() + assert comm.get_size("h") == dist.h_size + assert comm.get_size("w") == dist.w_size + assert comm.get_size("spatial") == dist.h_size * dist.w_size + assert comm.get_size("matmul") == 1 + + +@pytest.mark.parallel +def test_comm_get_rank_matches_distributed(): + """comm.get_rank should agree with the Distributed singleton.""" + from fme.ace.models.makani_fcn3.utils import comm + + dist = Distributed.get_instance() + assert comm.get_rank("h") == dist.h_rank + assert comm.get_rank("w") == dist.w_rank + assert comm.get_rank("matmul") == 0 + + +@pytest.mark.parallel +def test_comm_get_group_matches_distributed(): + """comm.get_group should return the same groups as Distributed.""" + from fme.ace.models.makani_fcn3.utils import comm + + dist = Distributed.get_instance() + assert comm.get_group("h") is dist.h_group + assert comm.get_group("w") is dist.w_group + assert comm.get_group("spatial") is dist.spatial_group + assert comm.get_group("matmul") is None + + +@pytest.mark.parallel +def test_comm_is_distributed_consistent(): + """comm.is_distributed should be True iff size > 1.""" + from fme.ace.models.makani_fcn3.utils import comm + + dist = Distributed.get_instance() + expected_spatial = (dist.h_size * dist.w_size) > 1 + assert comm.is_distributed("spatial") == expected_spatial + assert comm.is_distributed("matmul") is False + + +@pytest.mark.parallel +def test_comm_unknown_name_raises(): + """Requesting an unknown group name should raise ValueError.""" + from fme.ace.models.makani_fcn3.utils import comm + + with pytest.raises(ValueError, match="Unknown comm group name"): + comm.get_size("nonexistent") + with pytest.raises(ValueError, match="Unknown comm group name"): + comm.get_rank("nonexistent") + with pytest.raises(ValueError, match="Unknown comm group name"): + comm.get_group("nonexistent") diff --git a/fme/core/gridded_ops.py b/fme/core/gridded_ops.py index 596faf8fa..e0916b723 100644 --- a/fme/core/gridded_ops.py +++ b/fme/core/gridded_ops.py @@ -16,6 +16,49 @@ from fme.core.typing_ import TensorDict, TensorMapping +def _spatial_weighted_mean( + data: torch.Tensor, + weights: torch.Tensor, + dim: tuple[int, ...], + keepdim: bool = False, +) -> torch.Tensor: + """Weighted mean that is correct under spatial sharding. + + Computes the local numerator and denominator, all-reduces both across + spatial peers via ``Distributed.reduce_sum_spatial`` (a no-op when there + is no spatial parallelism), and returns numerator / denominator. + """ + from fme.core.distributed import Distributed + + expanded = weights.expand(data.shape) + # Mask NaNs that sit behind zero-weight cells. + data = data.where(expanded != 0.0, 0.0) + num = (data * expanded).sum(dim=dim, keepdim=keepdim) + den = expanded.sum(dim=dim, keepdim=keepdim) + + dist = Distributed.get_instance() + num = dist.reduce_sum_spatial(num) + den = dist.reduce_sum_spatial(den) + return num / den + + +def _spatial_weighted_sum( + data: torch.Tensor, + weights: torch.Tensor, + dim: tuple[int, ...], + keepdim: bool = False, +) -> torch.Tensor: + """Weighted sum that is correct under spatial sharding.""" + from fme.core.distributed import Distributed + + expanded = weights.expand(data.shape) + data = data.where(expanded != 0.0, 0.0) + local = (data * expanded).sum(dim=dim, keepdim=keepdim) + + dist = Distributed.get_instance() + return dist.reduce_sum_spatial(local) + + class GriddedOperations(abc.ABC): def __eq__(self, other) -> bool: if not isinstance(other, GriddedOperations): @@ -332,7 +375,7 @@ def area_weighted_sum( name: str | None = None, ) -> torch.Tensor: area_weights = self._get_area_weights(data, name) - return metrics.weighted_sum( + return _spatial_weighted_sum( data, area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim ) @@ -343,7 +386,7 @@ def area_weighted_mean( name: str | None = None, ) -> torch.Tensor: area_weights = self._get_area_weights(data, name) - return metrics.weighted_mean( + return _spatial_weighted_mean( data, area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim ) @@ -355,7 +398,7 @@ def regional_area_weighted_mean( name: str | None = None, ) -> torch.Tensor: regional_area_weights = self._get_area_weights(data, name, regional_weights) - return metrics.weighted_mean( + return _spatial_weighted_mean( data, regional_area_weights, dim=self.HORIZONTAL_DIMS, diff --git a/fme/core/models/conditional_sfno/layers.py b/fme/core/models/conditional_sfno/layers.py index 5f6dcbe8e..1b4bf65f2 100644 --- a/fme/core/models/conditional_sfno/layers.py +++ b/fme/core/models/conditional_sfno/layers.py @@ -122,6 +122,66 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return y +class DistributedGlobalLayerNorm(nn.Module): + """LayerNorm over (C, H, W) with cross-shard all-reduce of statistics. + + Standard ``nn.LayerNorm`` computes mean/variance from the local shard + only, which diverges from single-GPU results when the spatial grid is + partitioned across ranks. This class all-reduces the first two moments + across the spatial process group so the normalisation is globally + consistent. + """ + + def __init__( + self, + normalized_shape: Tuple[int, ...], + eps: float = 1e-5, + elementwise_affine: bool = False, + ): + super().__init__() + self.normalized_shape = normalized_shape + self.n_dims = len(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + from fme.core.distributed import Distributed + + dist = Distributed.get_instance() + + # Dimensions to reduce: last ``n_dims`` dimensions of x. + reduce_dims = tuple(range(x.ndim - self.n_dims, x.ndim)) + + # Local element count per sample (same on every rank for a given + # reduce_dims since C is not sharded). Global count = local * spatial_size. + local_count = 1 + for d in reduce_dims: + local_count *= x.shape[d] + spatial_size = dist.h_size * dist.w_size + global_count = local_count * spatial_size + + # Local sums → all-reduce across spatial peers. + local_sum = x.sum(dim=reduce_dims, keepdim=True) + local_sum_sq = (x * x).sum(dim=reduce_dims, keepdim=True) + global_sum = dist.reduce_sum_spatial(local_sum) + global_sum_sq = dist.reduce_sum_spatial(local_sum_sq) + + mean = global_sum / global_count + var = global_sum_sq / global_count - mean * mean + + x = (x - mean) * torch.rsqrt(var + self.eps) + + if self.weight is not None: + x = x * self.weight + self.bias + return x + + class ConditionalLayerNorm(nn.Module): """ Conditional Layer Normalization as described in "AdaSpeech: Adaptive @@ -185,11 +245,24 @@ def __init__( self.W_scale_pos = None self.W_bias_pos = None if global_layer_norm: - self.norm = nn.LayerNorm( - (self.n_channels, img_shape[0], img_shape[1]), - eps=epsilon, - elementwise_affine=elementwise_affine, - ) + from fme.core.distributed import Distributed + + dist = Distributed.get_instance() + norm_shape = (self.n_channels, img_shape[0], img_shape[1]) + if dist.h_size > 1 or dist.w_size > 1: + # Use distributed variant that all-reduces mean/var + # across spatial shards for globally consistent stats. + self.norm = DistributedGlobalLayerNorm( + norm_shape, + eps=epsilon, + elementwise_affine=elementwise_affine, + ) + else: + self.norm = nn.LayerNorm( + norm_shape, + eps=epsilon, + elementwise_affine=elementwise_affine, + ) else: self.norm = ChannelLayerNorm( self.n_channels, diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/sfnonet.py index 29eb986f0..db6afa76a 100644 --- a/fme/core/models/conditional_sfno/sfnonet.py +++ b/fme/core/models/conditional_sfno/sfnonet.py @@ -22,6 +22,7 @@ # get spectral transforms from torch_harmonics import torch_harmonics as th +import torch_harmonics.distributed as thd from torch.utils.checkpoint import checkpoint from fme.core.benchmark.timer import Timer, NullTimer @@ -366,6 +367,10 @@ def get_lat_lon_sfnonet( embed_dim_pos=0, ), ) -> "SphericalFourierNeuralOperatorNet": + from fme.core.distributed import Distributed + + dist = Distributed.get_instance() + h, w = img_shape hard_thresholding_fraction = ( params.hard_thresholding_fraction @@ -375,28 +380,48 @@ def get_lat_lon_sfnonet( modes_lat = int(h * hard_thresholding_fraction) modes_lon = int((w // 2 + 1) * hard_thresholding_fraction) data_grid = params.data_grid if hasattr(params, "data_grid") else "equiangular" - trans_down = th.RealSHT( + + spatial_parallel = dist.h_size > 1 or dist.w_size > 1 + if spatial_parallel: + thd.init(dist.h_group, dist.w_group) + sht_cls = thd.DistributedRealSHT + isht_cls = thd.DistributedInverseRealSHT + else: + sht_cls = th.RealSHT + isht_cls = th.InverseRealSHT + + trans_down = sht_cls( *img_shape, lmax=modes_lat, mmax=modes_lon, grid=data_grid ).float() - itrans_up = th.InverseRealSHT( + itrans_up = isht_cls( *img_shape, lmax=modes_lat, mmax=modes_lon, grid=data_grid ).float() - trans = th.RealSHT( + trans = sht_cls( *img_shape, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss" ).float() - itrans = th.InverseRealSHT( + itrans = isht_cls( h, w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss" ).float() + # Under spatial parallelism the model sees the *local* shard dims. + if spatial_parallel: + local_h = h // dist.h_size + local_w = w // dist.w_size + local_img_shape: Tuple[int, int] = (local_h, local_w) + else: + local_img_shape = img_shape + def get_pos_embed(): - pos_embed = nn.Parameter(torch.zeros(1, params.embed_dim, h, w)) + pos_embed = nn.Parameter( + torch.zeros(1, params.embed_dim, local_img_shape[0], local_img_shape[1]) + ) pos_embed.is_shared_mp = ["matmul"] trunc_normal_(pos_embed, std=0.02) return pos_embed return SphericalFourierNeuralOperatorNet( params, - img_shape=img_shape, + img_shape=local_img_shape, in_chans=in_chans, out_chans=out_chans, context_config=context_config, From 58c9d98d27cc324e99bb53ca2faae33f7ee00b7a Mon Sep 17 00:00:00 2001 From: mahf708 Date: Tue, 24 Feb 2026 09:18:41 -0800 Subject: [PATCH 2/4] fix --- fme/ace/models/makani_fcn3/mpu/layer_norm.py | 2 +- fme/ace/models/makani_fcn3/mpu/layers.py | 2 +- fme/ace/registry/sfno.py | 3 --- .../distributed/model_torch_distributed.py | 19 +++++-------------- .../parallel_tests/test_spatial.py | 1 - 5 files changed, 7 insertions(+), 20 deletions(-) diff --git a/fme/ace/models/makani_fcn3/mpu/layer_norm.py b/fme/ace/models/makani_fcn3/mpu/layer_norm.py index e9bbfface..12cbc59a1 100644 --- a/fme/ace/models/makani_fcn3/mpu/layer_norm.py +++ b/fme/ace/models/makani_fcn3/mpu/layer_norm.py @@ -1,4 +1,4 @@ -# type: ignore +# mypy: ignore-errors # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # diff --git a/fme/ace/models/makani_fcn3/mpu/layers.py b/fme/ace/models/makani_fcn3/mpu/layers.py index 1608316c4..815d33992 100644 --- a/fme/ace/models/makani_fcn3/mpu/layers.py +++ b/fme/ace/models/makani_fcn3/mpu/layers.py @@ -1,4 +1,4 @@ -# type: ignore +# mypy: ignore-errors # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # diff --git a/fme/ace/registry/sfno.py b/fme/ace/registry/sfno.py index b1e727942..9fc737f58 100644 --- a/fme/ace/registry/sfno.py +++ b/fme/ace/registry/sfno.py @@ -47,9 +47,6 @@ def build( n_out_channels: int, dataset_info: DatasetInfo, ): - from fme.core.distributed import Distributed - - dist = Distributed.get_instance() if len(dataset_info.all_labels) > 0: raise ValueError( "SphericalFourierNeuralOperatorNet does not support labels" diff --git a/fme/core/distributed/model_torch_distributed.py b/fme/core/distributed/model_torch_distributed.py index 8838bcf96..a244d6195 100644 --- a/fme/core/distributed/model_torch_distributed.py +++ b/fme/core/distributed/model_torch_distributed.py @@ -304,9 +304,7 @@ def spatial_group(self): """Flat (h, w) process group for all-gather / all-reduce.""" if not hasattr(self, "_spatial_group"): # Flatten the (h, w) sub-mesh into a single group. - self._spatial_group = self._dm.get_mesh_group( - self._mesh["h", "w"] - ) + self._spatial_group = self._dm.get_mesh_group(self._mesh["h", "w"]) return self._spatial_group def scatter_spatial( @@ -335,12 +333,8 @@ def scatter_spatial( h_chunk = h_total // self._h_size w_chunk = w_total // self._w_size - tensor = tensor.narrow( - h_dim, self._h_rank * h_chunk, h_chunk - ) - tensor = tensor.narrow( - w_dim, self._w_rank * w_chunk, w_chunk - ) + tensor = tensor.narrow(h_dim, self._h_rank * h_chunk, h_chunk) + tensor = tensor.narrow(w_dim, self._w_rank * w_chunk, w_chunk) return tensor.contiguous() def gather_spatial( @@ -357,9 +351,7 @@ def gather_spatial( w_dim = w_dim % ndim # All-gather across the flat spatial group. - gather_list = [ - torch.empty_like(tensor) for _ in range(spatial_size) - ] + gather_list = [torch.empty_like(tensor) for _ in range(spatial_size)] torch.distributed.all_gather( gather_list, tensor.contiguous(), group=self.spatial_group ) @@ -369,8 +361,7 @@ def gather_spatial( rows = [] for hi in range(self._h_size): row_tiles = [ - gather_list[hi * self._w_size + wi] - for wi in range(self._w_size) + gather_list[hi * self._w_size + wi] for wi in range(self._w_size) ] rows.append(torch.cat(row_tiles, dim=w_dim)) return torch.cat(rows, dim=h_dim) diff --git a/fme/core/distributed/parallel_tests/test_spatial.py b/fme/core/distributed/parallel_tests/test_spatial.py index 1e517947f..129cbb19f 100644 --- a/fme/core/distributed/parallel_tests/test_spatial.py +++ b/fme/core/distributed/parallel_tests/test_spatial.py @@ -12,7 +12,6 @@ from fme.core import get_device from fme.core.distributed import Distributed - # ----------------------------------------------------------------------- # scatter + gather round-trip # ----------------------------------------------------------------------- From 0b5f2cf9b218eb468840b183669fdb331370bd3b Mon Sep 17 00:00:00 2001 From: mahf708 Date: Tue, 24 Feb 2026 09:26:24 -0800 Subject: [PATCH 3/4] fix --- fme/core/testing/distributed.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/fme/core/testing/distributed.py b/fme/core/testing/distributed.py index 743709cea..a79bcabc6 100644 --- a/fme/core/testing/distributed.py +++ b/fme/core/testing/distributed.py @@ -43,6 +43,17 @@ def gather_irregular(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: """ return self.gather(tensor) # this is single-process, can't be irregular + def reduce_sum_spatial(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor + + @property + def h_size(self) -> int: + return 1 + + @property + def w_size(self) -> int: + return 1 + @contextlib.contextmanager def mock_distributed(fill_value: float = 0.0, world_size: int = 1): From 971319330c63b58651cf2c98cef986262a0ae65a Mon Sep 17 00:00:00 2001 From: mahf708 Date: Tue, 24 Feb 2026 12:06:05 -0800 Subject: [PATCH 4/4] fix --- fme/ace/models/makani/sfnonet.py | 7 +++ fme/ace/models/modulus/sfnonet.py | 7 +++ fme/ace/registry/stochastic_sfno.py | 8 ++- .../parallel_tests/test_spatial.py | 4 +- .../distributed/parallel_tests/test_step.py | 29 ++++++++--- fme/core/distributed/sht_compat.py | 50 +++++++++++++++++++ fme/core/gridded_ops.py | 7 +++ fme/core/models/conditional_sfno/sfnonet.py | 19 +++++++ fme/core/step/test_step.py | 30 ++++++++--- 9 files changed, 145 insertions(+), 16 deletions(-) create mode 100644 fme/core/distributed/sht_compat.py diff --git a/fme/ace/models/makani/sfnonet.py b/fme/ace/models/makani/sfnonet.py index caa7d5996..e291cd686 100644 --- a/fme/ace/models/makani/sfnonet.py +++ b/fme/ace/models/makani/sfnonet.py @@ -523,6 +523,13 @@ def _init_spectral_transforms( self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=sht_grid_type ).float() + if spatial_parallel: + from fme.core.distributed.sht_compat import patch_distributed_sht + + patch_distributed_sht( + self.trans_down, self.itrans_up, self.trans, self.itrans + ) + elif spectral_transform == "fft": if spatial_parallel: fft_handle = DistributedRealFFT2 diff --git a/fme/ace/models/modulus/sfnonet.py b/fme/ace/models/modulus/sfnonet.py index 1295996e8..6d86fdbc3 100644 --- a/fme/ace/models/modulus/sfnonet.py +++ b/fme/ace/models/modulus/sfnonet.py @@ -537,6 +537,13 @@ def __init__( self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss" ).float() + if sp_active: + from fme.core.distributed.sht_compat import patch_distributed_sht + + patch_distributed_sht( + self.trans_down, self.itrans_up, self.trans, self.itrans + ) + elif self.spectral_transform == "fft": if residual_filter_factor != 1: raise NotImplementedError( diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index bf22f0d54..aef2ffdcf 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -7,6 +7,7 @@ from fme.ace.registry.registry import ModuleConfig, ModuleSelector from fme.core.dataset_info import DatasetInfo +from fme.core.distributed import Distributed from fme.core.models.conditional_sfno.sfnonet import ( Context, ContextConfig, @@ -262,11 +263,16 @@ def build( embed_dim_labels=len(dataset_info.all_labels), ), ) + # Under spatial parallelism the wrapper's positional + # embedding must use local (sharded) spatial dimensions. + dist = Distributed.get_instance() + local_h = dataset_info.img_shape[0] // max(dist.h_size, 1) + local_w = dataset_info.img_shape[1] // max(dist.w_size, 1) return NoiseConditionedSFNO( sfno_net, noise_type=self.noise_type, embed_dim_noise=self.noise_embed_dim, embed_dim_pos=self.context_pos_embed_dim, embed_dim_labels=len(dataset_info.all_labels), - img_shape=dataset_info.img_shape, + img_shape=(local_h, local_w), ) diff --git a/fme/core/distributed/parallel_tests/test_spatial.py b/fme/core/distributed/parallel_tests/test_spatial.py index 129cbb19f..491e1f0a2 100644 --- a/fme/core/distributed/parallel_tests/test_spatial.py +++ b/fme/core/distributed/parallel_tests/test_spatial.py @@ -33,7 +33,9 @@ def test_scatter_gather_roundtrip_4d(): """scatter then gather on a (B, C, H, W) tensor recovers the original.""" dist = Distributed.get_instance() B, C, H, W = 2, 3, 8, 12 - x = torch.randn(B, C, H, W, device=get_device()) + # Use a fixed seed so all ranks produce the identical starting tensor. + gen = torch.Generator(device=get_device()).manual_seed(42) + x = torch.randn(B, C, H, W, device=get_device(), generator=gen) local = dist.scatter_spatial(x) # default h_dim=-2, w_dim=-1 reconstructed = dist.gather_spatial(local) torch.testing.assert_close(reconstructed, x) diff --git a/fme/core/distributed/parallel_tests/test_step.py b/fme/core/distributed/parallel_tests/test_step.py index 0d5453b2a..5d493457b 100644 --- a/fme/core/distributed/parallel_tests/test_step.py +++ b/fme/core/distributed/parallel_tests/test_step.py @@ -9,6 +9,7 @@ import dataclasses import datetime +import os import pathlib import tempfile import unittest @@ -26,6 +27,7 @@ from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig from fme.core.dataset_info import DatasetInfo +from fme.core.distributed import Distributed from fme.core.distributed.non_distributed import DummyWrapper from fme.core.labels import BatchLabels from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig @@ -37,7 +39,8 @@ from fme.core.step.step import StepABC, StepSelector from fme.core.typing_ import TensorDict -DEFAULT_IMG_SHAPE = (45, 90) +# Must be divisible by h_size (typically 2) for spatial parallelism. +DEFAULT_IMG_SHAPE = (44, 90) DATA_DIR = pathlib.Path(__file__).parent / "testdata" @@ -90,7 +93,6 @@ def get_single_module_noise_conditioned_selector( context_pos_embed_dim=2, pos_embed=False, num_layers=2, - local_blocks=[0], affine_norms=True, ) ), @@ -169,7 +171,6 @@ def get_single_module_with_atmosphere_corrector_selector( noise_embed_dim=4, noise_type="isotropic", num_layers=2, - local_blocks=[0], ) ), ), @@ -238,12 +239,20 @@ def get_tensor_dict( return data_dict +def scatter_tensor_dict(td: TensorDict) -> TensorDict: + """Scatter each tensor in the dict to the local spatial shard.""" + dist = Distributed.get_instance() + return {k: dist.scatter_spatial(v) for k, v in td.items()} + + def get_step( selector: StepSelector, img_shape: tuple[int, int], init_weights: Callable[[list[nn.Module]], None] = lambda _: None, all_labels: set[str] | None = None, ) -> StepABC: + # Ensure CUDA device is set before any tensor creation. + Distributed.get_instance() device = fme.get_device() horizontal_coordinate = LatLonCoordinates( lat=torch.zeros(img_shape[0], device=device), @@ -268,9 +277,11 @@ def test_step_applies_wrapper(config: StepSelector): img_shape = DEFAULT_IMG_SHAPE n_samples = 5 step = get_step(config, img_shape) - input_data = get_tensor_dict(step.input_names, img_shape, n_samples) - next_step_input_data = get_tensor_dict( - step.next_step_input_names, img_shape, n_samples + input_data = scatter_tensor_dict( + get_tensor_dict(step.input_names, img_shape, n_samples) + ) + next_step_input_data = scatter_tensor_dict( + get_tensor_dict(step.next_step_input_names, img_shape, n_samples) ) multi_calls = 1 if isinstance(config._step_config_instance, MultiCallStepConfig): @@ -412,6 +423,12 @@ def cache_step_output(output_data: TensorDict, checkpoint_path: pathlib.Path): SELECTOR_GETTERS.items(), ) @pytest.mark.parallel +@pytest.mark.skipif( + int(os.environ.get("FME_DISTRIBUTED_H", "1")) > 1 + or int(os.environ.get("FME_DISTRIBUTED_W", "1")) > 1, + reason="Cached regression testdata was saved without spatial parallelism; " + "model state_dict keys may differ under distributed SHT.", +) def test_step_regression( case_name, get_config: Callable[[pathlib.Path | None], StepSelector], diff --git a/fme/core/distributed/sht_compat.py b/fme/core/distributed/sht_compat.py new file mode 100644 index 000000000..fdf45d51f --- /dev/null +++ b/fme/core/distributed/sht_compat.py @@ -0,0 +1,50 @@ +""" +Compatibility shim for torch_harmonics distributed transforms. + +Some versions of torch_harmonics do not expose ``lmax_local``, +``mmax_local``, ``lpad_local``, and ``mpad_local`` as attributes on +``DistributedRealSHT`` / ``DistributedInverseRealSHT``. The codebase +(e.g. ``s2convolutions.py``) reads these attributes on the transform +objects, so we patch them in after construction when they are missing. +""" + +from __future__ import annotations + +import torch.nn as nn +import torch_harmonics.distributed as thd + + +def _patch_sht_local_attrs(transform: nn.Module) -> nn.Module: + """Add ``lmax_local``, ``mmax_local``, ``lpad_local``, ``mpad_local`` + to a distributed SHT / ISHT module if they are missing. + + The values are derived from ``l_shapes`` / ``m_shapes`` which *are* + always present (set by all torch_harmonics versions we support). + + This is a no-op when the attributes already exist. + """ + if not isinstance( + transform, (thd.DistributedRealSHT, thd.DistributedInverseRealSHT) + ): + return transform + + if not hasattr(transform, "lmax_local"): + transform.lmax_local = transform.l_shapes[transform.comm_rank_polar] + + if not hasattr(transform, "mmax_local"): + transform.mmax_local = transform.m_shapes[transform.comm_rank_azimuth] + + if not hasattr(transform, "lpad_local"): + # Pad = 0 because compute_split_shapes always sums to the total. + transform.lpad_local = 0 + + if not hasattr(transform, "mpad_local"): + transform.mpad_local = 0 + + return transform + + +def patch_distributed_sht(*transforms: nn.Module) -> None: + """Convenience: patch multiple transforms in one call.""" + for t in transforms: + _patch_sht_local_attrs(t) diff --git a/fme/core/gridded_ops.py b/fme/core/gridded_ops.py index e0916b723..009ce4a1b 100644 --- a/fme/core/gridded_ops.py +++ b/fme/core/gridded_ops.py @@ -355,6 +355,8 @@ def _get_area_weights( name: str | None = None, regional_weights: torch.Tensor | None = None, ): + from fme.core.distributed import Distributed + if data.device == torch.device("cpu"): area_weights = self._cpu_area mask_provider = self._cpu_mask_provider @@ -362,6 +364,11 @@ def _get_area_weights( area_weights = self._device_area mask_provider = self._device_mask_provider area_weights = _mask_area_weights(area_weights, mask_provider, name) + # Under spatial parallelism the stored weights are at global shape + # while data tensors are at local (sharded) shape. Scatter the + # weights so they match the data. + dist = Distributed.get_instance() + area_weights = dist.scatter_spatial(area_weights) if regional_weights is None: return area_weights if regional_weights.device.type != data.device.type: diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/sfnonet.py index db6afa76a..ec4666e9c 100644 --- a/fme/core/models/conditional_sfno/sfnonet.py +++ b/fme/core/models/conditional_sfno/sfnonet.py @@ -382,6 +382,20 @@ def get_lat_lon_sfnonet( data_grid = params.data_grid if hasattr(params, "data_grid") else "equiangular" spatial_parallel = dist.h_size > 1 or dist.w_size > 1 + + if spatial_parallel: + local_blocks = ( + params.local_blocks + if hasattr(params, "local_blocks") + else None + ) + if local_blocks: + raise ValueError( + "local_blocks (DISCO convolution) is not supported " + "with spatial parallelism. Set local_blocks=[] to " + "use spectral filters for all blocks." + ) + if spatial_parallel: thd.init(dist.h_group, dist.w_group) sht_cls = thd.DistributedRealSHT @@ -403,6 +417,11 @@ def get_lat_lon_sfnonet( h, w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss" ).float() + if spatial_parallel: + from fme.core.distributed.sht_compat import patch_distributed_sht + + patch_distributed_sht(trans_down, itrans_up, trans, itrans) + # Under spatial parallelism the model sees the *local* shard dims. if spatial_parallel: local_h = h // dist.h_size diff --git a/fme/core/step/test_step.py b/fme/core/step/test_step.py index f1f248459..c009d1bfc 100644 --- a/fme/core/step/test_step.py +++ b/fme/core/step/test_step.py @@ -18,6 +18,7 @@ from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig from fme.core.dataset_info import DatasetInfo +from fme.core.distributed import Distributed from fme.core.distributed.non_distributed import DummyWrapper from fme.core.labels import BatchLabels from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig @@ -31,7 +32,8 @@ from .radiation import SeparateRadiationStepConfig -DEFAULT_IMG_SHAPE = (45, 90) +# Must be divisible by h_size (typically 2) for spatial parallelism. +DEFAULT_IMG_SHAPE = (44, 90) def get_network_and_loss_normalization_config( @@ -206,7 +208,6 @@ def get_label_conditioned_selector( noise_embed_dim=4, noise_type="isotropic", num_layers=2, - local_blocks=[0], ) ), ), @@ -455,6 +456,8 @@ def get_step( init_weights: Callable[[list[nn.Module]], None] = lambda _: None, all_labels: set[str] | None = None, ) -> StepABC: + # Ensure CUDA device is set before any tensor creation. + Distributed.get_instance() device = fme.get_device() horizontal_coordinate = LatLonCoordinates( lat=torch.zeros(img_shape[0], device=device), @@ -476,10 +479,19 @@ def get_step( def test_label_conditioned_step(): selector = get_label_conditioned_selector() step = get_step(selector, DEFAULT_IMG_SHAPE, all_labels={"a", "b"}) - input_data = get_tensor_dict(step.input_names, DEFAULT_IMG_SHAPE, n_samples=1) - next_step_input_data = get_tensor_dict( - step.next_step_input_names, DEFAULT_IMG_SHAPE, n_samples=1 - ) + dist = Distributed.get_instance() + input_data = { + k: dist.scatter_spatial(v) + for k, v in get_tensor_dict( + step.input_names, DEFAULT_IMG_SHAPE, n_samples=1 + ).items() + } + next_step_input_data = { + k: dist.scatter_spatial(v) + for k, v in get_tensor_dict( + step.next_step_input_names, DEFAULT_IMG_SHAPE, n_samples=1 + ).items() + } output = step.step( args=StepArgs( input=input_data, @@ -490,8 +502,10 @@ def test_label_conditioned_step(): ), wrapper=lambda x: x, ) - assert output["diagnostic_main"].shape == (1, 45, 90) - assert output["diagnostic_rad"].shape == (1, 45, 90) + expected_h = DEFAULT_IMG_SHAPE[0] // max(dist.h_size, 1) + expected_w = DEFAULT_IMG_SHAPE[1] // max(dist.w_size, 1) + assert output["diagnostic_main"].shape == (1, expected_h, expected_w) + assert output["diagnostic_rad"].shape == (1, expected_h, expected_w) @pytest.mark.parametrize("config", HAS_NEXT_STEP_FORCING_NAME_CASES)