Skip to content

Commit 5af16e1

Browse files
authored
geotiff: forward max_pixels, window, band on GPU stripped fallback (#1732) (#1738)
1 parent 99b1817 commit 5af16e1

2 files changed

Lines changed: 203 additions & 7 deletions

File tree

xrspatial/geotiff/__init__.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2683,9 +2683,19 @@ def read_geotiff_gpu(source: str, *,
26832683
# 5-8 today (the 2/3/4 fix in #1539 is in a sibling PR). Discard
26842684
# its geo_info and apply our own transform update below so the
26852685
# result is correct regardless of merge order.
2686+
#
2687+
# Forward ``max_pixels``, ``window``, and ``band`` so the
2688+
# caller's safety cap is honoured, windowed reads avoid
2689+
# decoding the full image, and single-band selection on a
2690+
# multi-band source skips the unused channels. Without this,
2691+
# the stripped GPU path bypassed all three (issue #1732).
2692+
# Orientation != 1 + window is already rejected at line 2495,
2693+
# so ``window`` is None whenever ``geo_info`` will be remapped
2694+
# below.
26862695
src.close()
26872696
arr_cpu, _ = _read_to_array(
2688-
source, overview_level=overview_level)
2697+
source, overview_level=overview_level,
2698+
window=window, band=band, max_pixels=max_pixels)
26892699
arr_gpu = cupy.asarray(arr_cpu)
26902700
if orientation != 1:
26912701
geo_info = _apply_orientation_geo_info(
@@ -2708,12 +2718,36 @@ def read_geotiff_gpu(source: str, *,
27082718
target = np.dtype(dtype)
27092719
_validate_dtype_cast(np.dtype(str(arr_gpu.dtype)), target)
27102720
arr_gpu = arr_gpu.astype(target)
2711-
# Apply window/band slicing post-decode. The stripped CPU
2712-
# fallback already produces the full-image array; slice on the
2713-
# GPU so the result matches ``open_geotiff`` /
2714-
# ``read_geotiff_dask`` semantics.
2715-
arr_gpu, coords = _gpu_apply_window_band(
2716-
arr_gpu, geo_info, window=window, band=band)
2721+
# ``read_to_array`` already applied window + band slicing, so
2722+
# ``arr_gpu`` is at output shape. Compute coords for that
2723+
# shape without re-slicing.
2724+
if window is not None:
2725+
r0, c0, r1, c1 = window
2726+
t = geo_info.transform
2727+
if t is None:
2728+
coords = {
2729+
'y': np.arange(r1 - r0, dtype=np.int64),
2730+
'x': np.arange(c1 - c0, dtype=np.int64),
2731+
}
2732+
elif geo_info.raster_type == RASTER_PIXEL_IS_POINT:
2733+
coords = {
2734+
'x': (np.arange(c0, c1, dtype=np.float64)
2735+
* t.pixel_width + t.origin_x),
2736+
'y': (np.arange(r0, r1, dtype=np.float64)
2737+
* t.pixel_height + t.origin_y),
2738+
}
2739+
else:
2740+
coords = {
2741+
'x': (np.arange(c0, c1, dtype=np.float64)
2742+
* t.pixel_width + t.origin_x
2743+
+ t.pixel_width * 0.5),
2744+
'y': (np.arange(r0, r1, dtype=np.float64)
2745+
* t.pixel_height + t.origin_y
2746+
+ t.pixel_height * 0.5),
2747+
}
2748+
else:
2749+
coords = _geo_to_coords(
2750+
geo_info, arr_gpu.shape[0], arr_gpu.shape[1])
27172751
# Multi-band stripped reads come back as (y, x, band); mirror
27182752
# the tiled branch so dims line up with ndim. Single-band stays
27192753
# 2-D ('y', 'x').
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""Regression tests for issue #1732.
2+
3+
The stripped-TIFF fallback inside ``read_geotiff_gpu`` previously called
4+
``read_to_array(source, overview_level=overview_level)`` and threw away
5+
the caller's ``max_pixels``, ``window``, and ``band`` arguments. That
6+
meant:
7+
8+
- a user-supplied ``max_pixels`` safety cap was silently ignored on
9+
stripped files (the default ~1B pixel cap applied instead),
10+
- windowed reads decoded the entire image before slicing on the GPU, and
11+
- single-band selection on a multi-band stripped file still decoded
12+
every band on the CPU.
13+
14+
These tests assert that all three kwargs are now forwarded to
15+
``read_to_array`` so the stripped GPU path matches the contract of the
16+
tiled GPU path.
17+
"""
18+
from __future__ import annotations
19+
20+
import importlib.util
21+
import os
22+
import tempfile
23+
24+
import numpy as np
25+
import pytest
26+
import xarray as xr
27+
28+
29+
def _gpu_available() -> bool:
30+
if importlib.util.find_spec("cupy") is None:
31+
return False
32+
try:
33+
import cupy
34+
return bool(cupy.cuda.is_available())
35+
except Exception:
36+
return False
37+
38+
39+
_HAS_GPU = _gpu_available()
40+
_gpu_only = pytest.mark.skipif(
41+
not _HAS_GPU,
42+
reason="cupy + CUDA required",
43+
)
44+
45+
46+
@_gpu_only
47+
def test_stripped_max_pixels_cap_is_enforced():
48+
"""max_pixels smaller than the file must raise before full decode."""
49+
from xrspatial.geotiff import to_geotiff, read_geotiff_gpu
50+
51+
rng = np.random.RandomState(20260512)
52+
data = rng.randint(0, 200, size=(64, 96)).astype(np.uint8)
53+
da = xr.DataArray(data, dims=['y', 'x'])
54+
55+
with tempfile.TemporaryDirectory() as d:
56+
p = os.path.join(d, 'tmp_1732_cap.tif')
57+
to_geotiff(da, p, tiled=False)
58+
# 64 * 96 = 6144 pixels; cap at 1000 must reject.
59+
with pytest.raises(ValueError, match="max_pixels|pixel"):
60+
read_geotiff_gpu(p, max_pixels=1000)
61+
62+
63+
@_gpu_only
64+
def test_stripped_window_returns_only_window():
65+
"""Windowed read on a stripped file returns the window-sized array
66+
with coords and transform that match the window origin.
67+
68+
The post-decode ``_gpu_apply_window_band`` call was replaced with a
69+
coord-only computation in #1732. Compare against the CPU eager path
70+
(which is the parity reference for this exact fixture) so a
71+
regression in the coord-only branch -- or a drift in the windowed
72+
``attrs['transform']`` -- shows up here.
73+
"""
74+
from xrspatial.geotiff import to_geotiff, open_geotiff, read_geotiff_gpu
75+
76+
rng = np.random.RandomState(20260512)
77+
data = rng.randint(0, 200, size=(64, 96)).astype(np.uint8)
78+
# Explicit y/x coords give the file a real georef so the coord-only
79+
# path computes a non-trivial windowed transform / origin -- a plain
80+
# ``dims=['y','x']`` array writes a no-georef TIFF where the coord
81+
# branch degenerates to integer arange and would not catch a
82+
# transform-math regression.
83+
da = xr.DataArray(
84+
data,
85+
dims=['y', 'x'],
86+
coords={
87+
'y': np.arange(64, dtype=np.float64) * 0.5 + 100.0,
88+
'x': np.arange(96, dtype=np.float64) * 0.5 + 200.0,
89+
},
90+
attrs={'crs': 4326},
91+
)
92+
93+
with tempfile.TemporaryDirectory() as d:
94+
p = os.path.join(d, 'tmp_1732_win.tif')
95+
to_geotiff(da, p, tiled=False)
96+
win = (8, 16, 40, 80) # 32x64 window
97+
out = read_geotiff_gpu(p, window=win)
98+
assert out.shape == (32, 64)
99+
np.testing.assert_array_equal(out.data.get(), data[8:40, 16:80])
100+
101+
# Coords + transform parity vs the CPU eager path. CPU runs
102+
# through ``open_geotiff``'s own windowed-coord branch (line
103+
# ~833), so any drift between the two coord computations is
104+
# caught here.
105+
cpu = open_geotiff(p, window=win)
106+
np.testing.assert_array_equal(out.coords['y'].values,
107+
cpu.coords['y'].values)
108+
np.testing.assert_array_equal(out.coords['x'].values,
109+
cpu.coords['x'].values)
110+
# ``attrs['transform']`` carries the windowed origin (origin_x
111+
# shifted by c0 * pixel_width, origin_y by r0 * pixel_height).
112+
# Pin the exact tuple as well as parity with CPU so a regression
113+
# in either ``_populate_attrs_from_geo_info`` or the coord-only
114+
# branch is visible.
115+
assert out.attrs['transform'] == cpu.attrs['transform']
116+
# pixel_width=0.5, origin_x=200, c0=16 -> 200 + 16*0.5 = 208
117+
# pixel_height=0.5, origin_y=100-0.5*0.5=99.75 (PixelIsArea),
118+
# r0=8 -> 99.75 + 8*0.5 = 103.75
119+
# to_geotiff writes the raw geo-transform (edge origin), so:
120+
# origin_x_raw = 200 - 0.25 = 199.75; +16*0.5 = 207.75
121+
# origin_y_raw = 100 - 0.25 = 99.75; +8*0.5 = 103.75
122+
assert out.attrs['transform'] == (0.5, 0.0, 207.75,
123+
0.0, 0.5, 103.75)
124+
125+
126+
@_gpu_only
127+
def test_stripped_band_selection_returns_2d():
128+
"""Selecting band=1 on a 3-band stripped file returns a 2D array
129+
matching the requested band."""
130+
from xrspatial.geotiff import to_geotiff, read_geotiff_gpu
131+
132+
rng = np.random.RandomState(20260512)
133+
data = rng.randint(0, 200, size=(48, 80, 3)).astype(np.uint8)
134+
da = xr.DataArray(data, dims=['y', 'x', 'band'])
135+
136+
with tempfile.TemporaryDirectory() as d:
137+
p = os.path.join(d, 'tmp_1732_band.tif')
138+
to_geotiff(da, p, tiled=False)
139+
out = read_geotiff_gpu(p, band=1)
140+
assert out.dims == ('y', 'x')
141+
assert out.shape == (48, 80)
142+
np.testing.assert_array_equal(out.data.get(), data[:, :, 1])
143+
144+
145+
@_gpu_only
146+
def test_stripped_window_plus_band():
147+
"""Windowed read with band selection composes correctly."""
148+
from xrspatial.geotiff import to_geotiff, read_geotiff_gpu
149+
150+
rng = np.random.RandomState(20260512)
151+
data = rng.randint(0, 200, size=(48, 80, 3)).astype(np.uint8)
152+
da = xr.DataArray(data, dims=['y', 'x', 'band'])
153+
154+
with tempfile.TemporaryDirectory() as d:
155+
p = os.path.join(d, 'tmp_1732_wb.tif')
156+
to_geotiff(da, p, tiled=False)
157+
win = (4, 8, 36, 72) # 32x64
158+
out = read_geotiff_gpu(p, window=win, band=2)
159+
assert out.dims == ('y', 'x')
160+
assert out.shape == (32, 64)
161+
np.testing.assert_array_equal(
162+
out.data.get(), data[4:36, 8:72, 2])

0 commit comments

Comments
 (0)