Skip to content

Commit 3c71242

Browse files
fix: derive coarse spatial transforms from coordinates (#168)
* fix: derive coarse spatial transforms from coordinates * refactor: improve function definitions for clarity and consistency
1 parent 4196a67 commit 3c71242

3 files changed

Lines changed: 129 additions & 71 deletions

File tree

src/eopf_geozarr/s2_optimization/s2_multiscale.py

Lines changed: 68 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,71 @@ def get_grid_spacing(ds: xr.DataArray, coords: tuple[Hashable, ...]) -> tuple[fl
6969
return tuple(np.abs(ds.coords[coord][0].data - ds.coords[coord][1].data) for coord in coords)
7070

7171

72+
def _transform_from_coordinates(
73+
dataset: xr.Dataset,
74+
) -> tuple[float, float, float, float, float, float] | None:
75+
"""Construct an affine transform from dataset coordinates when possible."""
76+
if "x" not in dataset.coords or "y" not in dataset.coords:
77+
return None
78+
79+
x_coords = dataset.coords["x"].values
80+
y_coords = dataset.coords["y"].values
81+
if len(x_coords) < 2 or len(y_coords) < 2:
82+
return None
83+
84+
pixel_size_x = float(np.abs(x_coords[1] - x_coords[0]))
85+
pixel_size_y = float(np.abs(y_coords[1] - y_coords[0]))
86+
x_min = float(x_coords.min())
87+
y_max = float(y_coords.max())
88+
return (pixel_size_x, 0.0, x_min, 0.0, -pixel_size_y, y_max)
89+
90+
91+
def _rio_transform_matches_coordinates(
92+
transform: tuple[float, float, float, float, float, float] | None,
93+
coordinate_transform: tuple[float, float, float, float, float, float] | None,
94+
) -> bool:
95+
"""Check whether rio-derived metadata matches the current x/y grid."""
96+
if transform is None or coordinate_transform is None:
97+
return False
98+
99+
return all(np.isclose(a, b) for a, b in zip(transform, coordinate_transform, strict=False))
100+
101+
102+
def _preferred_spatial_transform(
103+
dataset: xr.Dataset,
104+
) -> tuple[float, float, float, float, float, float] | None:
105+
"""Prefer rio metadata only when it matches the current coordinate grid."""
106+
coordinate_transform = _transform_from_coordinates(dataset)
107+
rio_transform: tuple[float, float, float, float, float, float] | None = None
108+
109+
if hasattr(dataset, "rio") and hasattr(dataset.rio, "transform"):
110+
try:
111+
rio_value = dataset.rio.transform
112+
if callable(rio_value):
113+
rio_value = rio_value()
114+
rio_values = tuple(float(value) for value in tuple(rio_value)[:6])
115+
if len(rio_values) == 6:
116+
rio_transform = (
117+
rio_values[0],
118+
rio_values[1],
119+
rio_values[2],
120+
rio_values[3],
121+
rio_values[4],
122+
rio_values[5],
123+
)
124+
except (AttributeError, TypeError, ValueError):
125+
rio_transform = None
126+
127+
if (
128+
rio_transform is not None
129+
and not all(value == 0 for value in rio_transform)
130+
and _rio_transform_matches_coordinates(rio_transform, coordinate_transform)
131+
):
132+
return rio_transform
133+
134+
return coordinate_transform or rio_transform
135+
136+
72137
def _coarsen_variable(var_name: str, var_data: xr.DataArray, factor: int) -> xr.DataArray:
73138
"""Coarsen a single variable using type-aware resampling.
74139
@@ -607,56 +672,7 @@ def add_multiscales_metadata_to_parent(
607672
first_var = next(iter(dataset.data_vars.values()))
608673
height, width = first_var.shape[-2:]
609674

610-
# Calculate spatial transform (affine transformation)
611-
transform = None
612-
if hasattr(dataset, "rio") and hasattr(dataset.rio, "transform"):
613-
try:
614-
# Try to get transform as property first
615-
rio_transform = dataset.rio.transform
616-
if callable(rio_transform):
617-
rio_transform = rio_transform()
618-
transform = tuple(rio_transform)[:6] # Get 6 coefficients
619-
log.info("Got transform from rio accessor", transform=transform, level=res_name)
620-
except (AttributeError, TypeError) as e:
621-
log.warning(
622-
"Could not get transform from rio accessor", error=str(e), level=res_name
623-
)
624-
625-
if transform is None or all(t == 0 for t in transform):
626-
# Fallback: construct from grid spacing and bounds
627-
if "x" in dataset.coords and "y" in dataset.coords:
628-
# Use coordinate arrays to calculate spacing
629-
x_coords = dataset.coords["x"].values
630-
y_coords = dataset.coords["y"].values
631-
632-
if len(x_coords) > 1 and len(y_coords) > 1:
633-
# Calculate pixel size from actual coordinate spacing
634-
pixel_size_x = float(np.abs(x_coords[1] - x_coords[0]))
635-
pixel_size_y = float(np.abs(y_coords[1] - y_coords[0]))
636-
637-
x_min = float(x_coords.min())
638-
y_max = float(y_coords.max())
639-
transform = (pixel_size_x, 0.0, x_min, 0.0, -pixel_size_y, y_max)
640-
log.info(
641-
"Calculated transform from coordinates",
642-
transform=transform,
643-
pixel_size_x=pixel_size_x,
644-
pixel_size_y=pixel_size_y,
645-
level=res_name,
646-
)
647-
else:
648-
log.warning(
649-
"Insufficient coordinate points for transform calculation",
650-
x_len=len(x_coords),
651-
y_len=len(y_coords),
652-
level=res_name,
653-
)
654-
else:
655-
log.warning(
656-
"Missing x/y coordinates for transform calculation",
657-
coords=list(dataset.coords.keys()),
658-
level=res_name,
659-
)
675+
transform = _preferred_spatial_transform(dataset)
660676

661677
# Calculate zoom level (higher resolution = higher zoom)
662678
tile_width = 256
@@ -1162,30 +1178,11 @@ def write_geo_metadata(
11621178
y_min, y_max = float(y_coords.min()), float(y_coords.max())
11631179
dataset.attrs["spatial:bbox"] = [x_min, y_min, x_max, y_max]
11641180

1165-
# Calculate spatial transform (affine transformation)
1166-
spatial_transform = None
1167-
if hasattr(dataset, "rio") and hasattr(dataset.rio, "transform"):
1168-
try:
1169-
rio_transform = dataset.rio.transform
1170-
if callable(rio_transform):
1171-
rio_transform = rio_transform()
1172-
spatial_transform = list(rio_transform)[:6]
1173-
except (AttributeError, TypeError):
1174-
# Fallback: construct from coordinate spacing
1175-
pixel_size_x = float(get_grid_spacing(dataset, ("x",))[0])
1176-
pixel_size_y = float(get_grid_spacing(dataset, ("y",))[0])
1177-
spatial_transform = [
1178-
pixel_size_x,
1179-
0.0,
1180-
x_min,
1181-
0.0,
1182-
-pixel_size_y,
1183-
y_max,
1184-
]
1181+
spatial_transform = _preferred_spatial_transform(dataset)
11851182

11861183
# Only add spatial:transform if we have valid transform data (not all zeros)
11871184
if spatial_transform is not None and not all(t == 0 for t in spatial_transform):
1188-
dataset.attrs["spatial:transform"] = spatial_transform
1185+
dataset.attrs["spatial:transform"] = list(spatial_transform)
11891186

11901187
# Add spatial shape if data variables exist
11911188
if dataset.data_vars:

tests/test_s2_multiscale.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pathlib
77
from itertools import pairwise
88
from pathlib import Path
9+
from unittest.mock import patch
910

1011
import numpy as np
1112
import pytest
@@ -19,6 +20,7 @@
1920

2021
from eopf_geozarr.s2_optimization.s2_multiscale import (
2122
_coarsen_variable,
23+
add_multiscales_metadata_to_parent,
2224
calculate_aligned_chunk_size,
2325
calculate_simple_shard_dimensions,
2426
create_downsampled_resolution_group,
@@ -56,6 +58,47 @@ def test_create_downsampled_resolution_group_quality_mask() -> None:
5658
assert out["quality_clouds"].shape == (3, 4)
5759

5860

61+
def test_add_multiscales_metadata_prefers_coordinate_transform_for_inconsistent_rio(
62+
tmp_path: pathlib.Path,
63+
) -> None:
64+
"""Derived levels should not reuse a stale rio transform."""
65+
66+
def _dataset(resolution: int, size: int, x0: float, y0: float) -> xr.Dataset:
67+
x = x0 + np.arange(size, dtype="float64") * resolution
68+
y = y0 - np.arange(size, dtype="float64") * resolution
69+
ds = xr.Dataset(
70+
{"band": (["y", "x"], np.ones((size, size), dtype=np.uint16))},
71+
coords={"x": x, "y": y},
72+
)
73+
return ds.rio.write_crs("EPSG:32631")
74+
75+
r10m = _dataset(10, 12, 600000.0, 4900020.0)
76+
r120m = _dataset(120, 3, 600030.0, 4899990.0)
77+
78+
parent_group = zarr.create_group(tmp_path / "multiscales.zarr")
79+
80+
def stale_transform() -> tuple[float, float, float, float, float, float]:
81+
return (60.0, 0.0, 600030.0, 0.0, -60.0, 4899990.0)
82+
83+
with patch.object(r120m.rio, "transform", stale_transform):
84+
add_multiscales_metadata_to_parent(
85+
parent_group,
86+
{"r10m": r10m, "r120m": r120m},
87+
multiscales_flavor={"experimental_multiscales_convention"},
88+
)
89+
90+
layout = parent_group.attrs["multiscales"]["layout"]
91+
derived_level = next(level for level in layout if level["asset"] == "r120m")
92+
assert tuple(derived_level["spatial:transform"]) == (
93+
120.0,
94+
0.0,
95+
600030.0,
96+
0.0,
97+
-120.0,
98+
4899990.0,
99+
)
100+
101+
59102
def test_calculate_simple_shard_dimensions() -> None:
60103
"""Test simplified shard dimensions calculation."""
61104
# Test 3D data (time, y, x) - shards are multiples of chunks

tests/test_s2_multiscale_geo_metadata.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,24 @@ def test_write_geo_metadata_integration_with_stream_write(self, tmp_path: Path)
249249
assert written_ds.rio.crs is not None
250250
assert written_ds.rio.crs.to_epsg() == 32632
251251

252+
def test_write_geo_metadata_prefers_coordinate_transform_for_inconsistent_rio(self) -> None:
253+
"""Derived datasets should derive spatial:transform from current coordinates."""
254+
255+
x = 600030.0 + np.arange(3, dtype="float64") * 120.0
256+
y = 4899990.0 - np.arange(3, dtype="float64") * 120.0
257+
ds = xr.Dataset(
258+
{"b01": (["y", "x"], np.ones((3, 3), dtype=np.uint16))},
259+
coords={"x": x, "y": y},
260+
).rio.write_crs("EPSG:32631")
261+
262+
def stale_transform() -> tuple[float, float, float, float, float, float]:
263+
return (60.0, 0.0, 600030.0, 0.0, -60.0, 4899990.0)
264+
265+
with patch.object(ds.rio, "transform", stale_transform):
266+
write_geo_metadata(ds)
267+
268+
assert ds.attrs["spatial:transform"] == [120.0, 0.0, 600030.0, 0.0, -120.0, 4899990.0]
269+
252270

253271
class TestWriteGeoMetadataEdgeCases:
254272
"""Test edge cases for _write_geo_metadata method."""

0 commit comments

Comments
 (0)