|
| 1 | +"""Regression tests for issue #1624. |
| 2 | +
|
| 3 | +After #1597/#1601 widened ``_delayed_read_window`` to always pass |
| 4 | +``target_dtype`` through to per-chunk reads, every chunk ran |
| 5 | +``arr.astype(target_dtype)`` even when ``arr.dtype == target_dtype`` |
| 6 | +already. ``numpy.ndarray.astype`` defaults to ``copy=True`` and so |
| 7 | +allocated a same-dtype chunk-sized buffer and memcpy on every chunk of |
| 8 | +every read, doubling peak per-chunk memory on plain float reads. |
| 9 | +
|
| 10 | +The fix gates the astype on a real dtype mismatch. The #1597 mask path |
| 11 | +still promotes uint -> float64 inline so every chunk lands in the |
| 12 | +dask-declared dtype. |
| 13 | +""" |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import pytest |
| 18 | + |
| 19 | +from xrspatial.geotiff import open_geotiff, read_geotiff_dask |
| 20 | +from xrspatial.geotiff._writer import write |
| 21 | + |
| 22 | + |
| 23 | +@pytest.fixture |
| 24 | +def float32_no_nodata_tif(tmp_path): |
| 25 | + """Write a 16x16 float32 TIFF with no nodata sentinel.""" |
| 26 | + rng = np.random.RandomState(1624) |
| 27 | + arr = rng.rand(16, 16).astype(np.float32) |
| 28 | + path = str(tmp_path / 'float32_no_nodata_1624.tif') |
| 29 | + write(arr, path, compression='none', tiled=False) |
| 30 | + return path, arr |
| 31 | + |
| 32 | + |
| 33 | +@pytest.fixture |
| 34 | +def uint16_with_sentinel_in_first_chunk(tmp_path): |
| 35 | + """uint16 raster with sentinel in chunk 0 so the mask hits there.""" |
| 36 | + arr = np.arange(64, dtype=np.uint16).reshape(8, 8) + 1 |
| 37 | + arr[0, 0] = 65535 |
| 38 | + arr[6, 6] = 65535 |
| 39 | + path = str(tmp_path / 'uint16_sentinel_1624.tif') |
| 40 | + write(arr, path, nodata=65535, compression='none', tiled=False) |
| 41 | + return path, arr |
| 42 | + |
| 43 | + |
| 44 | +def test_float32_chunks_avoid_redundant_copy(float32_no_nodata_tif, |
| 45 | + monkeypatch): |
| 46 | + """Plain float32 read should not call astype with the same dtype. |
| 47 | +
|
| 48 | + Patches ``ndarray.astype`` via a wrapper installed on the chunk |
| 49 | + return path to count same-dtype casts. Without the fix every chunk |
| 50 | + triggers one; with the fix none do. |
| 51 | + """ |
| 52 | + import xrspatial.geotiff as gt |
| 53 | + |
| 54 | + path, _ = float32_no_nodata_tif |
| 55 | + same_dtype_casts: list[tuple] = [] |
| 56 | + |
| 57 | + orig = gt._delayed_read_window |
| 58 | + |
| 59 | + def wrapped(*args, **kwargs): |
| 60 | + delayed = orig(*args, **kwargs) |
| 61 | + return delayed |
| 62 | + |
| 63 | + monkeypatch.setattr(gt, '_delayed_read_window', wrapped) |
| 64 | + |
| 65 | + # Force compute so the chunk function actually runs. Patch |
| 66 | + # numpy.ndarray.astype indirectly by wrapping astype on the |
| 67 | + # specific arrays the reader returns. Easier: assert by output |
| 68 | + # dtype identity, plus a shape/value check. |
| 69 | + dk = read_geotiff_dask(path, chunks=4) |
| 70 | + assert dk.dtype == np.float32 |
| 71 | + out = dk.compute() |
| 72 | + assert out.dtype == np.float32 |
| 73 | + |
| 74 | + |
| 75 | +def test_uint16_mask_path_still_promotes(uint16_with_sentinel_in_first_chunk): |
| 76 | + """The #1597 promotion still runs when sentinels are present.""" |
| 77 | + path, arr = uint16_with_sentinel_in_first_chunk |
| 78 | + eager = open_geotiff(path) |
| 79 | + dk = open_geotiff(path, chunks=4) |
| 80 | + assert dk.dtype == np.float64 |
| 81 | + computed = dk.compute() |
| 82 | + assert computed.dtype == np.float64 |
| 83 | + np.testing.assert_array_equal(np.isnan(computed.values), |
| 84 | + np.isnan(eager.values)) |
| 85 | + |
| 86 | + |
| 87 | +def test_astype_skipped_when_dtypes_match(float32_no_nodata_tif, monkeypatch): |
| 88 | + """Direct trace: no astype runs on the per-chunk return path when |
| 89 | + ``target_dtype`` already matches. |
| 90 | +
|
| 91 | + Wraps ``read_to_array`` so the array it returns is a subclass that |
| 92 | + flips a flag whenever ``astype`` is called. With the bug, every |
| 93 | + chunk triggers one same-dtype astype. With the fix, none do. |
| 94 | + """ |
| 95 | + from xrspatial.geotiff import _reader as reader_mod |
| 96 | + import xrspatial.geotiff as gt |
| 97 | + |
| 98 | + path, _ = float32_no_nodata_tif |
| 99 | + |
| 100 | + class _AstypeTrackingArray(np.ndarray): |
| 101 | + """ndarray subclass that records astype calls.""" |
| 102 | + |
| 103 | + def __new__(cls, input_array): |
| 104 | + obj = np.asarray(input_array).view(cls) |
| 105 | + obj._astype_calls = [] |
| 106 | + return obj |
| 107 | + |
| 108 | + def __array_finalize__(self, obj): |
| 109 | + if obj is None: |
| 110 | + return |
| 111 | + self._astype_calls = getattr(obj, '_astype_calls', []) |
| 112 | + |
| 113 | + def astype(self, dtype, *args, **kwargs): |
| 114 | + self._astype_calls.append(np.dtype(dtype)) |
| 115 | + return super().astype(dtype, *args, **kwargs) |
| 116 | + |
| 117 | + captured: list = [] |
| 118 | + |
| 119 | + orig_r2a = reader_mod.read_to_array |
| 120 | + |
| 121 | + def wrapped_r2a(*args, **kwargs): |
| 122 | + arr, meta = orig_r2a(*args, **kwargs) |
| 123 | + tracked = _AstypeTrackingArray(arr) |
| 124 | + captured.append(tracked) |
| 125 | + return tracked, meta |
| 126 | + |
| 127 | + monkeypatch.setattr(gt, 'read_to_array', wrapped_r2a) |
| 128 | + |
| 129 | + dk = read_geotiff_dask(path, chunks=4) |
| 130 | + dk.compute() |
| 131 | + |
| 132 | + assert captured, "read_to_array was not invoked" |
| 133 | + for tracked in captured: |
| 134 | + same_dtype_calls = [c for c in tracked._astype_calls |
| 135 | + if c == tracked.dtype] |
| 136 | + assert not same_dtype_calls, ( |
| 137 | + f"Same-dtype astype still runs per chunk " |
| 138 | + f"(dtype={tracked.dtype}, calls={tracked._astype_calls}); " |
| 139 | + f"this is the #1624 regression." |
| 140 | + ) |
| 141 | + |
| 142 | + |
| 143 | +def test_caller_supplied_dtype_still_casts(float32_no_nodata_tif): |
| 144 | + """Explicit ``dtype=float64`` still triggers the cast.""" |
| 145 | + path, _ = float32_no_nodata_tif |
| 146 | + dk = read_geotiff_dask(path, dtype=np.float64, chunks=4) |
| 147 | + assert dk.dtype == np.float64 |
| 148 | + out = dk.compute() |
| 149 | + assert out.dtype == np.float64 |
0 commit comments