Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 245 additions & 0 deletions fme/ace/stepper/test_single_module_csfno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
"""
Parallel regression tests for the SingleModuleStepper with NoiseConditionedSFNO.

These tests verify that the forward pass and loss computation produce identical
results regardless of spatial decomposition (nproc=1 vs model-parallel).
"""

import dataclasses
import datetime
import os
from collections.abc import Mapping

import numpy as np
import pytest
import torch
import xarray as xr

from fme.ace.data_loading.batch_data import BatchData
from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder
from fme.ace.stepper.single_module import (
StepperConfig,
TrainOutput,
TrainStepper,
TrainStepperConfig,
)
from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates
from fme.core.dataset_info import DatasetInfo
from fme.core.device import get_device
from fme.core.distributed.distributed import Distributed
from fme.core.loss import StepLossConfig
from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig
from fme.core.optimization import NullOptimization, OptimizationConfig
from fme.core.registry.module import ModuleSelector
from fme.core.step import SingleModuleStepConfig, StepSelector
from fme.core.testing.regression import validate_tensor_dict
from fme.core.typing_ import EnsembleTensorDict

DIR = os.path.abspath(os.path.dirname(__file__))
TIMESTEP = datetime.timedelta(hours=6)


def get_dataset_info(
img_shape=(5, 5),
) -> DatasetInfo:
horizontal_coordinate = LatLonCoordinates(
lat=torch.zeros(img_shape[-2]),
lon=torch.zeros(img_shape[-1]),
)
vertical_coordinate = HybridSigmaPressureCoordinate(
ak=torch.arange(7), bk=torch.arange(7)
)
return DatasetInfo(
horizontal_coordinates=horizontal_coordinate,
vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)


def _get_train_stepper(
stepper_config: StepperConfig,
dataset_info: DatasetInfo,
**train_config_kwargs,
) -> TrainStepper:
train_config = TrainStepperConfig(**train_config_kwargs)
return train_config.get_train_stepper(stepper_config, dataset_info)


def get_regression_stepper_and_data() -> (
tuple[TrainStepper, BatchData, tuple[int, int]]
):
in_names = ["a", "b"]
out_names = ["b", "c"]
n_forward_steps = 2
n_samples = 3
img_shape = (9, 18)
device = get_device()

all_names = list(set(in_names + out_names))

loss = StepLossConfig(type="AreaWeightedMSE")

config = StepperConfig(
step=StepSelector(
type="single_module",
config=dataclasses.asdict(
SingleModuleStepConfig(
builder=ModuleSelector(
type="NoiseConditionedSFNO",
config=dataclasses.asdict(
NoiseConditionedSFNOBuilder(
embed_dim=16,
num_layers=2,
noise_embed_dim=16,
noise_type="isotropic",
)
),
),
in_names=in_names,
out_names=out_names,
normalization=NetworkAndLossNormalizationConfig(
network=NormalizationConfig(
means={n: 0.1 for n in all_names},
stds={n: 1.1 for n in all_names},
),
),
ocean=None,
)
),
),
)

dataset_info = get_dataset_info(img_shape=img_shape)
train_stepper = _get_train_stepper(config, dataset_info, loss=loss)
data = BatchData.new_on_device(
data={
"a": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
"b": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
"c": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device),
},
time=xr.DataArray(
np.zeros((n_samples, n_forward_steps + 1)),
dims=["sample", "time"],
),
labels=None,
epoch=0,
horizontal_dims=["lat", "lon"],
)
data = data.scatter_spatial(img_shape)
return train_stepper, data, img_shape


def flatten_dict(
d: Mapping[str, Mapping[str, torch.Tensor]],
) -> dict[str, torch.Tensor]:
return_dict = {}
for k, v in d.items():
for k2, v2 in v.items():
return_dict[f"{k}.{k2}"] = v2
return return_dict


def _get_train_output_tensor_dict(data: TrainOutput) -> dict[str, torch.Tensor]:
return_dict = {}
for k, v in data.metrics.items():
return_dict[f"metrics.{k}"] = v
for k, v in data.gen_data.items():
return_dict[f"gen_data.{k}"] = v
for k, v in data.target_data.items():
assert v.shape[1] == 1
return_dict[f"target_data.{k}"] = v
return return_dict


def get_train_outputs_tensor_dict(
step_1: TrainOutput, step_2: TrainOutput
) -> dict[str, torch.Tensor]:
return flatten_dict(
{
"step_1": _get_train_output_tensor_dict(step_1),
"step_2": _get_train_output_tensor_dict(step_2),
}
)


@pytest.mark.parallel
def test_stepper_train_on_batch_regression():
torch.manual_seed(0)
train_stepper, data, img_shape = get_regression_stepper_and_data()
optimization = NullOptimization()
result1 = train_stepper.train_on_batch(data, optimization)
result2 = train_stepper.train_on_batch(data, optimization)
dist = Distributed.get_instance()
Comment thread
mahf708 marked this conversation as resolved.
for result in [result1, result2]:
result.gen_data = EnsembleTensorDict(
dist.gather_spatial(dict(result.gen_data), img_shape)
)
result.target_data = EnsembleTensorDict(
dist.gather_spatial(dict(result.target_data), img_shape)
)
output_dict = get_train_outputs_tensor_dict(result1, result2)
validate_tensor_dict(
output_dict,
os.path.join(
DIR,
"testdata/csfno_stepper_train_on_batch_regression.pt",
),
atol=1e-4,
rtol=1e-4,
)


@pytest.mark.parallel
def test_stepper_train_on_batch_with_optimization_regression():
torch.manual_seed(0)
train_stepper, data, img_shape = get_regression_stepper_and_data()
optimization = OptimizationConfig(
optimizer_type="Adam",
lr=0.0001,
).build(train_stepper.modules, max_epochs=1)
result1 = train_stepper.train_on_batch(data, optimization)
result2 = train_stepper.train_on_batch(data, optimization)
dist = Distributed.get_instance()
for result in [result1, result2]:
result.gen_data = EnsembleTensorDict(
dist.gather_spatial(dict(result.gen_data), img_shape)
)
result.target_data = EnsembleTensorDict(
dist.gather_spatial(dict(result.target_data), img_shape)
)
output_dict = get_train_outputs_tensor_dict(result1, result2)
validate_tensor_dict(
output_dict,
os.path.join(
DIR,
"testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt",
),
atol=1e-2,
rtol=1e-2,
)


@pytest.mark.parallel
def test_stepper_predict_regression():
torch.manual_seed(0)
train_stepper, data, img_shape = get_regression_stepper_and_data()
stepper = train_stepper._stepper
initial_condition = data.get_start(
prognostic_names=["b"],
n_ic_timesteps=1,
)
output, next_state = stepper.predict(
initial_condition, data, compute_derived_variables=True
)
dist = Distributed.get_instance()
output_data = dist.gather_spatial(dict(output.data), img_shape)
next_state_data = dist.gather_spatial(
dict(next_state.as_batch_data().data), img_shape
)
output_dict = flatten_dict({"output": output_data, "next_state": next_state_data})
validate_tensor_dict(
output_dict,
os.path.join(DIR, "testdata/csfno_stepper_predict_regression.pt"),
atol=1e-4,
rtol=1e-4,
)
Binary file not shown.
Binary file not shown.
Binary file not shown.
84 changes: 79 additions & 5 deletions fme/core/distributed/model_torch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@

import torch
import torch.distributed
import torch.distributed as pt_dist
import torch.nn as nn
import torch_harmonics.distributed as thd
from torch.amp import custom_bwd, custom_fwd
from torch.nn import SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel

Expand All @@ -42,6 +44,35 @@
T = TypeVar("T")


class _AutogradAllReduce(torch.autograd.Function):
"""Autograd-aware all-reduce (sum) for spatial parallelism.
Forward: all-reduce (sum) the input across the given process group.
Backward: identity — gradients pass through without communication.
This makes ``spatial_reduce_sum`` differentiable so that gradients
flow correctly through the loss computation path::
AreaWeightedMSELoss → area_weighted_mean → weighted_mean
→ spatial_reduce_sum (uses this function)
Without this, the raw ``torch.distributed.all_reduce`` would break
the autograd graph because it is an in-place, non-differentiable op.
"""

@staticmethod
@custom_fwd(device_type="cuda")
def forward(
ctx,
input: torch.Tensor,
group: torch.distributed.ProcessGroup,
) -> torch.Tensor:
output = input.clone()
torch.distributed.all_reduce(output, group=group)
return output

@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output: torch.Tensor):
return grad_output.clone(), None
Comment thread
mahf708 marked this conversation as resolved.
Comment on lines +66 to +73


class ModelTorchDistributed(DistributedBackend):
"""Distributed backend with spatial model parallelism.

Expand Down Expand Up @@ -307,31 +338,73 @@ def _device_ids(self) -> list[int] | None:
def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
"""Wrap with DDP over the **data** process group.

For now, we assume spatial communication is expected to be handled
inside the model layers themselves. If we need to change course, we
can revisit...
Spatial model parallelism is handled by:
- Forward: communication inside model layers (distributed SHT/iSHT)
- Backward: gradient hooks registered here that all-reduce across
spatial ranks, so every rank sees the global-mean gradient.

``broadcast_buffers=False`` is required because the SHT/iSHT layers
store precomputed Legendre polynomial buffers. DDP's default
buffer broadcast modifies these in-place between forward calls,
which breaks autograd's tensor-version tracking.
"""
if any(p.requires_grad for p in module.parameters()):
if using_gpu():
output_device = [self._device_id]
else:
output_device = None
return DistributedDataParallel(
wrapped = DistributedDataParallel(
SyncBatchNorm.convert_sync_batchnorm(module),
device_ids=self._device_ids,
output_device=output_device,
process_group=self._data_group,
broadcast_buffers=False,
)
self._register_spatial_grad_hooks(wrapped)
return wrapped
return DummyWrapper(module)

def _register_spatial_grad_hooks(self, module: torch.nn.Module) -> None:
"""All-reduce gradients across spatial ranks after each backward.

Each spatial rank only sees its local slice of the input, so its
gradient is a partial sum. This hook sums those partials so
that every rank applies the same weight update.

The hook fires via ``register_hook`` on each parameter, which is
invoked with the per-backward gradient tensor before it is
accumulated into ``.grad`` and before DDP's data-parallel
all-reduce. The two reductions commute (orthogonal groups), so
ordering does not matter.
"""
if self._h_size <= 1 and self._w_size <= 1:
return
spatial_group = self._spatial_group

def _hook(grad: torch.Tensor) -> torch.Tensor:
if grad is None:
return grad

reduced = grad.contiguous().clone()
torch.distributed.all_reduce(reduced, group=spatial_group)

# If we want mean gradient instead of sum, we want:
# reduced /= (self._h_size * self._w_size)

return reduced

for p in module.parameters():
if p.requires_grad:
p.register_hook(_hook)

Comment on lines +396 to +399
def barrier(self):
"""Global barrier across all ranks."""
logger.debug("Barrier on rank %d", self._rank)
torch.distributed.barrier(device_ids=self._device_ids)

def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor:
if self._h_size > 1 or self._w_size > 1:
torch.distributed.all_reduce(tensor, group=self._spatial_group)
return _AutogradAllReduce.apply(tensor, self._spatial_group)
return tensor

def weighted_mean(
Expand All @@ -341,6 +414,7 @@ def weighted_mean(
dim: tuple[int, ...],
keepdim: bool = False,
) -> torch.Tensor:

from fme.core.metrics import weighted_sum

local_weighted_sum = weighted_sum(data, weights, dim=dim, keepdim=keepdim)
Expand Down
Loading
Loading