diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index d5c0d3f7c..77b7a635a 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -18,6 +18,10 @@ ClosedInterval, LatLonCoordinates, adjust_fine_coord_range, + compute_lon_roll, expand_and_fold_tensor, + roll_latlon_coords, + roll_lon_coords, + roll_lon_data, scale_tuple, ) diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index ab51fadd0..357acaa89 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -24,9 +24,12 @@ BatchedLatLonCoordinates, ClosedInterval, check_leading_dim, + compute_lon_roll, expand_and_fold_tensor, get_offset, paired_shuffle, + roll_lon_coords, + roll_lon_data, scale_tuple, ) @@ -121,28 +124,22 @@ def __init__( self._orig_coords: LatLonCoordinates = properties.horizontal_coordinates - if (self.lon_interval.stop != float("inf")) and ( - torch.any(self._orig_coords.lon < self.lon_interval.stop - 360.0) - ): - lon_max = self._orig_coords.lon.max() - raise NotImplementedError( - f"lon wraparound not implemented, received lon_max {lon_max} but " - f"expected lon_max > {self.lon_interval.stop - 360.0}" - ) - if (self.lon_interval.start != -float("inf")) and ( - torch.any(self._orig_coords.lon > self.lon_interval.start + 360.0) - ): - lon_min = self._orig_coords.lon.min() - raise NotImplementedError( - f"lon wraparound not implemented, received lon_min {lon_min} but " - f"expected lon_min < {self.lon_interval.start + 360.0}" - ) + # Detect whether the requested lon interval spans the 0°/360° seam of the + # stored coordinates (e.g. start=-5, stop=3 on 0–360° data). When it does, + # roll the data along the longitude axis so the interval is contiguous. + lon_start = ( + self.lon_interval.start if self.lon_interval.start != -float("inf") else 0.0 + ) + self._lon_roll_amount = compute_lon_roll(self._orig_coords.lon, lon_start) + rolled_lon = roll_lon_coords( + self._orig_coords.lon, self._lon_roll_amount, lon_start + ) self._lats_slice = self.lat_interval.slice_from(self._orig_coords.lat) - self._lons_slice = self.lon_interval.slice_from(self._orig_coords.lon) + self._lons_slice = self.lon_interval.slice_from(rolled_lon) self._latlon_coordinates = LatLonCoordinates( lat=self._orig_coords.lat[self._lats_slice], - lon=self._orig_coords.lon[self._lons_slice], + lon=rolled_lon[self._lons_slice], ) self._area_weights = self._latlon_coordinates.area_weights @@ -168,7 +165,7 @@ def __len__(self): def __getitem__(self, key) -> DatasetItem: batch, times, _, epoch = self.dataset[key] batch = { - k: v[ + k: roll_lon_data(v, self._lon_roll_amount)[ ..., self._lats_slice, self._lons_slice, @@ -211,6 +208,10 @@ def __getitem__(self, idx) -> BatchItem: return BatchItem(fields, time.squeeze(), self._coordinates) + @property + def latlon_coordinates(self) -> LatLonCoordinates: + return self._coordinates + @property def variable_metadata(self) -> dict[str, VariableMetadata]: if self._properties is None: diff --git a/fme/downscaling/data/static.py b/fme/downscaling/data/static.py index e39b00588..c789fc886 100644 --- a/fme/downscaling/data/static.py +++ b/fme/downscaling/data/static.py @@ -5,7 +5,7 @@ from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device -from fme.downscaling.data.utils import ClosedInterval +from fme.downscaling.data.utils import ClosedInterval, roll_lon_coords, roll_lon_data @dataclasses.dataclass @@ -163,13 +163,31 @@ def subset( lat_slice = lat_interval.slice_from(self.coords.lat) lon_slice = lon_interval.slice_from(self.coords.lon) return StaticInputs( - fields=[field.subset(lat_slice, lon_slice) for field in self.fields], + fields=[ + StaticInput(data=field.data[lat_slice, lon_slice]) + for field in self.fields + ], coords=LatLonCoordinates( lat=lat_interval.subset_of(self.coords.lat), lon=lon_interval.subset_of(self.coords.lon), ), ) + def roll(self, roll_amount: int, lon_start: float) -> "StaticInputs": + """ + Roll the data and lon coordinates of the StaticInputs by the specified amount. + """ + if roll_amount == 0: + return self + rolled_lon = roll_lon_coords(self.coords.lon, roll_amount, lon_start) + return StaticInputs( + fields=[ + StaticInput(data=roll_lon_data(f.data, roll_amount, lon_dim=-1)) + for f in self.fields + ], + coords=LatLonCoordinates(lat=self.coords.lat, lon=rolled_lon), + ) + def to_device(self) -> "StaticInputs": return StaticInputs( fields=[field.to_device() for field in self.fields], diff --git a/fme/downscaling/data/test_datasets.py b/fme/downscaling/data/test_datasets.py index f3ac263de..0af6bc340 100644 --- a/fme/downscaling/data/test_datasets.py +++ b/fme/downscaling/data/test_datasets.py @@ -244,6 +244,51 @@ def test_horizontal_subset( assert dataset.subset_latlon_coordinates.lon.shape == (expected_n_lon,) +def test_horizontal_subset_prime_meridian_spanning(): + """HorizontalSubsetDataset must handle lon intervals that cross 0°/360°.""" + # 8-point longitude grid in 0–360° convention, 45° spacing + lons = torch.tensor([0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0]) + n_lat, n_lon = 4, 8 + coords = LatLonCoordinates( + lat=torch.linspace(0.0, 1.0, n_lat), + lon=lons, + ) + # Data: lon index encodes original position so we can verify the roll + data_tensor = torch.arange(n_lon, dtype=torch.float).unsqueeze(0).unsqueeze(0) + data_tensor = data_tensor.expand(1, 1, n_lat, n_lon).clone() + + datum: tuple[dict[str, torch.Tensor], xr.DataArray, set[str], int] = ( + {"x": data_tensor}, + xr.DataArray([0.0]), + set(), + 0, + ) + base_dataset = MagicMock(spec=torch.utils.data.Dataset) + properties = MagicMock(spec=DatasetProperties) + properties.horizontal_coordinates = coords + properties.all_labels = MagicMock(spec=set) + base_dataset.__getitem__.return_value = datum + + # Interval [-90, 45] spans 0° on a 0–360° grid (270°→45° going through 0°) + dataset = HorizontalSubsetDataset( + dataset=base_dataset, + properties=properties, + lat_interval=ClosedInterval(float("-inf"), float("inf")), + lon_interval=ClosedInterval(-90.0, 45.0), + ) + + # Expect 4 lon points: 270°→-90°, 315°→-45°, 0°→0°, 45°→45° + assert dataset.subset_latlon_coordinates.lon.shape == (4,) + expected_lons = torch.tensor([-90.0, -45.0, 0.0, 45.0]) + assert torch.allclose(dataset.subset_latlon_coordinates.lon, expected_lons) + + subset, _, _, _ = dataset[0] + assert subset["x"].shape == (1, 1, n_lat, 4) + # Data values should correspond to original lon indices 6, 7, 0, 1 + expected_vals = torch.tensor([6.0, 7.0, 0.0, 1.0]) + assert torch.allclose(subset["x"][0, 0, 0], expected_vals) + + def test_batch_data_from_sequence(): num_items = 3 items = get_batch_items(num_items=num_items) diff --git a/fme/downscaling/data/test_static.py b/fme/downscaling/data/test_static.py index d02bbe9ff..7e6371a08 100644 --- a/fme/downscaling/data/test_static.py +++ b/fme/downscaling/data/test_static.py @@ -178,3 +178,41 @@ def test__load_coords_from_ds(): ds = xr.Dataset(coords={"x": lon, "y": lat}) with pytest.raises(ValueError): _load_coords_from_ds(ds) + + +def test_StaticInputs_roll_shifts_data_and_coords(): + """StaticInputs.roll produces correctly rolled data and monotonic shifted coords.""" + from fme.core.coordinates import LatLonCoordinates + + # 1-degree global grid: 0.5, 1.5, ..., 359.5 + n_lon = 360 + lon = torch.arange(n_lon, dtype=torch.float32) + 0.5 + lat = torch.tensor([0.5], dtype=torch.float32) + coords = LatLonCoordinates(lat=lat, lon=lon) + data = torch.arange(n_lon, dtype=torch.float32).unsqueeze(0) # values = index + static = StaticInputs(fields=[StaticInput(data)], coords=coords) + + # Roll so domain starts at -5.5°: start_360=354.5, roll=(coords<354.5).sum()=354 + roll_amount = 354 + lon_start = -5.5 + rolled = static.roll(roll_amount, lon_start) + + # Data: original index 354 (value=354) should be at index 0 after rolling + assert rolled.fields[0].data[0, 0].item() == pytest.approx(354.0) + assert rolled.fields[0].data[0, -1].item() == pytest.approx(353.0) + + # Coordinates should be monotonically increasing and start near lon_start + assert torch.all(rolled.coords.lon[1:] > rolled.coords.lon[:-1]) + assert rolled.coords.lon[0].item() == pytest.approx(354.5 - 360) # = -5.5 + + +def test_StaticInputs_roll_zero_returns_self(): + from fme.core.coordinates import LatLonCoordinates + + coords = LatLonCoordinates( + lat=torch.tensor([0.0]), + lon=torch.tensor([0.5, 1.5, 2.5]), + ) + data = torch.zeros(1, 3) + static = StaticInputs(fields=[StaticInput(data)], coords=coords) + assert static.roll(0, 0.0) is static diff --git a/fme/downscaling/data/test_utils.py b/fme/downscaling/data/test_utils.py index 4cb94ac5f..7875edabd 100644 --- a/fme/downscaling/data/test_utils.py +++ b/fme/downscaling/data/test_utils.py @@ -5,7 +5,10 @@ from fme.downscaling.data.utils import ( ClosedInterval, adjust_fine_coord_range, + compute_lon_roll, paired_shuffle, + roll_lon_coords, + roll_lon_data, scale_slice, ) @@ -19,6 +22,65 @@ def _fine_midpoints(coarse_edges, downscale_factor): return torch.concatenate(fine_mids) +def _one_deg_lon_coords(): + """0.5, 1.5, ..., 359.5 — standard 1-degree global grid.""" + return torch.arange(0.5, 360.0, 1.0) + + +@pytest.mark.parametrize( + "lon_start, expected_roll", + [ + pytest.param(0.0, 0, id="zero_no_roll"), + pytest.param(10.0, 0, id="positive_in_range_no_roll"), + pytest.param(-5.0, 355, id="negative_start_rolls"), + pytest.param(-0.5, 359, id="just_negative_rolls"), + pytest.param(-180.0, 180, id="half_period"), + pytest.param(360.0, 0, id="exactly_360_no_roll"), + ], +) +def test_compute_lon_roll(lon_start, expected_roll): + coords = _one_deg_lon_coords() + assert compute_lon_roll(coords, lon_start) == expected_roll + + +def test_roll_lon_coords_negative_start(): + """Rolled coords for lon_start=-5 should be monotone and start near -5.""" + coords = _one_deg_lon_coords() + roll_amount = compute_lon_roll(coords, -5.0) + rolled = roll_lon_coords(coords, roll_amount, -5.0) + + assert rolled.shape == coords.shape + # monotonically increasing + assert torch.all(rolled[1:] > rolled[:-1]) + # first element is the first coord >= 355, shifted to negative convention + assert rolled[0].item() == pytest.approx(-4.5) + # last element of the original-convention "low" portion, shifted by -360 + assert rolled[-1].item() == pytest.approx(354.5) + + +def test_roll_lon_coords_zero_roll_returns_original(): + coords = _one_deg_lon_coords() + result = roll_lon_coords(coords, 0, 0.0) + assert torch.equal(result, coords) + + +def test_roll_lon_data_shifts_correctly(): + """Rolling data by r positions moves index r to index 0.""" + n = 8 + tensor = torch.arange(n, dtype=torch.float).unsqueeze(0) # shape (1, 8) + roll_amount = 3 + rolled = roll_lon_data(tensor, roll_amount, lon_dim=-1) + assert rolled.shape == tensor.shape + assert rolled[0, 0].item() == pytest.approx(3.0) # original index 3 → 0 + assert rolled[0, -1].item() == pytest.approx(2.0) # original index 2 → last + + +def test_roll_lon_data_zero_roll_returns_original(): + tensor = torch.randn(4, 8) + result = roll_lon_data(tensor, 0) + assert torch.equal(result, tensor) + + def test_paired_shuffle(): a = np.arange(5) b = a * 10 @@ -49,6 +111,66 @@ def test_adjust_fine_coord_range(downscale_factor, lat_range): assert len(subsel_fine_lat) / len(subsel_coarse_lat) == downscale_factor +def test_adjust_fine_coord_range_wrapping_lon(): + """ + adjust_fine_coord_range handles negative-start lon intervals (wrapping domain). + """ + downscale_factor = 2 # n_half_fine = 1 + # Coarse 1-degree global grid (0.5 to 359.5), subset and already rolled + # to shifted convention: coarse domain -2 to 3 (= 358 to 3 in 0-360) + coarse_lon = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + # Fine 0.5-degree global grid in 0-360 convention (0.25 to 359.75) + fine_lon = torch.arange(0.25, 360.0, 0.5) + lon_range = ClosedInterval(-2.5, 3.5) + + result = adjust_fine_coord_range( + lon_range, + full_coarse_coord=coarse_lon, + full_fine_coord=fine_lon, + downscale_factor=downscale_factor, + ) + # result should cover from 1 fine step below coarse_min=-2 to 1 fine step above + # coarse_max=3 + # coarse_min=-2 → 358 in 0-360; fine just below 358 is 357.75 → shifted = -2.25 + # coarse_max=3; fine just above 3 is 3.25 + assert result.start == pytest.approx(-2.25) + assert result.stop == pytest.approx(3.25) + + # Verify the returned interval properly covers the coarse points with + # n_half_fine padding + subsel_coarse = coarse_lon[ + (coarse_lon >= lon_range.start) & (coarse_lon <= lon_range.stop) + ] + assert len(subsel_coarse) == len(coarse_lon) # all 6 coarse points selected + + # The fine interval spans n_coarse * downscale_factor fine points + # After rolling to match the shifted convention, subset should be contiguous + from fme.downscaling.data.utils import compute_lon_roll, roll_lon_coords + + fine_roll = compute_lon_roll(fine_lon, result.start) + rolled_fine = roll_lon_coords(fine_lon, fine_roll, result.start) + subsel_fine = result.subset_of(rolled_fine) + assert len(subsel_fine) == len(subsel_coarse) * downscale_factor + + +def test_adjust_fine_coord_range_negative_lat_not_confused_with_wrapping(): + """Southern hemisphere lat (negative coarse_min) must not trigger lon-wrap logic.""" + downscale_factor = 2 + coarse_lat = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + # Fine lat grid also spans negative values — NOT a 0-360 lon grid + fine_lat = torch.arange(-3.25, 3.5, 0.5) + lat_range = ClosedInterval(-3.5, 3.5) + result = adjust_fine_coord_range( + lat_range, + full_coarse_coord=coarse_lat, + full_fine_coord=fine_lat, + downscale_factor=downscale_factor, + ) + # Standard non-wrapping behaviour: fine_min should be below coarse_min=-3 + assert result.start < -3.0 + assert result.stop > 3.0 + + def test_adjust_fine_coord_range_raises_near_domain_boundary(): downscale_factor = 4 # n_half_fine = 2 coarse_edges = torch.linspace(0, 6, 7) diff --git a/fme/downscaling/data/utils.py b/fme/downscaling/data/utils.py index 425011a53..b47f315d8 100644 --- a/fme/downscaling/data/utils.py +++ b/fme/downscaling/data/utils.py @@ -16,6 +16,76 @@ def null_generator(num: int): yield None +def compute_lon_roll(lon_coords: torch.Tensor, lon_start: float) -> int: + """ + Compute the number of positions to roll lon_coords (leftward) so that an + interval starting at lon_start becomes contiguous in the rolled array. + + Returns 0 when lon_start is already representable without wrapping (i.e., + lon_start is within one period of the coordinate origin). + + Args: + lon_coords: 1-D tensor of monotonically increasing longitudes (e.g. 0–360°). + lon_start: Start of the desired longitude interval. + + Returns: + Roll amount r such that ``torch.roll(lon_coords, -r)`` places the first + coordinate >= (lon_start % 360) at index 0 of the rolled array. + """ + start_360 = lon_start % 360 + if abs(lon_start - start_360) < 1e-6: + return 0 + return int((lon_coords < start_360).sum().item()) + + +def roll_lon_coords( + lon_coords: torch.Tensor, roll_amount: int, lon_start: float +) -> torch.Tensor: + """ + Roll longitude coordinates and adjust values to maintain a monotonically + increasing sequence whose first element is consistent with lon_start's + sign convention. + + After rolling, coordinates that wrapped around the end of the array have + 360 added, then a global offset (a multiple of 360) shifts the whole array + so that rolled_coords[0] ≈ lon_start. + + Args: + lon_coords: 1-D tensor of monotonically increasing longitudes. + roll_amount: Value returned by :func:`compute_lon_roll`. + lon_start: The lon_start used when computing roll_amount. + + Returns: + A new tensor of the same shape with monotonically increasing values. + """ + if roll_amount == 0: + return lon_coords + n = len(lon_coords) + rolled = torch.roll(lon_coords, -roll_amount).clone() + rolled[n - roll_amount :] += 360.0 + offset = lon_start - (lon_start % 360) + return rolled + offset + + +def roll_lon_data( + tensor: torch.Tensor, roll_amount: int, lon_dim: int = -1 +) -> torch.Tensor: + """Roll a data tensor along its longitude dimension by roll_amount positions.""" + if roll_amount == 0: + return tensor + return torch.roll(tensor, -roll_amount, dims=lon_dim) + + +def roll_latlon_coords( + coords: LatLonCoordinates, roll_amount: int, lon_start: float +) -> LatLonCoordinates: + """Return a new LatLonCoordinates with lon rolled by roll_amount.""" + return LatLonCoordinates( + lat=coords.lat, + lon=roll_lon_coords(coords.lon, roll_amount, lon_start), + ) + + @dataclasses.dataclass class ClosedInterval: """ @@ -147,7 +217,14 @@ def adjust_fine_coord_range( coarse_min = full_coarse_coord[full_coarse_coord >= coord_range.start][0] coarse_max = full_coarse_coord[full_coarse_coord <= coord_range.stop][-1] - n_fine_below = int((full_fine_coord < coarse_min).sum()) + # Detect a wrapping longitude domain: coord_range uses a shifted convention + # (coarse_min is negative) while the fine coord array is in 0–360°. + # Fine points "just before" the wrapped edge live at the high end of the + # unshifted fine array (e.g. near 270° when the edge is at -90° / 270°). + is_lon_wrap = float(coarse_min) < 0 and float(full_fine_coord.min()) >= 0 + fine_ref = float(coarse_min) + (360.0 if is_lon_wrap else 0.0) + + n_fine_below = int((full_fine_coord < fine_ref).sum()) n_fine_above = int((full_fine_coord > coarse_max).sum()) if n_fine_below < n_half_fine or n_fine_above < n_half_fine: raise ValueError( @@ -159,8 +236,9 @@ def adjust_fine_coord_range( f"the domain edges." ) - fine_min = full_fine_coord[full_fine_coord < coarse_min][-n_half_fine] - fine_max = full_fine_coord[full_fine_coord > coarse_max][n_half_fine - 1] + fine_min_ref = float(full_fine_coord[full_fine_coord < fine_ref][-n_half_fine]) + fine_min = fine_min_ref - (360.0 if is_lon_wrap else 0.0) + fine_max = float(full_fine_coord[full_fine_coord > coarse_max][n_half_fine - 1]) return ClosedInterval(start=fine_min, stop=fine_max) diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index d46bc7440..418022e2c 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -12,7 +12,7 @@ from fme.core.logging_utils import LoggingConfig from ..data import DataLoaderConfig -from ..models import CheckpointModelConfig, DiffusionModel +from ..models import CheckpointModelConfig, DiffusionModel, lon_rolled_model from ..predictors import ( DenoisingMoEConfig, DenoisingMoEPredictor, @@ -107,19 +107,17 @@ def run_output_generation(self, output: DownscalingOutput): """Execute the generation loop for this output.""" logging.info(f"Generating downscaled outputs for output: {output.name}") - # initialize writer and model in loop for coord info - model = None + coarse_coords = output.data.coarse_latlon_coords + input_shape = (len(coarse_coords.lat), len(coarse_coords.lon)) + model = self._get_generation_model(input_shape=input_shape, output=output) + if isinstance(model, DiffusionModel): + model = lon_rolled_model(model, coarse_coords.lon) + writer = None total_batches = len(output.data.loader) loaded_item: LoadedSliceWorkItem for i, loaded_item in enumerate(output.data.get_generator()): - input_shape = loaded_item.batch.horizontal_shape - if model is None: - model = self._get_generation_model( - input_shape=input_shape, output=output - ) - if writer is None: fine_latlon_coords = model.get_fine_coords_for_batch(loaded_item.batch) writer = output.get_writer( diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index 2af28e27e..3d9c98265 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -268,6 +268,7 @@ def _build_gridded_data( all_times=xr_dataset.sample_start_times, dtype=slice_dataset.dtype, max_output_shape=slice_dataset.max_output_shape, + coarse_latlon_coords=dataset.latlon_coordinates, ) def _build( diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index cef49d888..993c69e21 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -202,6 +202,10 @@ def test_run_target_generation_skips_padding_items( mock_work_item.batch.lat_interval = ClosedInterval(1.0, 8.0) mock_work_item.batch.lon_interval = ClosedInterval(1.0, 8.0) mock_output_target.data.get_generator.return_value = iter([mock_work_item]) + mock_output_target.data.coarse_latlon_coords = LatLonCoordinates( + lat=torch.arange(16, dtype=torch.float32), + lon=torch.arange(16, dtype=torch.float32), + ) mock_model.downscale_factor = 2 mock_model.static_inputs.coords.lat = torch.arange(0, 18).float() diff --git a/fme/downscaling/inference/work_items.py b/fme/downscaling/inference/work_items.py index 25049cf5b..1ed7f5fdd 100644 --- a/fme/downscaling/inference/work_items.py +++ b/fme/downscaling/inference/work_items.py @@ -6,6 +6,7 @@ import xarray as xr from torch.utils.data import DataLoader +from fme.core.coordinates import LatLonCoordinates from fme.core.dataset.data_typing import VariableMetadata from fme.core.distributed import Distributed from fme.core.generics.data import SizedMap @@ -297,6 +298,7 @@ class SliceWorkItemGriddedData: all_times: xr.CFTimeIndex dtype: torch.dtype max_output_shape: tuple[int, ...] + coarse_latlon_coords: LatLonCoordinates # TODO: currently no protocol or ABC for gridded data objects # if we want to unify, we will need one and just raise diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index a0634d808..06f68f161 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -21,7 +21,9 @@ PairedBatchData, StaticInputs, adjust_fine_coord_range, + compute_lon_roll, load_coords_from_path, + roll_latlon_coords, ) from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector @@ -721,6 +723,30 @@ def metadata(self): ) +def lon_rolled_model(model: DiffusionModel, coarse_lon: torch.Tensor) -> DiffusionModel: + """ + Return a shallow copy of model with full_fine_coords and static_inputs pre-rolled + to match the longitude convention of coarse_lon. Shares all neural network weights. + Returns model unchanged when no roll is needed. + """ + import copy + + lon_start = coarse_lon.min().item() + roll_amount = compute_lon_roll(model.full_fine_coords.lon, lon_start) + if roll_amount == 0: + return model + rolled_fine = roll_latlon_coords(model.full_fine_coords, roll_amount, lon_start) + rolled_static = ( + model.static_inputs.roll(roll_amount, lon_start) + if model.static_inputs is not None + else None + ) + result = copy.copy(model) + result.full_fine_coords = rolled_fine + result.static_inputs = rolled_static + return result + + @dataclasses.dataclass class _CheckpointModelConfigSelector: wrapper: DiffusionModelConfig diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 9566d3403..7550fd278 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -25,6 +25,7 @@ _build_variable_loss_weight_tensor, _repeat_batch_by_samples, _separate_interleaved_samples, + lon_rolled_model, ) from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.noise import LogNormalNoiseDistribution @@ -532,6 +533,66 @@ def test_get_fine_coords_for_batch(): assert torch.allclose(result.lon, expected_lon) +def test_lon_rolled_model_no_roll_returns_same(): + """lon_rolled_model returns the original model when no roll is needed.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + static_inputs = make_static_inputs(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=static_inputs.coords, + static_inputs=static_inputs, + ) + # coarse_lon with min >= 0: no roll needed. + coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + result = lon_rolled_model(model, coarse_lon) + assert result is model + + +def test_lon_rolled_model_shifts_coords_and_shares_weights(): + """lon_rolled_model produces a shallow copy with rolled coords but shared module.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + + # Use a global-covering lon grid so that compute_lon_roll returns a non-trivial + # roll amount for negative lon_start (32 cells × 11.25° = 360°). + step = 360 / 32 + global_fine_lon = torch.arange(32) * step + step / 2 # 5.625, 16.875, ..., 354.375 + global_fine_lat = _get_monotonic_coordinate(fine_shape[0], stop=fine_shape[0]) + full_fine_coords = LatLonCoordinates(lat=global_fine_lat, lon=global_fine_lon) + static_field = torch.arange( + fine_shape[0] * fine_shape[1], dtype=torch.float32 + ).reshape(*fine_shape) + static_inputs = StaticInputs( + fields=[StaticInput(static_field)], coords=full_fine_coords + ) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + + # coarse_lon with negative min (e.g. -10°) triggers a roll. + coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) + rolled = lon_rolled_model(model, coarse_lon) + + # Shallow copy: module weights are shared + assert rolled.module is model.module + # Coords changed — lon was rolled + assert not torch.equal(rolled.full_fine_coords.lon, model.full_fine_coords.lon) + # Rolled fine lon should be monotonically increasing + assert torch.all(rolled.full_fine_coords.lon[1:] > rolled.full_fine_coords.lon[:-1]) + # First rolled fine lon value should be negative (matching the shifted convention) + assert rolled.full_fine_coords.lon[0].item() < 0 + # Static inputs are also rolled + assert rolled.static_inputs is not None # for mypy + assert not torch.equal( + rolled.static_inputs.fields[0].data, static_inputs.fields[0].data + ) + + def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig(