Skip to content

Commit f53084b

Browse files
authored
Merge pull request #94 from wietzesuijker/fix/s2-quality-mask-downsampling-typo
fix: typo to prevent crash in quality-mask downsampling
2 parents 1151dba + fe07434 commit f53084b

2 files changed

Lines changed: 20 additions & 3 deletions

File tree

src/eopf_geozarr/s2_optimization/s2_multiscale.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,7 @@ def create_downsampled_resolution_group(source_dataset: xr.Dataset, factor: int)
598598
continue
599599
var_typ = determine_variable_type(var_name, var_data)
600600
if var_typ == "quality_mask":
601-
lazy_downsampled = (
602-
var_data.coarsen({"x": factor, "y": factor}, boundary="trim").max().sdyupr
603-
)
601+
lazy_downsampled = var_data.coarsen({"x": factor, "y": factor}, boundary="trim").max()
604602
elif var_typ == "reflectance":
605603
lazy_downsampled = var_data.coarsen({"x": factor, "y": factor}, boundary="trim").mean()
606604
elif var_typ == "classification":

tests/test_s2_multiscale.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from eopf_geozarr.s2_optimization.s2_multiscale import (
1515
calculate_aligned_chunk_size,
1616
calculate_simple_shard_dimensions,
17+
create_downsampled_resolution_group,
1718
create_measurements_encoding,
1819
create_multiscale_from_datatree,
1920
)
@@ -54,6 +55,24 @@ def sample_dataset() -> xr.Dataset:
5455
class TestS2MultiscaleFunctions:
5556
"""Test suite for S2 multiscale functions."""
5657

58+
def test_create_downsampled_resolution_group_quality_mask(self) -> None:
59+
"""Quality-mask downsampling should not crash and should preserve dtype."""
60+
x = np.arange(8)
61+
y = np.arange(6)
62+
quality = xr.DataArray(
63+
np.random.randint(0, 2, (6, 8), dtype=np.uint8),
64+
dims=["y", "x"],
65+
coords={"y": y, "x": x},
66+
name="quality_clouds",
67+
)
68+
ds = xr.Dataset({"quality_clouds": quality})
69+
70+
out = create_downsampled_resolution_group(ds, factor=2)
71+
72+
assert "quality_clouds" in out.data_vars
73+
assert out["quality_clouds"].dtype == np.uint8
74+
assert out["quality_clouds"].shape == (3, 4)
75+
5776
def test_calculate_simple_shard_dimensions(self) -> None:
5877
"""Test simplified shard dimensions calculation."""
5978
# Test 3D data (time, y, x) - shards are multiples of chunks

0 commit comments

Comments
 (0)