Skip to content

Commit 6c6fa8a

Browse files
committed
Add dask/cupy/dask+cupy support for trim() and crop()
Both functions only need numpy for the boundary scan; the output is a slice of the original DataArray, so the backend is preserved. Convert to numpy for the scan, then slice the original raster. Update README feature matrix for trim and crop.
1 parent 314a65e commit 6c6fa8a

File tree

3 files changed

+129
-5
lines changed

3 files changed

+129
-5
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,9 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
237237
| Name | Description | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
238238
|:----------:|:------------|:----------------------:|:--------------------:|:-------------------:|:------:|
239239
| [Apply](xrspatial/zonal.py) | Applies a custom function to each zone in a classified raster | ✅️ | ✅️ | 🔄 | 🔄 |
240-
| [Crop](xrspatial/zonal.py) | Extracts the bounding rectangle of a specific zone | ✅️ | | | |
240+
| [Crop](xrspatial/zonal.py) | Extracts the bounding rectangle of a specific zone | ✅️ | 🔄 | 🔄 | 🔄 |
241241
| [Regions](xrspatial/zonal.py) | Identifies connected regions of non-zero cells | ✅️ | ✅️ | ✅️ | ✅️ |
242-
| [Trim](xrspatial/zonal.py) | Removes nodata border rows and columns from a raster | ✅️ | | | |
242+
| [Trim](xrspatial/zonal.py) | Removes nodata border rows and columns from a raster | ✅️ | 🔄 | 🔄 | 🔄 |
243243
| [Zonal Statistics](xrspatial/zonal.py) | Computes summary statistics for a value raster within each zone | ✅️ | ✅️| ✅️ | 🔄 |
244244
| [Zonal Cross Tabulate](xrspatial/zonal.py) | Cross-tabulates agreement between two categorical rasters | ✅️ | ✅️| 🔄 | 🔄 |
245245

xrspatial/tests/test_dask_cupy_gaps.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
"""Tests for dask+cupy backends: perlin, terrain, crosstab."""
1+
"""Tests for dask+cupy backends: perlin, terrain, crosstab, trim, crop."""
22

33
import numpy as np
44
import xarray as xr
55

66
from xrspatial import generate_terrain, perlin
77
from xrspatial.tests.general_checks import cuda_and_cupy_available, dask_array_available
88
from xrspatial.utils import has_cuda_and_cupy
9+
from xrspatial.zonal import crop, trim
910

1011

1112
def _make_raster(shape=(50, 50), backend='numpy', chunks=(10, 10)):
@@ -156,3 +157,108 @@ def test_crosstab_dask_cupy():
156157
np.testing.assert_array_equal(
157158
df_numpy_sorted.values, df_computed_sorted.values,
158159
)
160+
161+
162+
# ---- trim: dask, cupy, dask+cupy ----
163+
164+
_TRIM_ARR = np.array([
165+
[0, 0, 0, 0],
166+
[0, 4, 0, 0],
167+
[0, 4, 4, 0],
168+
[0, 1, 1, 0],
169+
[0, 0, 0, 0],
170+
], dtype=np.int64)
171+
_TRIM_EXPECTED_SHAPE = (3, 2)
172+
_TRIM_EXPECTED = np.array([[4, 0], [4, 4], [1, 1]], dtype=np.int64)
173+
174+
175+
@dask_array_available
176+
def test_trim_dask():
177+
import dask.array as da
178+
179+
raster = xr.DataArray(
180+
da.from_array(_TRIM_ARR, chunks=(3, 2)), dims=['y', 'x'],
181+
)
182+
result = trim(raster, values=(0,))
183+
assert result.shape == _TRIM_EXPECTED_SHAPE
184+
np.testing.assert_array_equal(result.data.compute(), _TRIM_EXPECTED)
185+
186+
187+
@cuda_and_cupy_available
188+
def test_trim_cupy():
189+
import cupy
190+
191+
raster = xr.DataArray(cupy.asarray(_TRIM_ARR), dims=['y', 'x'])
192+
result = trim(raster, values=(0,))
193+
assert result.shape == _TRIM_EXPECTED_SHAPE
194+
np.testing.assert_array_equal(result.data.get(), _TRIM_EXPECTED)
195+
196+
197+
@cuda_and_cupy_available
198+
@dask_array_available
199+
def test_trim_dask_cupy():
200+
import cupy
201+
import dask.array as da
202+
203+
gpu = cupy.asarray(_TRIM_ARR)
204+
raster = xr.DataArray(da.from_array(gpu, chunks=(3, 2)), dims=['y', 'x'])
205+
result = trim(raster, values=(0,))
206+
assert result.shape == _TRIM_EXPECTED_SHAPE
207+
computed = result.data.compute()
208+
assert isinstance(computed, cupy.ndarray)
209+
np.testing.assert_array_equal(computed.get(), _TRIM_EXPECTED)
210+
211+
212+
# ---- crop: dask, cupy, dask+cupy ----
213+
214+
_CROP_ARR = np.array([
215+
[0, 4, 0, 3],
216+
[0, 4, 4, 3],
217+
[0, 1, 1, 3],
218+
[0, 1, 1, 3],
219+
[0, 0, 0, 0],
220+
], dtype=np.int64)
221+
_CROP_EXPECTED_SHAPE = (4, 3)
222+
_CROP_EXPECTED = np.array([
223+
[4, 0, 3],
224+
[4, 4, 3],
225+
[1, 1, 3],
226+
[1, 1, 3],
227+
], dtype=np.int64)
228+
229+
230+
@dask_array_available
231+
def test_crop_dask():
232+
import dask.array as da
233+
234+
raster = xr.DataArray(
235+
da.from_array(_CROP_ARR, chunks=(3, 2)), dims=['y', 'x'],
236+
)
237+
result = crop(raster, raster, zones_ids=(1, 3))
238+
assert result.shape == _CROP_EXPECTED_SHAPE
239+
np.testing.assert_array_equal(result.data.compute(), _CROP_EXPECTED)
240+
241+
242+
@cuda_and_cupy_available
243+
def test_crop_cupy():
244+
import cupy
245+
246+
raster = xr.DataArray(cupy.asarray(_CROP_ARR), dims=['y', 'x'])
247+
result = crop(raster, raster, zones_ids=(1, 3))
248+
assert result.shape == _CROP_EXPECTED_SHAPE
249+
np.testing.assert_array_equal(result.data.get(), _CROP_EXPECTED)
250+
251+
252+
@cuda_and_cupy_available
253+
@dask_array_available
254+
def test_crop_dask_cupy():
255+
import cupy
256+
import dask.array as da
257+
258+
gpu = cupy.asarray(_CROP_ARR)
259+
raster = xr.DataArray(da.from_array(gpu, chunks=(3, 2)), dims=['y', 'x'])
260+
result = crop(raster, raster, zones_ids=(1, 3))
261+
assert result.shape == _CROP_EXPECTED_SHAPE
262+
computed = result.data.compute()
263+
assert isinstance(computed, cupy.ndarray)
264+
np.testing.assert_array_equal(computed.get(), _CROP_EXPECTED)

xrspatial/zonal.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,7 +1890,16 @@ def trim(
18901890
"""
18911891
_validate_raster(raster, func_name='trim', name='raster', ndim=2)
18921892

1893-
top, bottom, left, right = _trim(raster.data, values)
1893+
data = raster.data
1894+
# _trim needs element access; materialise to numpy for non-numpy backends
1895+
if is_cupy_array(data):
1896+
data = data.get()
1897+
elif has_dask_array() and isinstance(data, da.Array):
1898+
data = data.compute()
1899+
if is_cupy_array(data):
1900+
data = data.get()
1901+
1902+
top, bottom, left, right = _trim(data, values)
18941903
arr = raster[top: bottom + 1, left: right + 1]
18951904
arr.name = name
18961905
return arr
@@ -2113,7 +2122,16 @@ def crop(
21132122
_validate_raster(zones, func_name='crop', name='zones', ndim=2)
21142123
_validate_raster(values, func_name='crop', name='values', ndim=2)
21152124

2116-
top, bottom, left, right = _crop(zones.data, zones_ids)
2125+
data = zones.data
2126+
# _crop is @ngjit; materialise to numpy for non-numpy backends
2127+
if is_cupy_array(data):
2128+
data = data.get()
2129+
elif has_dask_array() and isinstance(data, da.Array):
2130+
data = data.compute()
2131+
if is_cupy_array(data):
2132+
data = data.get()
2133+
2134+
top, bottom, left, right = _crop(data, zones_ids)
21172135
arr = values[top: bottom + 1, left: right + 1]
21182136
arr.name = name
21192137
return arr

0 commit comments

Comments
 (0)