Skip to content

Commit c1a1f28

Browse files
authored
Merge pull request #152 from d-v-b/feat/include-b08-in-r10m-r20m-r60m
feat: include b08 in resolution groups
2 parents 117a53e + 2bdad1b commit c1a1f28

5 files changed

Lines changed: 2102 additions & 43 deletions

src/eopf_geozarr/s2_optimization/s2_multiscale.py

Lines changed: 115 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
XarrayDataArrayEncoding,
3636
)
3737
from eopf_geozarr.s2_optimization.common import DISTRIBUTED_AVAILABLE
38+
from eopf_geozarr.s2_optimization.s2_band_mapping import BAND_INFO
3839

3940
from .s2_resampling import determine_variable_type, downsample_variable
4041

@@ -66,6 +67,101 @@ def get_grid_spacing(ds: xr.DataArray, coords: tuple[Hashable, ...]) -> tuple[fl
6667
return tuple(np.abs(ds.coords[coord][0].data - ds.coords[coord][1].data) for coord in coords)
6768

6869

70+
def _coarsen_variable(var_name: str, var_data: xr.DataArray, factor: int) -> xr.DataArray:
71+
"""Coarsen a single variable using type-aware resampling.
72+
73+
Dispatches to the appropriate coarsen reduction (mean, max, subsample)
74+
based on `determine_variable_type`. Preserves encoding and dtype.
75+
"""
76+
var_type = determine_variable_type(var_name, var_data)
77+
coarsened = var_data.coarsen({"x": factor, "y": factor}, boundary="trim")
78+
if var_type in ("reflectance", "probability"):
79+
result = coarsened.mean()
80+
elif var_type == "classification":
81+
result = coarsened.reduce(subsample_2)
82+
elif var_type == "quality_mask":
83+
result = coarsened.max()
84+
else:
85+
raise ValueError(f"Unknown variable type {var_type}")
86+
87+
result.encoding = var_data.encoding
88+
return result.astype(var_data.dtype)
89+
90+
91+
def inject_missing_bands(
92+
dataset: xr.Dataset,
93+
dt_input: xr.DataTree,
94+
target_resolution: int,
95+
*,
96+
bands: set[str] | None = None,
97+
) -> xr.Dataset:
98+
"""Inject bands whose native resolution is finer than `target_resolution`.
99+
100+
For each spectral band defined in `BAND_INFO` whose native resolution is
101+
finer than `target_resolution`, this function checks whether the band is
102+
already present in `dataset`. If not, it looks for the band in the
103+
appropriate source group (e.g. `/measurements/reflectance/r10m`),
104+
downsamples it to the target grid using the type-aware resampling from
105+
`determine_variable_type`, and merges it into `dataset`.
106+
107+
Args:
108+
dataset: The target-resolution dataset (e.g. the r20m or r60m
109+
reflectance group).
110+
dt_input: The full input DataTree (used to locate finer-resolution
111+
source bands).
112+
target_resolution: Target resolution in metres (e.g. 20 or 60).
113+
bands: If provided, only inject these band names. If `None`
114+
(default), inject every eligible band from `BAND_INFO`.
115+
116+
Returns:
117+
`dataset` with any missing finer-resolution bands added.
118+
"""
119+
for band_name, info in BAND_INFO.items():
120+
if bands is not None and band_name not in bands:
121+
continue
122+
native_res = info.native_resolution # type: ignore[attr-defined]
123+
if native_res >= target_resolution:
124+
continue
125+
if band_name in dataset.data_vars:
126+
continue
127+
128+
source_path = f"/measurements/reflectance/r{native_res}m"
129+
if source_path not in dt_input.groups:
130+
continue
131+
132+
source_ds = dt_input[source_path].to_dataset()
133+
if band_name not in source_ds.data_vars:
134+
continue
135+
136+
band_src = source_ds[band_name]
137+
factor = target_resolution // native_res
138+
band_ds = _coarsen_variable(band_name, band_src, factor)
139+
140+
# Replace coordinates with the target dataset's coordinates so that
141+
# xarray.Dataset.assign does not try to align on mismatched values.
142+
band_ds = xr.DataArray(
143+
band_ds.values,
144+
dims=band_ds.dims,
145+
coords={d: dataset.coords[d] for d in band_ds.dims if d in dataset.coords},
146+
attrs=band_ds.attrs,
147+
name=band_name,
148+
)
149+
150+
# Preserve source encoding so downstream encoding logic can inspect it
151+
band_ds.encoding = band_src.encoding.copy()
152+
153+
dataset = dataset.assign({band_name: band_ds})
154+
log.info(
155+
"Injected downsampled band from finer resolution",
156+
band=band_name,
157+
source=f"r{native_res}m",
158+
target=f"r{target_resolution}m",
159+
shape=band_ds.shape,
160+
)
161+
162+
return dataset
163+
164+
69165
def create_multiscale_from_datatree(
70166
dt_input: xr.DataTree,
71167
*,
@@ -121,6 +217,22 @@ def create_multiscale_from_datatree(
121217
)
122218

123219
if is_measurement_group:
220+
# Inject bands whose native resolution is finer than this group's
221+
# (e.g. b08 native at 10m into r20m/r60m) so they propagate through
222+
# the full overview chain (r120m … r720m).
223+
if group_path.startswith("/measurements/reflectance/"):
224+
try:
225+
group_resolution = int(group_name[1:-1])
226+
except ValueError:
227+
group_resolution = 0
228+
if group_resolution > 10:
229+
dataset = inject_missing_bands(
230+
dataset,
231+
dt_input,
232+
group_resolution,
233+
bands={"b08"},
234+
)
235+
124236
# Measurement groups: apply custom encoding
125237
encoding = create_measurements_encoding(
126238
dataset,
@@ -732,24 +844,7 @@ def create_downsampled_resolution_group(source_dataset: xr.Dataset, factor: int)
732844
for var_name, var_data in source_dataset.data_vars.items():
733845
if var_data.ndim < 2:
734846
continue
735-
var_typ = determine_variable_type(var_name, var_data)
736-
if var_typ == "quality_mask":
737-
lazy_downsampled = var_data.coarsen({"x": factor, "y": factor}, boundary="trim").max()
738-
elif var_typ == "reflectance":
739-
lazy_downsampled = var_data.coarsen({"x": factor, "y": factor}, boundary="trim").mean()
740-
elif var_typ == "classification":
741-
lazy_downsampled = var_data.coarsen({"x": factor, "y": factor}, boundary="trim").reduce(
742-
subsample_2
743-
)
744-
elif var_typ == "probability":
745-
lazy_downsampled = var_data.coarsen({"x": factor, "y": factor}, boundary="trim").mean()
746-
else:
747-
raise ValueError(f"Unknown variable type {var_typ}")
748-
749-
# preserve encoding
750-
lazy_downsampled.encoding = var_data.encoding
751-
# Ensure that dtype is preserved
752-
lazy_vars[var_name] = lazy_downsampled.astype(var_data.dtype)
847+
lazy_vars[var_name] = _coarsen_variable(var_name, var_data, factor)
753848

754849
if not lazy_vars:
755850
return xr.Dataset()
@@ -760,8 +855,8 @@ def create_downsampled_resolution_group(source_dataset: xr.Dataset, factor: int)
760855

761856
def subsample_2(a: xr.DataArray, axis: tuple[int, ...] | None = None) -> xr.DataArray:
762857
if axis is None:
763-
return a[((slice(None, None, 2),) * a.ndim)]
764-
indexer = [slice(None, None, 2) if i in axis else slice(None) for i in range(a.ndim)]
858+
return a[((0,) * a.ndim)]
859+
indexer = [0 if i in axis else slice(None) for i in range(a.ndim)]
765860
return a[tuple(indexer)]
766861

767862

0 commit comments

Comments
 (0)