Skip to content

Commit 61178c3

Browse files
committed
Add overview resampling options: nearest, min, max, median, mode, cubic
_make_overview() now accepts a method parameter instead of hardcoding 2x2 block averaging. Available methods: - mean (default): nanmean of each 2x2 block, same as before - nearest: top-left pixel of each block (no interpolation) - min/max: nanmin/nanmax of each block - median: nanmedian of each block - mode: most frequent value per block (for classified rasters) - cubic: scipy.ndimage.zoom with order=3 (requires scipy) All methods work on both 2D and 3D (multi-band) arrays. Exposed via overview_resampling= parameter on write() and write_geotiff(). 12 new tests covering each method, NaN handling, multi-band, COG round-trips with nearest and mode, the public API, and error on invalid method names.
1 parent f90791f commit 61178c3

File tree

3 files changed

+226
-24
lines changed

3 files changed

+226
-24
lines changed

xrspatial/geotiff/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
186186
tile_size: int = 256,
187187
predictor: bool = False,
188188
cog: bool = False,
189-
overview_levels: list[int] | None = None) -> None:
189+
overview_levels: list[int] | None = None,
190+
overview_resampling: str = 'mean') -> None:
190191
"""Write data as a GeoTIFF or Cloud Optimized GeoTIFF.
191192
192193
Parameters
@@ -211,6 +212,9 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
211212
Write as Cloud Optimized GeoTIFF.
212213
overview_levels : list[int] or None
213214
Overview decimation factors. Only used when cog=True.
215+
overview_resampling : str
216+
Resampling method for overviews: 'mean' (default), 'nearest',
217+
'min', 'max', 'median', 'mode', or 'cubic'.
214218
"""
215219
geo_transform = None
216220
epsg = crs
@@ -243,6 +247,7 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
243247
predictor=predictor,
244248
cog=cog,
245249
overview_levels=overview_levels,
250+
overview_resampling=overview_resampling,
246251
raster_type=raster_type,
247252
)
248253

xrspatial/geotiff/_writer.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,40 +69,85 @@ def _compression_tag(compression_name: str) -> int:
6969
return _map[name]
7070

7171

72-
def _make_overview(arr: np.ndarray) -> np.ndarray:
73-
"""Generate a 2x decimated overview using 2x2 block averaging.
72+
OVERVIEW_METHODS = ('mean', 'nearest', 'min', 'max', 'median', 'mode', 'cubic')
73+
74+
75+
def _block_reduce_2d(arr2d, method):
76+
"""2x block-reduce a single 2D plane using *method*."""
77+
h, w = arr2d.shape
78+
h2 = (h // 2) * 2
79+
w2 = (w // 2) * 2
80+
cropped = arr2d[:h2, :w2]
81+
oh, ow = h2 // 2, w2 // 2
82+
83+
if method == 'nearest':
84+
# Top-left pixel of each 2x2 block
85+
return cropped[::2, ::2].copy()
86+
87+
if method == 'cubic':
88+
try:
89+
from scipy.ndimage import zoom
90+
except ImportError:
91+
raise ImportError(
92+
"scipy is required for cubic overview resampling. "
93+
"Install it with: pip install scipy")
94+
return zoom(arr2d, 0.5, order=3).astype(arr2d.dtype)
95+
96+
if method == 'mode':
97+
# Most-common value per 2x2 block (useful for classified rasters)
98+
blocks = cropped.reshape(oh, 2, ow, 2).transpose(0, 2, 1, 3).reshape(oh, ow, 4)
99+
out = np.empty((oh, ow), dtype=arr2d.dtype)
100+
for r in range(oh):
101+
for c in range(ow):
102+
vals, counts = np.unique(blocks[r, c], return_counts=True)
103+
out[r, c] = vals[counts.argmax()]
104+
return out
105+
106+
# Block reshape for mean/min/max/median
107+
if arr2d.dtype.kind == 'f':
108+
blocks = cropped.reshape(oh, 2, ow, 2)
109+
else:
110+
blocks = cropped.astype(np.float64).reshape(oh, 2, ow, 2)
111+
112+
if method == 'mean':
113+
result = np.nanmean(blocks, axis=(1, 3))
114+
elif method == 'min':
115+
result = np.nanmin(blocks, axis=(1, 3))
116+
elif method == 'max':
117+
result = np.nanmax(blocks, axis=(1, 3))
118+
elif method == 'median':
119+
flat = blocks.transpose(0, 2, 1, 3).reshape(oh, ow, 4)
120+
result = np.nanmedian(flat, axis=2)
121+
else:
122+
raise ValueError(
123+
f"Unknown overview resampling method: {method!r}. "
124+
f"Use one of: {OVERVIEW_METHODS}")
125+
126+
if arr2d.dtype.kind != 'f':
127+
return np.round(result).astype(arr2d.dtype)
128+
return result.astype(arr2d.dtype)
129+
130+
131+
def _make_overview(arr: np.ndarray, method: str = 'mean') -> np.ndarray:
132+
"""Generate a 2x decimated overview.
74133
75134
Parameters
76135
----------
77136
arr : np.ndarray
78137
2D or 3D (height, width, bands) array.
138+
method : str
139+
Resampling method: 'mean' (default), 'nearest', 'min', 'max',
140+
'median', 'mode', or 'cubic'.
79141
80142
Returns
81143
-------
82144
np.ndarray
83145
Half-resolution array.
84146
"""
85-
h, w = arr.shape[:2]
86-
h2 = (h // 2) * 2
87-
w2 = (w // 2) * 2
88-
cropped = arr[:h2, :w2]
89-
90147
if arr.ndim == 3:
91-
# Multi-band: average each band independently
92-
bands = arr.shape[2]
93-
if arr.dtype.kind == 'f':
94-
blocks = cropped.reshape(h2 // 2, 2, w2 // 2, 2, bands)
95-
return np.nanmean(blocks, axis=(1, 3)).astype(arr.dtype)
96-
else:
97-
blocks = cropped.astype(np.float64).reshape(h2 // 2, 2, w2 // 2, 2, bands)
98-
return np.round(blocks.mean(axis=(1, 3))).astype(arr.dtype)
99-
else:
100-
if arr.dtype.kind == 'f':
101-
blocks = cropped.reshape(h2 // 2, 2, w2 // 2, 2)
102-
return np.nanmean(blocks, axis=(1, 3)).astype(arr.dtype)
103-
else:
104-
blocks = cropped.astype(np.float64).reshape(h2 // 2, 2, w2 // 2, 2)
105-
return np.round(blocks.mean(axis=(1, 3))).astype(arr.dtype)
148+
bands = [_block_reduce_2d(arr[:, :, b], method) for b in range(arr.shape[2])]
149+
return np.stack(bands, axis=2)
150+
return _block_reduce_2d(arr, method)
106151

107152

108153
# ---------------------------------------------------------------------------
@@ -619,6 +664,7 @@ def write(data: np.ndarray, path: str, *,
619664
predictor: bool = False,
620665
cog: bool = False,
621666
overview_levels: list[int] | None = None,
667+
overview_resampling: str = 'mean',
622668
raster_type: int = 1) -> None:
623669
"""Write a numpy array as a GeoTIFF or COG.
624670
@@ -676,7 +722,7 @@ def write(data: np.ndarray, path: str, *,
676722

677723
current = data
678724
for _ in overview_levels:
679-
current = _make_overview(current)
725+
current = _make_overview(current, method=overview_resampling)
680726
oh, ow = current.shape[:2]
681727
if tiled:
682728
o_off, o_bc, o_data = _write_tiled(current, comp_tag, predictor, tile_size)

xrspatial/geotiff/tests/test_features.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,157 @@ def test_zstd_public_api(self, tmp_path):
256256
np.testing.assert_array_equal(result.values, arr)
257257

258258

259+
# -----------------------------------------------------------------------
260+
# Overview resampling methods
261+
# -----------------------------------------------------------------------
262+
263+
class TestOverviewResampling:
264+
265+
def test_mean_default(self, tmp_path):
266+
"""Default mean resampling produces correct 2x2 block averages."""
267+
from xrspatial.geotiff._writer import _make_overview
268+
arr = np.array([[1, 3, 5, 7],
269+
[2, 4, 6, 8],
270+
[10, 20, 30, 40],
271+
[10, 20, 30, 40]], dtype=np.float32)
272+
ov = _make_overview(arr, 'mean')
273+
assert ov.shape == (2, 2)
274+
# (1+3+2+4)/4 = 2.5
275+
assert ov[0, 0] == pytest.approx(2.5)
276+
277+
def test_nearest(self, tmp_path):
278+
"""Nearest resampling picks top-left pixel of each 2x2 block."""
279+
from xrspatial.geotiff._writer import _make_overview
280+
arr = np.array([[10, 20, 30, 40],
281+
[50, 60, 70, 80],
282+
[90, 100, 110, 120],
283+
[130, 140, 150, 160]], dtype=np.uint8)
284+
ov = _make_overview(arr, 'nearest')
285+
assert ov.shape == (2, 2)
286+
assert ov[0, 0] == 10
287+
assert ov[0, 1] == 30
288+
assert ov[1, 0] == 90
289+
assert ov[1, 1] == 110
290+
291+
def test_min(self, tmp_path):
292+
from xrspatial.geotiff._writer import _make_overview
293+
arr = np.array([[10, 1, 5, 3],
294+
[20, 2, 6, 4],
295+
[30, 3, 7, 5],
296+
[40, 4, 8, 6]], dtype=np.float32)
297+
ov = _make_overview(arr, 'min')
298+
assert ov[0, 0] == pytest.approx(1.0)
299+
assert ov[0, 1] == pytest.approx(3.0)
300+
301+
def test_max(self, tmp_path):
302+
from xrspatial.geotiff._writer import _make_overview
303+
arr = np.array([[10, 1, 5, 3],
304+
[20, 2, 6, 4],
305+
[30, 3, 7, 5],
306+
[40, 4, 8, 6]], dtype=np.float32)
307+
ov = _make_overview(arr, 'max')
308+
assert ov[0, 0] == pytest.approx(20.0)
309+
assert ov[1, 1] == pytest.approx(8.0)
310+
311+
def test_median(self, tmp_path):
312+
from xrspatial.geotiff._writer import _make_overview
313+
arr = np.array([[1, 2, 10, 20],
314+
[3, 100, 30, 40],
315+
[0, 0, 0, 0],
316+
[0, 0, 0, 0]], dtype=np.float32)
317+
ov = _make_overview(arr, 'median')
318+
assert ov.shape == (2, 2)
319+
# median of [1, 2, 3, 100] = 2.5
320+
assert ov[0, 0] == pytest.approx(2.5)
321+
322+
def test_mode(self, tmp_path):
323+
"""Mode picks the most common value in each 2x2 block."""
324+
from xrspatial.geotiff._writer import _make_overview
325+
arr = np.array([[1, 1, 2, 3],
326+
[1, 2, 2, 2],
327+
[5, 5, 5, 6],
328+
[5, 7, 6, 6]], dtype=np.uint8)
329+
ov = _make_overview(arr, 'mode')
330+
assert ov[0, 0] == 1 # 1 appears 3 times
331+
assert ov[0, 1] == 2 # 2 appears 3 times
332+
assert ov[1, 0] == 5 # 5 appears 3 times
333+
assert ov[1, 1] == 6 # 6 appears 3 times
334+
335+
def test_mean_with_nan(self, tmp_path):
336+
"""Mean resampling ignores NaN values."""
337+
from xrspatial.geotiff._writer import _make_overview
338+
arr = np.array([[np.nan, 2, 4, 6],
339+
[1, 3, np.nan, 8],
340+
[10, 20, 30, 40],
341+
[10, 20, 30, 40]], dtype=np.float32)
342+
ov = _make_overview(arr, 'mean')
343+
# nanmean([nan, 2, 1, 3]) = 2.0
344+
assert ov[0, 0] == pytest.approx(2.0)
345+
346+
def test_multiband(self, tmp_path):
347+
"""Resampling works on 3D (multi-band) arrays."""
348+
from xrspatial.geotiff._writer import _make_overview
349+
arr = np.zeros((4, 4, 3), dtype=np.uint8)
350+
arr[:, :, 0] = 100
351+
arr[:, :, 1] = 200
352+
arr[:, :, 2] = 50
353+
ov = _make_overview(arr, 'mean')
354+
assert ov.shape == (2, 2, 3)
355+
assert ov[0, 0, 0] == 100
356+
assert ov[0, 0, 1] == 200
357+
assert ov[0, 0, 2] == 50
358+
359+
def test_cog_round_trip_nearest(self, tmp_path):
360+
"""COG with nearest resampling writes and reads back."""
361+
arr = np.arange(256, dtype=np.float32).reshape(16, 16)
362+
path = str(tmp_path / 'cog_nearest.tif')
363+
write(arr, path, compression='deflate', tiled=True, tile_size=8,
364+
cog=True, overview_levels=[1], overview_resampling='nearest')
365+
366+
result, _ = read_to_array(path)
367+
np.testing.assert_array_equal(result, arr)
368+
369+
def test_cog_round_trip_mode(self, tmp_path):
370+
"""COG with mode resampling for classified data."""
371+
arr = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
372+
[0, 0, 1, 1, 2, 2, 3, 3],
373+
[4, 4, 5, 5, 6, 6, 7, 7],
374+
[4, 4, 5, 5, 6, 6, 7, 7],
375+
[0, 0, 1, 1, 2, 2, 3, 3],
376+
[0, 0, 1, 1, 2, 2, 3, 3],
377+
[4, 4, 5, 5, 6, 6, 7, 7],
378+
[4, 4, 5, 5, 6, 6, 7, 7]], dtype=np.uint8)
379+
path = str(tmp_path / 'cog_mode.tif')
380+
write(arr, path, compression='deflate', tiled=True, tile_size=4,
381+
cog=True, overview_levels=[1], overview_resampling='mode')
382+
383+
# Full res should be exact
384+
result, _ = read_to_array(path)
385+
np.testing.assert_array_equal(result, arr)
386+
387+
# Overview should have mode-reduced values
388+
ov, _ = read_to_array(path, overview_level=1)
389+
assert ov.shape == (4, 4)
390+
assert ov[0, 0] == 0
391+
assert ov[0, 1] == 1
392+
393+
def test_write_geotiff_api(self, tmp_path):
394+
"""overview_resampling kwarg works through the public API."""
395+
arr = np.arange(64, dtype=np.float32).reshape(8, 8)
396+
path = str(tmp_path / 'api_nearest.tif')
397+
write_geotiff(arr, path, compression='deflate',
398+
cog=True, overview_resampling='nearest')
399+
400+
result = read_geotiff(path)
401+
np.testing.assert_array_equal(result.values, arr)
402+
403+
def test_invalid_method(self):
404+
from xrspatial.geotiff._writer import _make_overview
405+
arr = np.ones((4, 4), dtype=np.float32)
406+
with pytest.raises(ValueError, match="Unknown overview resampling"):
407+
_make_overview(arr, 'bicubic_spline')
408+
409+
259410
# -----------------------------------------------------------------------
260411
# BigTIFF write
261412
# -----------------------------------------------------------------------

0 commit comments

Comments
 (0)