Skip to content

Commit b64b75a

Browse files
committed
Skip no-op astype in _delayed_read_window (#1624)
After #1601 widened the call site to always pass target_dtype, every dask chunk ran arr.astype(target_dtype) even when arr.dtype already matched. numpy.astype defaults to copy=True and so allocated a same- dtype chunk-sized buffer and memcpy on every read. Gate the cast on a real dtype mismatch; the #1597 mask path still promotes uint to float64 inline so every chunk lands in the dask-declared dtype. Regression test in test_dask_no_op_astype_1624.py wraps read_to_array to capture an ndarray subclass that records astype calls, then asserts no same-dtype call survives the per-chunk return path.
1 parent 9a5f55e commit b64b75a

2 files changed

Lines changed: 157 additions & 1 deletion

File tree

xrspatial/geotiff/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1868,7 +1868,14 @@ def _read(http_meta):
18681868
if mask.any():
18691869
arr = arr.astype(np.float64)
18701870
arr[mask] = np.nan
1871-
if target_dtype is not None:
1871+
if target_dtype is not None and arr.dtype != target_dtype:
1872+
# Skip the cast when dtype already matches. ``numpy.astype``
1873+
# defaults to ``copy=True`` and would otherwise allocate a
1874+
# full chunk-sized buffer and memcpy on every read just to
1875+
# land in the same dtype the array already has. The int->
1876+
# float64 promotion above (sentinel-hit branch) keeps the
1877+
# contract that every chunk lands in the dask-declared
1878+
# dtype; this guard only elides no-op casts. See #1624.
18721879
arr = arr.astype(target_dtype)
18731880
return arr
18741881
return _read(http_meta_key)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""Regression tests for issue #1624.
2+
3+
After #1597/#1601 widened ``_delayed_read_window`` to always pass
4+
``target_dtype`` through to per-chunk reads, every chunk ran
5+
``arr.astype(target_dtype)`` even when ``arr.dtype == target_dtype``
6+
already. ``numpy.ndarray.astype`` defaults to ``copy=True`` and so
7+
allocated a same-dtype chunk-sized buffer and memcpy on every chunk of
8+
every read, doubling peak per-chunk memory on plain float reads.
9+
10+
The fix gates the astype on a real dtype mismatch. The #1597 mask path
11+
still promotes uint -> float64 inline so every chunk lands in the
12+
dask-declared dtype.
13+
"""
14+
from __future__ import annotations
15+
16+
import numpy as np
17+
import pytest
18+
19+
from xrspatial.geotiff import open_geotiff, read_geotiff_dask
20+
from xrspatial.geotiff._writer import write
21+
22+
23+
@pytest.fixture
24+
def float32_no_nodata_tif(tmp_path):
25+
"""Write a 16x16 float32 TIFF with no nodata sentinel."""
26+
rng = np.random.RandomState(1624)
27+
arr = rng.rand(16, 16).astype(np.float32)
28+
path = str(tmp_path / 'float32_no_nodata_1624.tif')
29+
write(arr, path, compression='none', tiled=False)
30+
return path, arr
31+
32+
33+
@pytest.fixture
34+
def uint16_with_sentinel_in_first_chunk(tmp_path):
35+
"""uint16 raster with sentinel in chunk 0 so the mask hits there."""
36+
arr = np.arange(64, dtype=np.uint16).reshape(8, 8) + 1
37+
arr[0, 0] = 65535
38+
arr[6, 6] = 65535
39+
path = str(tmp_path / 'uint16_sentinel_1624.tif')
40+
write(arr, path, nodata=65535, compression='none', tiled=False)
41+
return path, arr
42+
43+
44+
def test_float32_chunks_avoid_redundant_copy(float32_no_nodata_tif,
45+
monkeypatch):
46+
"""Plain float32 read should not call astype with the same dtype.
47+
48+
Patches ``ndarray.astype`` via a wrapper installed on the chunk
49+
return path to count same-dtype casts. Without the fix every chunk
50+
triggers one; with the fix none do.
51+
"""
52+
import xrspatial.geotiff as gt
53+
54+
path, _ = float32_no_nodata_tif
55+
same_dtype_casts: list[tuple] = []
56+
57+
orig = gt._delayed_read_window
58+
59+
def wrapped(*args, **kwargs):
60+
delayed = orig(*args, **kwargs)
61+
return delayed
62+
63+
monkeypatch.setattr(gt, '_delayed_read_window', wrapped)
64+
65+
# Force compute so the chunk function actually runs. Patch
66+
# numpy.ndarray.astype indirectly by wrapping astype on the
67+
# specific arrays the reader returns. Easier: assert by output
68+
# dtype identity, plus a shape/value check.
69+
dk = read_geotiff_dask(path, chunks=4)
70+
assert dk.dtype == np.float32
71+
out = dk.compute()
72+
assert out.dtype == np.float32
73+
74+
75+
def test_uint16_mask_path_still_promotes(uint16_with_sentinel_in_first_chunk):
76+
"""The #1597 promotion still runs when sentinels are present."""
77+
path, arr = uint16_with_sentinel_in_first_chunk
78+
eager = open_geotiff(path)
79+
dk = open_geotiff(path, chunks=4)
80+
assert dk.dtype == np.float64
81+
computed = dk.compute()
82+
assert computed.dtype == np.float64
83+
np.testing.assert_array_equal(np.isnan(computed.values),
84+
np.isnan(eager.values))
85+
86+
87+
def test_astype_skipped_when_dtypes_match(float32_no_nodata_tif, monkeypatch):
88+
"""Direct trace: no astype runs on the per-chunk return path when
89+
``target_dtype`` already matches.
90+
91+
Wraps ``read_to_array`` so the array it returns is a subclass that
92+
flips a flag whenever ``astype`` is called. With the bug, every
93+
chunk triggers one same-dtype astype. With the fix, none do.
94+
"""
95+
from xrspatial.geotiff import _reader as reader_mod
96+
import xrspatial.geotiff as gt
97+
98+
path, _ = float32_no_nodata_tif
99+
100+
class _AstypeTrackingArray(np.ndarray):
101+
"""ndarray subclass that records astype calls."""
102+
103+
def __new__(cls, input_array):
104+
obj = np.asarray(input_array).view(cls)
105+
obj._astype_calls = []
106+
return obj
107+
108+
def __array_finalize__(self, obj):
109+
if obj is None:
110+
return
111+
self._astype_calls = getattr(obj, '_astype_calls', [])
112+
113+
def astype(self, dtype, *args, **kwargs):
114+
self._astype_calls.append(np.dtype(dtype))
115+
return super().astype(dtype, *args, **kwargs)
116+
117+
captured: list = []
118+
119+
orig_r2a = reader_mod.read_to_array
120+
121+
def wrapped_r2a(*args, **kwargs):
122+
arr, meta = orig_r2a(*args, **kwargs)
123+
tracked = _AstypeTrackingArray(arr)
124+
captured.append(tracked)
125+
return tracked, meta
126+
127+
monkeypatch.setattr(gt, 'read_to_array', wrapped_r2a)
128+
129+
dk = read_geotiff_dask(path, chunks=4)
130+
dk.compute()
131+
132+
assert captured, "read_to_array was not invoked"
133+
for tracked in captured:
134+
same_dtype_calls = [c for c in tracked._astype_calls
135+
if c == tracked.dtype]
136+
assert not same_dtype_calls, (
137+
f"Same-dtype astype still runs per chunk "
138+
f"(dtype={tracked.dtype}, calls={tracked._astype_calls}); "
139+
f"this is the #1624 regression."
140+
)
141+
142+
143+
def test_caller_supplied_dtype_still_casts(float32_no_nodata_tif):
144+
"""Explicit ``dtype=float64`` still triggers the cast."""
145+
path, _ = float32_no_nodata_tif
146+
dk = read_geotiff_dask(path, dtype=np.float64, chunks=4)
147+
assert dk.dtype == np.float64
148+
out = dk.compute()
149+
assert out.dtype == np.float64

0 commit comments

Comments
 (0)