Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fme/ace/aggregator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations)

@torch.no_grad()
def record_batch(self, batch: TrainOutput):
batch = batch.gather_spatial()
self._loss += batch.metrics["loss"]
self._n_loss_batches += 1

Expand Down
20 changes: 20 additions & 0 deletions fme/ace/data_loading/batch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,26 @@ def pin_memory(self: SelfType) -> SelfType:
self.data = {name: tensor.pin_memory() for name, tensor in self.data.items()}
return self

def to_spatial_shard(self: SelfType) -> SelfType:
"""Slice every tensor in *data* to this rank's spatial tile.

When spatial parallelism is inactive the ``scatter_spatial`` call is a
no-op and this method returns *self* unchanged.
"""
from fme.core.distributed import Distributed

dist = Distributed.get_instance()
if dist.h_size <= 1 and dist.w_size <= 1:
return self
return self.__class__(
data={k: dist.scatter_spatial(v) for k, v in self.data.items()},
time=self.time,
horizontal_dims=self.horizontal_dims,
epoch=self.epoch,
labels=self.labels,
n_ensemble=self.n_ensemble,
)


@dataclasses.dataclass
class PairedData:
Expand Down
86 changes: 68 additions & 18 deletions fme/ace/models/makani/sfnonet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@

# get spectral transforms from torch_harmonics
import torch_harmonics as th
import torch_harmonics.distributed as thd
from torch.utils.checkpoint import checkpoint

from fme.ace.models.makani_fcn3.mpu.layers import (
DistributedInverseRealFFT2,
DistributedRealFFT2,
)
from fme.ace.models.makani_fcn3.mpu.layer_norm import DistributedInstanceNorm2d

from .layers import MLP, DropPath, EncoderDecoder, InverseRealFFT2, RealFFT2
from .spectral_convolution import FactorizedSpectralConv, SpectralConv

Expand Down Expand Up @@ -336,13 +343,24 @@ def __init__(
if normalization_layer == "layer_norm":
raise NotImplementedError("requires makani distributed libraries")
elif normalization_layer == "instance_norm":
norm_layer_inp = partial(
nn.InstanceNorm2d,
num_features=embed_dim,
eps=1e-6,
affine=True,
track_running_stats=False,
)
from fme.core.distributed import Distributed

dist = Distributed.get_instance()
if dist.h_size > 1 or dist.w_size > 1:
norm_layer_inp = partial(
DistributedInstanceNorm2d,
num_features=embed_dim,
eps=1e-6,
affine=True,
)
else:
norm_layer_inp = partial(
nn.InstanceNorm2d,
num_features=embed_dim,
eps=1e-6,
affine=True,
track_running_stats=False,
)
norm_layer_out = norm_layer_mid = norm_layer_inp
elif normalization_layer == "none":
norm_layer_out = norm_layer_mid = norm_layer_inp = nn.Identity
Expand Down Expand Up @@ -464,9 +482,17 @@ def _init_spectral_transforms(
max_modes=None,
):
"""
Initialize the spectral transforms based on the maximum number of modes to keep. Handles the computation
of local image shapes and domain parallelism, based on the
Initialize the spectral transforms based on the maximum number of
modes to keep. Automatically selects distributed variants when
spatial parallelism is active.
"""
from fme.core.distributed import Distributed

dist = Distributed.get_instance()
spatial_parallel = dist.h_size > 1 or dist.w_size > 1

if spatial_parallel:
thd.init(dist.h_group, dist.w_group)

if max_modes is not None:
modes_lat, modes_lon = max_modes
Expand All @@ -476,8 +502,12 @@ def _init_spectral_transforms(

# prepare the spectral transforms
if spectral_transform == "sht":
sht_handle = th.RealSHT
isht_handle = th.InverseRealSHT
if spatial_parallel:
sht_handle = thd.DistributedRealSHT
isht_handle = thd.DistributedInverseRealSHT
else:
sht_handle = th.RealSHT
isht_handle = th.InverseRealSHT

# set up
self.trans_down = sht_handle(
Expand All @@ -493,9 +523,20 @@ def _init_spectral_transforms(
self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=sht_grid_type
).float()

if spatial_parallel:
from fme.core.distributed.sht_compat import patch_distributed_sht

patch_distributed_sht(
self.trans_down, self.itrans_up, self.trans, self.itrans
)

elif spectral_transform == "fft":
fft_handle = RealFFT2
ifft_handle = InverseRealFFT2
if spatial_parallel:
fft_handle = DistributedRealFFT2
ifft_handle = DistributedInverseRealFFT2
else:
fft_handle = RealFFT2
ifft_handle = InverseRealFFT2

self.trans_down = fft_handle(
self.inp_shape[0], self.inp_shape[1], lmax=modes_lat, mmax=modes_lon
Expand All @@ -512,11 +553,20 @@ def _init_spectral_transforms(
else:
raise (ValueError("Unknown spectral transform"))

# use the SHT/FFT to compute the local, downscaled grid dimensions
self.inp_shape_loc = (self.trans_down.nlat, self.trans_down.nlon)
self.out_shape_loc = (self.itrans_up.nlat, self.itrans_up.nlon)
self.h_loc = self.itrans.nlat
self.w_loc = self.itrans.nlon
# use the SHT/FFT to compute the local, downscaled grid dimensions.
# Under distributed transforms, .nlat / .nlon still report the
# *global* size, so we divide by the spatial group dimensions to
# get the local tile size used for position embeddings etc.
self.inp_shape_loc = (
self.trans_down.nlat // dist.h_size,
self.trans_down.nlon // dist.w_size,
)
self.out_shape_loc = (
self.itrans_up.nlat // dist.h_size,
self.itrans_up.nlon // dist.w_size,
)
self.h_loc = self.itrans.nlat // dist.h_size
self.w_loc = self.itrans.nlon // dist.w_size

@torch.jit.ignore
def no_weight_decay(self):
Expand Down
2 changes: 1 addition & 1 deletion fme/ace/models/makani_fcn3/mpu/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# type: ignore
# mypy: ignore-errors
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
2 changes: 1 addition & 1 deletion fme/ace/models/makani_fcn3/mpu/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# type: ignore
# mypy: ignore-errors
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
61 changes: 56 additions & 5 deletions fme/ace/models/makani_fcn3/utils/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,69 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Communicator shim for Makani / FCN3.

Delegates to the ACE :class:`~fme.core.distributed.Distributed` singleton so
that the rest of the vendored Makani code works unchanged.

Supported group names: ``"h"``, ``"w"``, ``"spatial"``, ``"matmul"``.
``"matmul"`` parallelism is not used (always returns size 1 / rank 0).
"""

from __future__ import annotations


def _dist():
"""Lazy import to avoid circular imports at module level."""
from fme.core.distributed import Distributed

return Distributed.get_instance()


def get_size(name: str) -> int:
return 1
"""Return the number of ranks in the named group."""
d = _dist()
if name == "h":
return d.h_size
if name == "w":
return d.w_size
if name == "spatial":
return d.h_size * d.w_size
if name == "matmul":
return 1
raise ValueError(f"Unknown comm group name: {name!r}")


def get_rank(name: str) -> int:
return 0
"""Return the rank of this process in the named group."""
d = _dist()
if name == "h":
return d.h_rank
if name == "w":
return d.w_rank
if name == "spatial":
# Row-major linearised rank within the (h, w) tile.
return d.h_rank * d.w_size + d.w_rank
if name == "matmul":
return 0
raise ValueError(f"Unknown comm group name: {name!r}")


def get_group(name: str):
return None
"""Return the ``torch.distributed`` process group for the named group."""
d = _dist()
if name == "h":
return d.h_group
if name == "w":
return d.w_group
if name == "spatial":
return d.spatial_group
if name == "matmul":
return None
raise ValueError(f"Unknown comm group name: {name!r}")


def is_distributed(name: str):
return False
def is_distributed(name: str) -> bool:
"""Return whether the named group has more than one rank."""
return get_size(name) > 1
Loading
Loading