Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fme/downscaling/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
ClosedInterval,
LatLonCoordinates,
adjust_fine_coord_range,
compute_lon_roll,
expand_and_fold_tensor,
roll_latlon_coords,
roll_lon_coords,
roll_lon_data,
scale_tuple,
)
39 changes: 20 additions & 19 deletions fme/downscaling/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
BatchedLatLonCoordinates,
ClosedInterval,
check_leading_dim,
compute_lon_roll,
expand_and_fold_tensor,
get_offset,
paired_shuffle,
roll_lon_coords,
roll_lon_data,
scale_tuple,
)

Expand Down Expand Up @@ -121,28 +124,22 @@ def __init__(

self._orig_coords: LatLonCoordinates = properties.horizontal_coordinates

if (self.lon_interval.stop != float("inf")) and (
torch.any(self._orig_coords.lon < self.lon_interval.stop - 360.0)
):
lon_max = self._orig_coords.lon.max()
raise NotImplementedError(
f"lon wraparound not implemented, received lon_max {lon_max} but "
f"expected lon_max > {self.lon_interval.stop - 360.0}"
)
if (self.lon_interval.start != -float("inf")) and (
torch.any(self._orig_coords.lon > self.lon_interval.start + 360.0)
):
lon_min = self._orig_coords.lon.min()
raise NotImplementedError(
f"lon wraparound not implemented, received lon_min {lon_min} but "
f"expected lon_min < {self.lon_interval.start + 360.0}"
)
# Detect whether the requested lon interval spans the 0°/360° seam of the
# stored coordinates (e.g. start=-5, stop=3 on 0–360° data). When it does,
# roll the data along the longitude axis so the interval is contiguous.
lon_start = (
self.lon_interval.start if self.lon_interval.start != -float("inf") else 0.0
)
self._lon_roll_amount = compute_lon_roll(self._orig_coords.lon, lon_start)
rolled_lon = roll_lon_coords(
self._orig_coords.lon, self._lon_roll_amount, lon_start
)

self._lats_slice = self.lat_interval.slice_from(self._orig_coords.lat)
self._lons_slice = self.lon_interval.slice_from(self._orig_coords.lon)
self._lons_slice = self.lon_interval.slice_from(rolled_lon)
self._latlon_coordinates = LatLonCoordinates(
lat=self._orig_coords.lat[self._lats_slice],
lon=self._orig_coords.lon[self._lons_slice],
lon=rolled_lon[self._lons_slice],
)
self._area_weights = self._latlon_coordinates.area_weights

Expand All @@ -168,7 +165,7 @@ def __len__(self):
def __getitem__(self, key) -> DatasetItem:
batch, times, _, epoch = self.dataset[key]
batch = {
k: v[
k: roll_lon_data(v, self._lon_roll_amount)[
...,
self._lats_slice,
self._lons_slice,
Expand Down Expand Up @@ -211,6 +208,10 @@ def __getitem__(self, idx) -> BatchItem:

return BatchItem(fields, time.squeeze(), self._coordinates)

@property
def latlon_coordinates(self) -> LatLonCoordinates:
return self._coordinates

@property
def variable_metadata(self) -> dict[str, VariableMetadata]:
if self._properties is None:
Expand Down
22 changes: 20 additions & 2 deletions fme/downscaling/data/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fme.core.coordinates import LatLonCoordinates
from fme.core.device import get_device
from fme.downscaling.data.utils import ClosedInterval
from fme.downscaling.data.utils import ClosedInterval, roll_lon_coords, roll_lon_data


@dataclasses.dataclass
Expand Down Expand Up @@ -163,13 +163,31 @@ def subset(
lat_slice = lat_interval.slice_from(self.coords.lat)
lon_slice = lon_interval.slice_from(self.coords.lon)
return StaticInputs(
fields=[field.subset(lat_slice, lon_slice) for field in self.fields],
fields=[
StaticInput(data=field.data[lat_slice, lon_slice])
for field in self.fields
],
coords=LatLonCoordinates(
lat=lat_interval.subset_of(self.coords.lat),
lon=lon_interval.subset_of(self.coords.lon),
),
)

def roll(self, roll_amount: int, lon_start: float) -> "StaticInputs":
"""
Roll the data and lon coordinates of the StaticInputs by the specified amount.
"""
if roll_amount == 0:
return self
rolled_lon = roll_lon_coords(self.coords.lon, roll_amount, lon_start)
return StaticInputs(
fields=[
StaticInput(data=roll_lon_data(f.data, roll_amount, lon_dim=-1))
for f in self.fields
],
coords=LatLonCoordinates(lat=self.coords.lat, lon=rolled_lon),
)

def to_device(self) -> "StaticInputs":
return StaticInputs(
fields=[field.to_device() for field in self.fields],
Expand Down
45 changes: 45 additions & 0 deletions fme/downscaling/data/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,51 @@ def test_horizontal_subset(
assert dataset.subset_latlon_coordinates.lon.shape == (expected_n_lon,)


def test_horizontal_subset_prime_meridian_spanning():
"""HorizontalSubsetDataset must handle lon intervals that cross 0°/360°."""
# 8-point longitude grid in 0–360° convention, 45° spacing
lons = torch.tensor([0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0])
n_lat, n_lon = 4, 8
coords = LatLonCoordinates(
lat=torch.linspace(0.0, 1.0, n_lat),
lon=lons,
)
# Data: lon index encodes original position so we can verify the roll
data_tensor = torch.arange(n_lon, dtype=torch.float).unsqueeze(0).unsqueeze(0)
data_tensor = data_tensor.expand(1, 1, n_lat, n_lon).clone()

datum: tuple[dict[str, torch.Tensor], xr.DataArray, set[str], int] = (
{"x": data_tensor},
xr.DataArray([0.0]),
set(),
0,
)
base_dataset = MagicMock(spec=torch.utils.data.Dataset)
properties = MagicMock(spec=DatasetProperties)
properties.horizontal_coordinates = coords
properties.all_labels = MagicMock(spec=set)
base_dataset.__getitem__.return_value = datum

# Interval [-90, 45] spans 0° on a 0–360° grid (270°→45° going through 0°)
dataset = HorizontalSubsetDataset(
dataset=base_dataset,
properties=properties,
lat_interval=ClosedInterval(float("-inf"), float("inf")),
lon_interval=ClosedInterval(-90.0, 45.0),
)

# Expect 4 lon points: 270°→-90°, 315°→-45°, 0°→0°, 45°→45°
assert dataset.subset_latlon_coordinates.lon.shape == (4,)
expected_lons = torch.tensor([-90.0, -45.0, 0.0, 45.0])
assert torch.allclose(dataset.subset_latlon_coordinates.lon, expected_lons)

subset, _, _, _ = dataset[0]
assert subset["x"].shape == (1, 1, n_lat, 4)
# Data values should correspond to original lon indices 6, 7, 0, 1
expected_vals = torch.tensor([6.0, 7.0, 0.0, 1.0])
assert torch.allclose(subset["x"][0, 0, 0], expected_vals)


def test_batch_data_from_sequence():
num_items = 3
items = get_batch_items(num_items=num_items)
Expand Down
38 changes: 38 additions & 0 deletions fme/downscaling/data/test_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,41 @@ def test__load_coords_from_ds():
ds = xr.Dataset(coords={"x": lon, "y": lat})
with pytest.raises(ValueError):
_load_coords_from_ds(ds)


def test_StaticInputs_roll_shifts_data_and_coords():
"""StaticInputs.roll produces correctly rolled data and monotonic shifted coords."""
from fme.core.coordinates import LatLonCoordinates

# 1-degree global grid: 0.5, 1.5, ..., 359.5
n_lon = 360
lon = torch.arange(n_lon, dtype=torch.float32) + 0.5
lat = torch.tensor([0.5], dtype=torch.float32)
coords = LatLonCoordinates(lat=lat, lon=lon)
data = torch.arange(n_lon, dtype=torch.float32).unsqueeze(0) # values = index
static = StaticInputs(fields=[StaticInput(data)], coords=coords)

# Roll so domain starts at -5.5°: start_360=354.5, roll=(coords<354.5).sum()=354
roll_amount = 354
lon_start = -5.5
rolled = static.roll(roll_amount, lon_start)

# Data: original index 354 (value=354) should be at index 0 after rolling
assert rolled.fields[0].data[0, 0].item() == pytest.approx(354.0)
assert rolled.fields[0].data[0, -1].item() == pytest.approx(353.0)

# Coordinates should be monotonically increasing and start near lon_start
assert torch.all(rolled.coords.lon[1:] > rolled.coords.lon[:-1])
assert rolled.coords.lon[0].item() == pytest.approx(354.5 - 360) # = -5.5


def test_StaticInputs_roll_zero_returns_self():
from fme.core.coordinates import LatLonCoordinates

coords = LatLonCoordinates(
lat=torch.tensor([0.0]),
lon=torch.tensor([0.5, 1.5, 2.5]),
)
data = torch.zeros(1, 3)
static = StaticInputs(fields=[StaticInput(data)], coords=coords)
assert static.roll(0, 0.0) is static
122 changes: 122 additions & 0 deletions fme/downscaling/data/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from fme.downscaling.data.utils import (
ClosedInterval,
adjust_fine_coord_range,
compute_lon_roll,
paired_shuffle,
roll_lon_coords,
roll_lon_data,
scale_slice,
)

Expand All @@ -19,6 +22,65 @@ def _fine_midpoints(coarse_edges, downscale_factor):
return torch.concatenate(fine_mids)


def _one_deg_lon_coords():
"""0.5, 1.5, ..., 359.5 — standard 1-degree global grid."""
return torch.arange(0.5, 360.0, 1.0)


@pytest.mark.parametrize(
"lon_start, expected_roll",
[
pytest.param(0.0, 0, id="zero_no_roll"),
pytest.param(10.0, 0, id="positive_in_range_no_roll"),
pytest.param(-5.0, 355, id="negative_start_rolls"),
pytest.param(-0.5, 359, id="just_negative_rolls"),
pytest.param(-180.0, 180, id="half_period"),
pytest.param(360.0, 0, id="exactly_360_no_roll"),
],
)
def test_compute_lon_roll(lon_start, expected_roll):
coords = _one_deg_lon_coords()
assert compute_lon_roll(coords, lon_start) == expected_roll


def test_roll_lon_coords_negative_start():
"""Rolled coords for lon_start=-5 should be monotone and start near -5."""
coords = _one_deg_lon_coords()
roll_amount = compute_lon_roll(coords, -5.0)
rolled = roll_lon_coords(coords, roll_amount, -5.0)

assert rolled.shape == coords.shape
# monotonically increasing
assert torch.all(rolled[1:] > rolled[:-1])
# first element is the first coord >= 355, shifted to negative convention
assert rolled[0].item() == pytest.approx(-4.5)
# last element of the original-convention "low" portion, shifted by -360
assert rolled[-1].item() == pytest.approx(354.5)


def test_roll_lon_coords_zero_roll_returns_original():
coords = _one_deg_lon_coords()
result = roll_lon_coords(coords, 0, 0.0)
assert torch.equal(result, coords)


def test_roll_lon_data_shifts_correctly():
"""Rolling data by r positions moves index r to index 0."""
n = 8
tensor = torch.arange(n, dtype=torch.float).unsqueeze(0) # shape (1, 8)
roll_amount = 3
rolled = roll_lon_data(tensor, roll_amount, lon_dim=-1)
assert rolled.shape == tensor.shape
assert rolled[0, 0].item() == pytest.approx(3.0) # original index 3 → 0
assert rolled[0, -1].item() == pytest.approx(2.0) # original index 2 → last


def test_roll_lon_data_zero_roll_returns_original():
tensor = torch.randn(4, 8)
result = roll_lon_data(tensor, 0)
assert torch.equal(result, tensor)


def test_paired_shuffle():
a = np.arange(5)
b = a * 10
Expand Down Expand Up @@ -49,6 +111,66 @@ def test_adjust_fine_coord_range(downscale_factor, lat_range):
assert len(subsel_fine_lat) / len(subsel_coarse_lat) == downscale_factor


def test_adjust_fine_coord_range_wrapping_lon():
"""
adjust_fine_coord_range handles negative-start lon intervals (wrapping domain).
"""
downscale_factor = 2 # n_half_fine = 1
# Coarse 1-degree global grid (0.5 to 359.5), subset and already rolled
# to shifted convention: coarse domain -2 to 3 (= 358 to 3 in 0-360)
coarse_lon = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0])
# Fine 0.5-degree global grid in 0-360 convention (0.25 to 359.75)
fine_lon = torch.arange(0.25, 360.0, 0.5)
lon_range = ClosedInterval(-2.5, 3.5)

result = adjust_fine_coord_range(
lon_range,
full_coarse_coord=coarse_lon,
full_fine_coord=fine_lon,
downscale_factor=downscale_factor,
)
# result should cover from 1 fine step below coarse_min=-2 to 1 fine step above
# coarse_max=3
# coarse_min=-2 → 358 in 0-360; fine just below 358 is 357.75 → shifted = -2.25
# coarse_max=3; fine just above 3 is 3.25
assert result.start == pytest.approx(-2.25)
assert result.stop == pytest.approx(3.25)

# Verify the returned interval properly covers the coarse points with
# n_half_fine padding
subsel_coarse = coarse_lon[
(coarse_lon >= lon_range.start) & (coarse_lon <= lon_range.stop)
]
assert len(subsel_coarse) == len(coarse_lon) # all 6 coarse points selected

# The fine interval spans n_coarse * downscale_factor fine points
# After rolling to match the shifted convention, subset should be contiguous
from fme.downscaling.data.utils import compute_lon_roll, roll_lon_coords

fine_roll = compute_lon_roll(fine_lon, result.start)
rolled_fine = roll_lon_coords(fine_lon, fine_roll, result.start)
subsel_fine = result.subset_of(rolled_fine)
assert len(subsel_fine) == len(subsel_coarse) * downscale_factor


def test_adjust_fine_coord_range_negative_lat_not_confused_with_wrapping():
"""Southern hemisphere lat (negative coarse_min) must not trigger lon-wrap logic."""
downscale_factor = 2
coarse_lat = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0])
# Fine lat grid also spans negative values — NOT a 0-360 lon grid
fine_lat = torch.arange(-3.25, 3.5, 0.5)
lat_range = ClosedInterval(-3.5, 3.5)
result = adjust_fine_coord_range(
lat_range,
full_coarse_coord=coarse_lat,
full_fine_coord=fine_lat,
downscale_factor=downscale_factor,
)
# Standard non-wrapping behaviour: fine_min should be below coarse_min=-3
assert result.start < -3.0
assert result.stop > 3.0


def test_adjust_fine_coord_range_raises_near_domain_boundary():
downscale_factor = 4 # n_half_fine = 2
coarse_edges = torch.linspace(0, 6, 7)
Expand Down
Loading
Loading