Skip to content

Commit 6cc947b

Browse files
committed
Fix open_geotiff(gpu=True) silently dropping on_gpu_failure kwarg (#1615)
open_geotiff's GPU branch called read_geotiff_gpu without forwarding on_gpu_failure, so callers wanting strict GPU-failure mode through the dispatcher entry point had no way to enable it. They had to drop down to read_geotiff_gpu directly, defeating the whole point of the auto-dispatch. Same shape as #1561 and #1605: kwargs added to a backend without a matching update to the dispatcher signature. Add on_gpu_failure to open_geotiff with a sentinel default. When gpu=True, forward to read_geotiff_gpu. When gpu=False (or unset), raise ValueError so the kwarg cannot be silently dropped on the CPU or dask path. Tests cover: signature presence, gpu=False rejection (default and explicit), chunks-only rejection, dispatch byte-stability for the default case, and GPU-path end-to-end forwarding (gated on CUDA). Closes #1615.
1 parent 0088dfe commit 6cc947b

2 files changed

Lines changed: 235 additions & 12 deletions

File tree

xrspatial/geotiff/__init__.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@
5555
]
5656

5757

58+
# Sentinels distinguishing "user passed this kwarg explicitly" from "user
59+
# passed nothing". A plain default of None does not work because None is
60+
# itself a value a caller could supply. ``read_geotiff_gpu`` needs both
61+
# sentinels so it can tell whether the deprecated ``gpu=`` and the new
62+
# ``on_gpu_failure=`` were *each* supplied, and refuse the ambiguous
63+
# both-supplied case regardless of which values were chosen.
64+
# ``open_geotiff`` also uses ``_ON_GPU_FAILURE_SENTINEL`` to distinguish
65+
# "caller never set on_gpu_failure" (default sentinel: skip forwarding so
66+
# the read_geotiff_gpu signature default applies) from "caller set
67+
# on_gpu_failure=<value>" (forward verbatim).
68+
_GPU_DEPRECATED_SENTINEL = object()
69+
_ON_GPU_FAILURE_SENTINEL = object()
70+
71+
5872
def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
5973
"""Try to extract an EPSG code from a WKT or PROJ string.
6074
@@ -404,7 +418,8 @@ def open_geotiff(source, *, dtype=None, window=None,
404418
name: str | None = None,
405419
chunks: int | tuple | None = None,
406420
gpu: bool = False,
407-
max_pixels: int | None = None) -> xr.DataArray:
421+
max_pixels: int | None = None,
422+
on_gpu_failure=_ON_GPU_FAILURE_SENTINEL) -> xr.DataArray:
408423
"""Read a GeoTIFF, COG, or VRT file into an xarray.DataArray.
409424
410425
Automatically dispatches to the best backend:
@@ -442,6 +457,13 @@ def open_geotiff(source, *, dtype=None, window=None,
442457
Maximum allowed pixel count (width * height * samples). None
443458
uses the default (~1 billion). Raise to read legitimately
444459
large files.
460+
on_gpu_failure : {'auto', 'strict'}, optional
461+
Forwarded to ``read_geotiff_gpu`` when ``gpu=True``. Controls
462+
whether GPU decode failures fall back to CPU (``'auto'``,
463+
default) or re-raise the original exception (``'strict'``).
464+
Passing this kwarg with ``gpu=False`` raises ``ValueError``
465+
because the policy only applies to the GPU pipeline. See
466+
``read_geotiff_gpu`` for the full description.
445467
446468
Returns
447469
-------
@@ -475,6 +497,18 @@ def open_geotiff(source, *, dtype=None, window=None,
475497

476498
source = _coerce_path(source)
477499

500+
# ``on_gpu_failure`` is GPU-only. Reject it up front for CPU/dask paths
501+
# rather than silently dropping it once dispatch is decided -- callers
502+
# otherwise have no way to learn that the policy is being ignored.
503+
# ``gpu=False`` (the default) on a ``.vrt`` source still routes through
504+
# ``read_vrt`` below which has no GPU-failure concept, so the same
505+
# rejection rule applies there.
506+
if on_gpu_failure is not _ON_GPU_FAILURE_SENTINEL and not gpu:
507+
raise ValueError(
508+
"on_gpu_failure only applies when gpu=True. "
509+
"Pass gpu=True to enable the GPU pipeline, or drop "
510+
"on_gpu_failure to keep the default CPU path.")
511+
478512
# VRT files (string paths only -- VRT XML references other files on disk)
479513
if isinstance(source, str) and source.lower().endswith('.vrt'):
480514
return read_vrt(source, dtype=dtype, window=window, band=band,
@@ -496,11 +530,15 @@ def open_geotiff(source, *, dtype=None, window=None,
496530

497531
# GPU path
498532
if gpu:
533+
gpu_kwargs = {}
534+
if on_gpu_failure is not _ON_GPU_FAILURE_SENTINEL:
535+
gpu_kwargs['on_gpu_failure'] = on_gpu_failure
499536
return read_geotiff_gpu(source, dtype=dtype,
500537
overview_level=overview_level,
501538
window=window, band=band,
502539
name=name, chunks=chunks,
503-
max_pixels=max_pixels)
540+
max_pixels=max_pixels,
541+
**gpu_kwargs)
504542

505543
# Dask path (CPU)
506544
if chunks is not None:
@@ -665,16 +703,6 @@ def _apply_nodata_mask_gpu(arr_gpu, nodata):
665703
)
666704

667705

668-
# Sentinels distinguishing "user passed this kwarg explicitly" from "user
669-
# passed nothing". A plain default of None would not work because None is
670-
# itself a value a caller could supply. ``read_geotiff_gpu`` needs both
671-
# sentinels so it can tell whether the deprecated ``gpu=`` and the new
672-
# ``on_gpu_failure=`` were *each* supplied, and refuse the ambiguous
673-
# both-supplied case regardless of which values were chosen.
674-
_GPU_DEPRECATED_SENTINEL = object()
675-
_ON_GPU_FAILURE_SENTINEL = object()
676-
677-
678706
# TIFF type ids needed when synthesizing extra_tags entries from attrs.
679707
_TIFF_BYTE = 1
680708
_TIFF_ASCII = 2
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""Regression tests for issue #1615.
2+
3+
``open_geotiff(gpu=True)`` used to silently drop the ``on_gpu_failure``
4+
kwarg: the dispatcher did not declare it and the GPU branch did not
5+
forward it to ``read_geotiff_gpu``. Callers that wanted strict GPU
6+
failure semantics had to bypass ``open_geotiff`` entirely and call
7+
``read_geotiff_gpu`` directly, defeating the dispatcher.
8+
9+
The fix:
10+
11+
* ``open_geotiff`` accepts ``on_gpu_failure`` and forwards it to
12+
``read_geotiff_gpu`` when ``gpu=True``.
13+
* ``on_gpu_failure`` paired with ``gpu=False`` (or unset) raises
14+
``ValueError`` so the kwarg cannot be silently ignored.
15+
* The default sentinel keeps the dispatcher's behavior bit-stable for
16+
callers that never set ``on_gpu_failure``: ``read_geotiff_gpu`` is
17+
called without the kwarg, taking its own ``'auto'`` default.
18+
"""
19+
from __future__ import annotations
20+
21+
import importlib.util
22+
23+
import numpy as np
24+
import pytest
25+
import xarray as xr
26+
27+
28+
def _gpu_available() -> bool:
29+
"""True if cupy is importable and CUDA is initialized."""
30+
if importlib.util.find_spec("cupy") is None:
31+
return False
32+
try:
33+
import cupy
34+
return bool(cupy.cuda.is_available())
35+
except Exception:
36+
return False
37+
38+
39+
_HAS_GPU = _gpu_available()
40+
_gpu_only = pytest.mark.skipif(
41+
not _HAS_GPU,
42+
reason="cupy + CUDA required",
43+
)
44+
45+
46+
@pytest.fixture
47+
def small_tiff_path(tmp_path):
48+
"""4x6 single-band tiled tiff usable by both CPU and GPU readers."""
49+
from xrspatial.geotiff import to_geotiff
50+
51+
arr = np.arange(24, dtype=np.float32).reshape(4, 6)
52+
da = xr.DataArray(
53+
arr,
54+
dims=['y', 'x'],
55+
coords={
56+
'y': np.array([0.5, 1.5, 2.5, 3.5]),
57+
'x': np.array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5]),
58+
},
59+
attrs={'crs': 4326},
60+
)
61+
p = tmp_path / 'on_gpu_failure_1615.tif'
62+
to_geotiff(da, str(p), tile_size=4)
63+
return str(p), arr
64+
65+
66+
# --------------------------------------------------------------------
67+
# Signature: ``open_geotiff`` accepts ``on_gpu_failure``.
68+
# --------------------------------------------------------------------
69+
70+
71+
def test_open_geotiff_signature_includes_on_gpu_failure():
72+
"""Direct ``inspect.signature`` check on the public dispatcher."""
73+
import inspect
74+
75+
from xrspatial.geotiff import open_geotiff
76+
77+
sig = inspect.signature(open_geotiff)
78+
assert 'on_gpu_failure' in sig.parameters, (
79+
"open_geotiff must declare on_gpu_failure so the GPU policy is "
80+
"addressable through the dispatcher entry point"
81+
)
82+
83+
84+
# --------------------------------------------------------------------
85+
# ``on_gpu_failure`` with ``gpu=False`` rejects up front.
86+
# --------------------------------------------------------------------
87+
88+
89+
def test_on_gpu_failure_with_gpu_false_raises_value_error(small_tiff_path):
90+
"""Refuse rather than silently dropping the kwarg on the CPU path."""
91+
from xrspatial.geotiff import open_geotiff
92+
93+
path, _ = small_tiff_path
94+
with pytest.raises(ValueError, match="on_gpu_failure only applies"):
95+
open_geotiff(path, on_gpu_failure='strict')
96+
97+
98+
def test_on_gpu_failure_with_explicit_gpu_false_raises(small_tiff_path):
99+
"""``gpu=False`` explicitly is rejected just like the default."""
100+
from xrspatial.geotiff import open_geotiff
101+
102+
path, _ = small_tiff_path
103+
with pytest.raises(ValueError, match="on_gpu_failure only applies"):
104+
open_geotiff(path, gpu=False, on_gpu_failure='auto')
105+
106+
107+
def test_on_gpu_failure_with_chunks_only_raises(small_tiff_path):
108+
"""Dask CPU path is not GPU and should refuse the kwarg."""
109+
from xrspatial.geotiff import open_geotiff
110+
111+
path, _ = small_tiff_path
112+
with pytest.raises(ValueError, match="on_gpu_failure only applies"):
113+
open_geotiff(path, chunks=2, on_gpu_failure='auto')
114+
115+
116+
# --------------------------------------------------------------------
117+
# Default sentinel does not change CPU/dask behavior.
118+
# --------------------------------------------------------------------
119+
120+
121+
def test_default_dispatch_unchanged_cpu(small_tiff_path):
122+
"""Not passing ``on_gpu_failure`` keeps CPU dispatch byte-stable."""
123+
from xrspatial.geotiff import open_geotiff
124+
125+
path, arr = small_tiff_path
126+
da = open_geotiff(path)
127+
np.testing.assert_array_equal(da.values, arr)
128+
129+
130+
def test_default_dispatch_unchanged_dask(small_tiff_path):
131+
"""Not passing ``on_gpu_failure`` keeps Dask CPU dispatch byte-stable."""
132+
from xrspatial.geotiff import open_geotiff
133+
134+
path, arr = small_tiff_path
135+
da = open_geotiff(path, chunks=2)
136+
np.testing.assert_array_equal(da.values, arr)
137+
138+
139+
# --------------------------------------------------------------------
140+
# GPU forwarding: real behavior parity with ``read_geotiff_gpu``.
141+
# --------------------------------------------------------------------
142+
143+
144+
@_gpu_only
145+
def test_open_geotiff_gpu_forwards_on_gpu_failure_auto(small_tiff_path):
146+
"""``open_geotiff(gpu=True, on_gpu_failure='auto')`` works end-to-end."""
147+
from xrspatial.geotiff import open_geotiff
148+
149+
path, arr = small_tiff_path
150+
da = open_geotiff(path, gpu=True, on_gpu_failure='auto')
151+
# CuPy-backed DataArray -- .data.get() pulls back to host for comparison.
152+
np.testing.assert_array_equal(da.data.get(), arr)
153+
154+
155+
@_gpu_only
156+
def test_open_geotiff_gpu_forwards_on_gpu_failure_strict(small_tiff_path):
157+
"""``on_gpu_failure='strict'`` also reaches the GPU pipeline."""
158+
from xrspatial.geotiff import open_geotiff
159+
160+
path, arr = small_tiff_path
161+
da = open_geotiff(path, gpu=True, on_gpu_failure='strict')
162+
np.testing.assert_array_equal(da.data.get(), arr)
163+
164+
165+
@_gpu_only
166+
def test_open_geotiff_gpu_rejects_invalid_on_gpu_failure(small_tiff_path):
167+
"""An invalid value still surfaces from the underlying validator."""
168+
from xrspatial.geotiff import open_geotiff
169+
170+
path, _ = small_tiff_path
171+
with pytest.raises(ValueError, match="on_gpu_failure must be"):
172+
open_geotiff(path, gpu=True, on_gpu_failure='loose')
173+
174+
175+
# --------------------------------------------------------------------
176+
# Static-only check that works on CPU CI: passing an invalid value
177+
# with gpu=True still routes through to read_geotiff_gpu's validator.
178+
# --------------------------------------------------------------------
179+
180+
181+
def test_invalid_on_gpu_failure_reaches_gpu_validator_on_cpu(small_tiff_path):
182+
"""Even on a CPU-only host, the kwarg should reach ``read_geotiff_gpu``.
183+
184+
Without GPU hardware, ``read_geotiff_gpu`` raises ``ImportError`` (no
185+
cupy) or runs through to the actual decode. The kwarg validator runs
186+
*before* the cupy import, so an invalid value surfaces deterministically
187+
in both environments. This pins the forwarding wire even on CPU-only CI.
188+
"""
189+
from xrspatial.geotiff import open_geotiff
190+
191+
path, _ = small_tiff_path
192+
# gpu=True + invalid value: validation fires before any cupy import,
193+
# so the ValueError reaches us on every host.
194+
with pytest.raises(ValueError, match="on_gpu_failure must be"):
195+
open_geotiff(path, gpu=True, on_gpu_failure='loose')

0 commit comments

Comments
 (0)