Skip to content

Commit 835b5b1

Browse files
committed
GPU-enable all classification ops: equal_interval, quantile, natural_breaks now support all 4 backends (#190)
- Add Dask+CuPy backend for equal_interval via _run_dask_cupy_equal_interval - Replace quantile Dask+CuPy NotImplementedError with working implementation that materializes data to CPU for percentile computation - Add CuPy, Dask+NumPy, and Dask+CuPy backends for natural_breaks by extracting shared _compute_natural_break_bins helper - Add 7 new tests covering all new backend combinations - Update README feature matrix to reflect full backend support
1 parent a28019b commit 835b5b1

File tree

3 files changed

+159
-25
lines changed

3 files changed

+159
-25
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
137137

138138
| Name | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray |
139139
|:----------:|:----------------------:|:--------------------:|:-------------------:|:------:|
140-
| [Equal Interval](xrspatial/classify.py) |✅️ ||| |
141-
| [Natural Breaks](xrspatial/classify.py) |✅️ | | | |
140+
| [Equal Interval](xrspatial/classify.py) |✅️ ||| |
141+
| [Natural Breaks](xrspatial/classify.py) |✅️ ||| |
142142
| [Reclassify](xrspatial/classify.py) |✅️ ||||
143-
| [Quantile](xrspatial/classify.py) |✅️ | |||
143+
| [Quantile](xrspatial/classify.py) |✅️ || ||
144144

145145
-------
146146

xrspatial/classify.py

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,18 @@ def _run_quantile(data, k, module):
398398

399399

400400
def _run_dask_cupy_quantile(data, k):
401-
msg = 'Currently percentile calculation has not' \
402-
'been supported for Dask array backed by CuPy.' \
403-
'See issue at https://github.com/dask/dask/issues/6942'
404-
raise NotImplementedError(msg)
401+
w = 100.0 / k
402+
p = np.arange(w, 100 + w, w)
403+
if p[-1] > 100.0:
404+
p[-1] = 100.0
405+
406+
finite_mask = da.isfinite(data)
407+
finite_data = data[finite_mask].compute()
408+
# transfer from GPU to CPU
409+
finite_data_np = cupy.asnumpy(finite_data)
410+
q = np.percentile(finite_data_np, p)
411+
q = np.unique(q)
412+
return q
405413

406414

407415
def _quantile(agg, k):
@@ -575,25 +583,28 @@ def _run_jenks(data, n_classes):
575583
return kclass
576584

577585

578-
def _run_natural_break(agg, num_sample, k):
579-
data = agg.data
580-
num_data = data.size
581-
max_data = np.max(data[np.isfinite(data)])
586+
def _compute_natural_break_bins(data_flat_np, num_sample, k, max_data):
587+
"""Shared helper: compute natural break bins from a flat numpy array.
588+
589+
Returns (bins, uvk) where bins is a numpy array of bin edges
590+
and uvk is the number of unique values.
591+
"""
592+
num_data = data_flat_np.size
582593

583594
if num_sample is not None and num_sample < num_data:
584595
# randomly select sample from the whole dataset
585596
# create a pseudo random number generator
586-
# Note: cupy and nupy generate different random numbers
597+
# Note: cupy and numpy generate different random numbers
587598
# use numpy.random to ensure the same result
588599
generator = np.random.RandomState(1234567890)
589600
idx = np.linspace(
590-
0, data.size, data.size, endpoint=False, dtype=np.uint32
601+
0, num_data, num_data, endpoint=False, dtype=np.uint32
591602
)
592603
generator.shuffle(idx)
593604
sample_idx = idx[:num_sample]
594-
sample_data = data.flatten()[sample_idx]
605+
sample_data = data_flat_np[sample_idx]
595606
else:
596-
sample_data = data.flatten()
607+
sample_data = data_flat_np
597608

598609
# warning if number of total data points to fit the model bigger than 40k
599610
if sample_data.size >= 40000:
@@ -627,6 +638,45 @@ def _run_natural_break(agg, num_sample, k):
627638
bins = np.array(centroids[1:])
628639
bins[-1] = max_data
629640

641+
return bins, uvk
642+
643+
644+
def _run_natural_break(agg, num_sample, k):
645+
data = agg.data
646+
max_data = float(np.max(data[np.isfinite(data)]))
647+
data_flat_np = data.flatten()
648+
649+
bins, uvk = _compute_natural_break_bins(data_flat_np, num_sample, k, max_data)
650+
out = _bin(agg, bins, np.arange(uvk))
651+
return out
652+
653+
654+
def _run_cupy_natural_break(agg, num_sample, k):
655+
data = agg.data
656+
max_data = float(cupy.max(data[cupy.isfinite(data)]).get())
657+
data_flat_np = cupy.asnumpy(data.ravel())
658+
659+
bins, uvk = _compute_natural_break_bins(data_flat_np, num_sample, k, max_data)
660+
out = _bin(agg, bins, np.arange(uvk))
661+
return out
662+
663+
664+
def _run_dask_natural_break(agg, num_sample, k):
665+
data = agg.data
666+
max_data = float(da.max(data[da.isfinite(data)]).compute())
667+
data_flat_np = np.asarray(data.ravel().compute())
668+
669+
bins, uvk = _compute_natural_break_bins(data_flat_np, num_sample, k, max_data)
670+
out = _bin(agg, bins, np.arange(uvk))
671+
return out
672+
673+
674+
def _run_dask_cupy_natural_break(agg, num_sample, k):
675+
data = agg.data
676+
max_data = float(da.max(data[da.isfinite(data)]).compute().item())
677+
data_flat_np = cupy.asnumpy(data.ravel().compute())
678+
679+
bins, uvk = _compute_natural_break_bins(data_flat_np, num_sample, k, max_data)
630680
out = _bin(agg, bins, np.arange(uvk))
631681
return out
632682

@@ -644,7 +694,8 @@ def natural_breaks(agg: xr.DataArray,
644694
Parameters
645695
----------
646696
agg : xarray.DataArray
647-
2D NumPy DataArray of values to be reclassified.
697+
2D NumPy, CuPy, NumPy-backed Dask, or CuPy-backed Dask array
698+
of values to be reclassified.
648699
num_sample : int, default=20000
649700
Number of sample data points used to fit the model.
650701
Natural Breaks (Jenks) classification is indeed O(n²) complexity,
@@ -715,13 +766,10 @@ def natural_breaks(agg: xr.DataArray,
715766
"""
716767

717768
mapper = ArrayTypeFunctionMapping(
718-
numpy_func=lambda *args: _run_natural_break(*args),
719-
dask_func=lambda *args: not_implemented_func(
720-
*args, messages='natural_breaks() does not support dask with numpy backed DataArray.'), # noqa
721-
cupy_func=lambda *args: not_implemented_func(
722-
*args, messages='natural_breaks() does not support cupy backed DataArray.'), # noqa
723-
dask_cupy_func=lambda *args: not_implemented_func(
724-
*args, messages='natural_breaks() does not support dask with cupy backed DataArray.'), # noqa
769+
numpy_func=_run_natural_break,
770+
dask_func=_run_dask_natural_break,
771+
cupy_func=_run_cupy_natural_break,
772+
dask_cupy_func=_run_dask_cupy_natural_break,
725773
)
726774
out = mapper(agg)(agg, num_sample, k)
727775
return xr.DataArray(out,
@@ -772,6 +820,27 @@ def _run_equal_interval(agg, k, module):
772820
return out
773821

774822

823+
def _run_dask_cupy_equal_interval(agg, k):
824+
data = agg.data.ravel()
825+
826+
# replace inf with nan
827+
data = da.where(data == np.inf, np.nan, data)
828+
data = da.where(data == -np.inf, np.nan, data)
829+
830+
min_data = float(da.nanmin(data).compute().item())
831+
max_data = float(da.nanmax(data).compute().item())
832+
833+
width = (max_data - min_data) * 1.0 / k
834+
cuts = np.arange(min_data + width, max_data + width, width)
835+
l_cuts = cuts.shape[0]
836+
if l_cuts > k:
837+
cuts = cuts[0:k]
838+
839+
cuts[-1] = max_data
840+
out = _bin(agg, cuts, np.arange(l_cuts))
841+
return out
842+
843+
775844
def equal_interval(agg: xr.DataArray,
776845
k: int = 5,
777846
name: Optional[str] = 'equal_interval') -> xr.DataArray:
@@ -832,8 +901,7 @@ def equal_interval(agg: xr.DataArray,
832901
numpy_func=lambda *args: _run_equal_interval(*args, module=np),
833902
dask_func=lambda *args: _run_equal_interval(*args, module=da),
834903
cupy_func=lambda *args: _run_equal_interval(*args, module=cupy),
835-
dask_cupy_func=lambda *args: not_implemented_func(
836-
*args, messages='equal_interval() does support dask with cupy backed DataArray.'), # noqa
904+
dask_cupy_func=_run_dask_cupy_equal_interval
837905
)
838906
out = mapper(agg)(agg, k)
839907
return xr.DataArray(out,

xrspatial/tests/test_classify.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,69 @@ def test_equal_interval_cupy(result_equal_interval):
284284
cupy_agg = input_data(backend='cupy')
285285
cupy_result = equal_interval(cupy_agg, k=k)
286286
general_output_checks(cupy_agg, cupy_result, expected_result, verify_dtype=True)
287+
288+
289+
@dask_array_available
290+
@cuda_and_cupy_available
291+
def test_equal_interval_dask_cupy(result_equal_interval):
292+
k, expected_result = result_equal_interval
293+
dask_cupy_agg = input_data(backend='dask+cupy')
294+
dask_cupy_result = equal_interval(dask_cupy_agg, k=k)
295+
general_output_checks(dask_cupy_agg, dask_cupy_result, expected_result, verify_dtype=True)
296+
297+
298+
@dask_array_available
299+
@cuda_and_cupy_available
300+
def test_quantile_dask_cupy(result_quantile):
301+
# Relaxed verification (same pattern as test_quantile_dask_numpy)
302+
# because percentile is computed on CPU from materialized data
303+
dask_cupy_agg = input_data('dask+cupy')
304+
k, expected_result = result_quantile
305+
dask_cupy_quantile = quantile(dask_cupy_agg, k=k)
306+
general_output_checks(dask_cupy_agg, dask_cupy_quantile)
307+
dask_cupy_quantile = dask_cupy_quantile.compute()
308+
import cupy as cp
309+
result_data = cp.asnumpy(dask_cupy_quantile.data)
310+
unique_elements = np.unique(result_data[np.isfinite(result_data)])
311+
assert len(unique_elements) == k
312+
313+
314+
@cuda_and_cupy_available
315+
def test_natural_breaks_cupy(result_natural_breaks):
316+
cupy_agg = input_data('cupy')
317+
k, expected_result = result_natural_breaks
318+
cupy_natural_breaks = natural_breaks(cupy_agg, k=k)
319+
general_output_checks(cupy_agg, cupy_natural_breaks, expected_result, verify_dtype=True)
320+
321+
322+
@dask_array_available
323+
def test_natural_breaks_dask_numpy(result_natural_breaks):
324+
dask_agg = input_data('dask+numpy')
325+
k, expected_result = result_natural_breaks
326+
dask_natural_breaks = natural_breaks(dask_agg, k=k)
327+
general_output_checks(dask_agg, dask_natural_breaks, expected_result, verify_dtype=True)
328+
329+
330+
@dask_array_available
331+
@cuda_and_cupy_available
332+
def test_natural_breaks_dask_cupy(result_natural_breaks):
333+
dask_cupy_agg = input_data('dask+cupy')
334+
k, expected_result = result_natural_breaks
335+
dask_cupy_natural_breaks = natural_breaks(dask_cupy_agg, k=k)
336+
general_output_checks(dask_cupy_agg, dask_cupy_natural_breaks, expected_result, verify_dtype=True)
337+
338+
339+
@cuda_and_cupy_available
340+
def test_natural_breaks_cupy_num_sample(result_natural_breaks_num_sample):
341+
cupy_agg = input_data('cupy')
342+
k, num_sample, expected_result = result_natural_breaks_num_sample
343+
cupy_natural_breaks = natural_breaks(cupy_agg, k=k, num_sample=num_sample)
344+
general_output_checks(cupy_agg, cupy_natural_breaks, expected_result, verify_dtype=True)
345+
346+
347+
@dask_array_available
348+
def test_natural_breaks_dask_numpy_num_sample(result_natural_breaks_num_sample):
349+
dask_agg = input_data('dask+numpy')
350+
k, num_sample, expected_result = result_natural_breaks_num_sample
351+
dask_natural_breaks = natural_breaks(dask_agg, k=k, num_sample=num_sample)
352+
general_output_checks(dask_agg, dask_natural_breaks, expected_result, verify_dtype=True)

0 commit comments

Comments
 (0)