Skip to content

Commit 7ce3707

Browse files
committed
optimize classification ops: reduce memory allocations and dask passes
- Remove unnecessary .ravel() in _run_equal_interval; nanmin/nanmax work on 2D - Combine double where(±inf) into single isinf pass in _run_equal_interval and _run_cupy_bin, halving temporary allocations - Use dask.compute(min, max) instead of two separate .compute() calls so dask reads data once instead of twice - Build cuts as numpy array for all backends (was needlessly dask for k elements) - Replace boolean fancy indexing in dask natural_break functions with da.where + da.nanmax to preserve chunk structure - Delete _run_dask_cupy_equal_interval; unified _run_equal_interval with module=da handles both dask+numpy and dask+cupy
1 parent 9b726c5 commit 7ce3707

File tree

1 file changed

+27
-51
lines changed

1 file changed

+27
-51
lines changed

xrspatial/classify.py

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ class cupy(object):
1515
ndarray = False
1616

1717
try:
18+
import dask
1819
import dask.array as da
1920
except ImportError:
21+
dask = None
2022
da = None
2123

2224
import numba as nb
@@ -232,9 +234,8 @@ def _run_gpu_bin(data, bins, new_values, out):
232234

233235

234236
def _run_cupy_bin(data, bins, new_values):
235-
# replace inf by nan to avoid classify these values as we want to treat them as outliers
236-
data = cupy.where(data == cupy.inf, cupy.nan, data)
237-
data = cupy.where(data == -cupy.inf, cupy.nan, data)
237+
# replace ±inf with nan in a single pass to avoid classifying outliers
238+
data = cupy.where(cupy.isinf(data), cupy.nan, data)
238239

239240
bins_cupy = cupy.asarray(bins)
240241
new_values_cupy = cupy.asarray(new_values)
@@ -679,7 +680,9 @@ def _generate_sample_indices(num_data, num_sample, seed=1234567890):
679680

680681
def _run_dask_natural_break(agg, num_sample, k):
681682
data = agg.data
682-
max_data = float(da.max(data[da.isfinite(data)]).compute())
683+
# Avoid boolean fancy indexing which flattens dask arrays and
684+
# produces chunks of unknown size; use element-wise where instead
685+
max_data = float(da.nanmax(da.where(da.isinf(data), np.nan, data)).compute())
683686

684687
num_data = data.size
685688
if num_sample is not None and num_sample < num_data:
@@ -699,7 +702,9 @@ def _run_dask_natural_break(agg, num_sample, k):
699702

700703
def _run_dask_cupy_natural_break(agg, num_sample, k):
701704
data = agg.data
702-
max_data = float(da.max(data[da.isfinite(data)]).compute().item())
705+
# Avoid boolean fancy indexing which flattens dask arrays and
706+
# produces chunks of unknown size; use element-wise where instead
707+
max_data = float(da.nanmax(da.where(da.isinf(data), np.nan, data)).compute().item())
703708

704709
num_data = data.size
705710
if num_sample is not None and num_sample < num_data:
@@ -817,57 +822,28 @@ def natural_breaks(agg: xr.DataArray,
817822

818823

819824
def _run_equal_interval(agg, k, module):
820-
data = agg.data.ravel()
821-
if module == cupy:
822-
nan = cupy.nan
823-
inf = cupy.inf
824-
else:
825-
nan = np.nan
826-
inf = np.inf
825+
data = agg.data
827826

828-
data = module.where(data == inf, nan, data)
829-
data = module.where(data == -inf, nan, data)
827+
# Replace ±inf with nan in a single pass (no ravel needed)
828+
data_clean = module.where(module.isinf(data), np.nan, data)
830829

831-
max_data = module.nanmax(data)
832-
min_data = module.nanmin(data)
830+
min_lazy = module.nanmin(data_clean)
831+
max_lazy = module.nanmax(data_clean)
833832

834833
if module == cupy:
835-
min_data = min_data.get()
836-
max_data = max_data.get()
837-
838-
if module == da:
839-
min_data = min_data.compute()
840-
max_data = max_data.compute()
841-
842-
width = (max_data - min_data) * 1.0 / k
843-
cuts = module.arange(min_data + width, max_data + width, width)
844-
l_cuts = cuts.shape[0]
845-
if l_cuts > k:
846-
# handle overshooting
847-
cuts = cuts[0:k]
848-
849-
if module == da:
850-
# work around to assign cuts[-1] = max_data
851-
bins = da.concatenate([cuts[:k-1], [max_data]])
852-
out = _bin(agg, bins, np.arange(l_cuts))
834+
min_data = float(min_lazy.get())
835+
max_data = float(max_lazy.get())
836+
elif module == da:
837+
# Compute both in a single pass over the data
838+
min_data, max_data = dask.compute(min_lazy, max_lazy)
839+
min_data = float(min_data)
840+
max_data = float(max_data)
853841
else:
854-
cuts[-1] = max_data
855-
out = _bin(agg, cuts, np.arange(l_cuts))
856-
857-
return out
858-
859-
860-
def _run_dask_cupy_equal_interval(agg, k):
861-
data = agg.data.ravel()
862-
863-
# replace inf with nan
864-
data = da.where(data == np.inf, np.nan, data)
865-
data = da.where(data == -np.inf, np.nan, data)
866-
867-
min_data = float(da.nanmin(data).compute().item())
868-
max_data = float(da.nanmax(data).compute().item())
842+
min_data = float(min_lazy)
843+
max_data = float(max_lazy)
869844

870-
width = (max_data - min_data) * 1.0 / k
845+
width = (max_data - min_data) / k
846+
# Build cuts as numpy — only k elements, no need for dask/cupy overhead
871847
cuts = np.arange(min_data + width, max_data + width, width)
872848
l_cuts = cuts.shape[0]
873849
if l_cuts > k:
@@ -938,7 +914,7 @@ def equal_interval(agg: xr.DataArray,
938914
numpy_func=lambda *args: _run_equal_interval(*args, module=np),
939915
dask_func=lambda *args: _run_equal_interval(*args, module=da),
940916
cupy_func=lambda *args: _run_equal_interval(*args, module=cupy),
941-
dask_cupy_func=_run_dask_cupy_equal_interval
917+
dask_cupy_func=lambda *args: _run_equal_interval(*args, module=da)
942918
)
943919
out = mapper(agg)(agg, k)
944920
return xr.DataArray(out,

0 commit comments

Comments
 (0)