Skip to content

Commit 7fa9e04

Browse files
authored
Cut head_tail_breaks and box_plot dask re-scans (#1213)
head_tail_breaks (dask) called .compute() three times per iteration of its while-loop (mean, new-mask count, old-mask count) and rebuilt the same data_clean graph every time. For N iterations that was 3N+1 full graph traversals. Persist data_clean once, track the running mask count across iterations, and fuse the mean+head-count reductions into a single dask.compute() per iteration. Wall time drops from ~910 ms to ~340 ms on 256x256 chunks=64. box_plot (dask and dask+cupy) did data_clean[da.isfinite(data_clean)] which is boolean fancy indexing on a dask array. That forces compute_chunk_sizes, materializing a full scan just to know the output chunk layout before percentile can run. Swap in the same seeded _generate_sample_indices sampler that natural_breaks/quantile already use: gather 200k indices on the dask array, compute the sample and the global nanmax in one dask.compute() call, and take percentiles on the finite portion of the sample in numpy.
1 parent 2cf545d commit 7fa9e04

1 file changed

Lines changed: 83 additions & 43 deletions

File tree

xrspatial/classify.py

Lines changed: 83 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,19 +1079,31 @@ def _run_head_tail_breaks(agg, module):
10791079

10801080
def _run_dask_head_tail_breaks(agg):
10811081
data = agg.data
1082-
data_clean = da.where(da.isinf(data), np.nan, data)
1082+
# Persist once so the iterative loop does not re-read from storage on
1083+
# every scan. Fuse the three reductions per iteration into a single
1084+
# dask.compute() to cut graph traversals from 3N+1 to N+1.
1085+
data_clean = da.where(da.isinf(data), np.nan, data).persist()
10831086
bins = []
10841087
mask = da.isfinite(data_clean)
1088+
total_count_lazy = mask.sum()
1089+
total_count = int(total_count_lazy.compute())
1090+
if total_count == 0:
1091+
max_v = float(da.nanmax(data_clean).compute())
1092+
return _bin(agg, np.array([max_v]), np.arange(1))
1093+
1094+
current_total = total_count
10851095
while True:
10861096
current = da.where(mask, data_clean, np.nan)
1087-
mean_v = float(da.nanmean(current).compute())
1097+
new_mask = mask & (data_clean > da.nanmean(current))
1098+
# Fuse mean and head-count into one graph evaluation.
1099+
mean_v, head_count = dask.compute(da.nanmean(current), new_mask.sum())
1100+
mean_v = float(mean_v)
1101+
head_count = int(head_count)
10881102
bins.append(mean_v)
1089-
new_mask = mask & (data_clean > mean_v)
1090-
head_count = int(new_mask.sum().compute())
1091-
total_count = int(mask.sum().compute())
1092-
if head_count == 0 or head_count / total_count > 0.40:
1103+
if head_count == 0 or head_count / current_total > 0.40:
10931104
break
10941105
mask = new_mask
1106+
current_total = head_count
10951107
max_v = float(da.nanmax(data_clean).compute())
10961108
bins.append(max_v)
10971109
bins = np.array(bins)
@@ -1366,6 +1378,22 @@ def maximum_breaks(agg: xr.DataArray,
13661378
attrs=agg.attrs)
13671379

13681380

1381+
_BOX_PLOT_DEFAULT_SAMPLE = 200_000
1382+
1383+
1384+
def _box_plot_bins_from_sample(finite_np, hinge, max_v):
1385+
q1 = float(np.percentile(finite_np, 25))
1386+
q2 = float(np.percentile(finite_np, 50))
1387+
q3 = float(np.percentile(finite_np, 75))
1388+
iqr = q3 - q1
1389+
raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v]
1390+
bins = np.sort(np.unique(raw_bins))
1391+
bins = bins[bins <= max_v]
1392+
if bins[-1] < max_v:
1393+
bins = np.append(bins, max_v)
1394+
return bins
1395+
1396+
13691397
def _run_box_plot(agg, hinge, module):
13701398
data = agg.data
13711399
data_clean = module.where(module.isinf(data), np.nan, data)
@@ -1376,51 +1404,63 @@ def _run_box_plot(agg, hinge, module):
13761404
q2 = float(cupy.percentile(finite_data, 50).get())
13771405
q3 = float(cupy.percentile(finite_data, 75).get())
13781406
max_v = float(cupy.nanmax(finite_data).get())
1379-
elif module == da:
1380-
q1_l = da.percentile(finite_data, 25)
1381-
q2_l = da.percentile(finite_data, 50)
1382-
q3_l = da.percentile(finite_data, 75)
1383-
max_l = da.nanmax(data_clean)
1384-
q1, q2, q3, max_v = dask.compute(q1_l, q2_l, q3_l, max_l)
1385-
q1, q2, q3, max_v = q1.item(), q2.item(), q3.item(), max_v.item()
1386-
else:
1387-
q1 = float(np.percentile(finite_data, 25))
1388-
q2 = float(np.percentile(finite_data, 50))
1389-
q3 = float(np.percentile(finite_data, 75))
1390-
max_v = float(np.nanmax(finite_data))
1407+
iqr = q3 - q1
1408+
raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v]
1409+
bins = np.sort(np.unique(raw_bins))
1410+
bins = bins[bins <= max_v]
1411+
if bins[-1] < max_v:
1412+
bins = np.append(bins, max_v)
1413+
else: # numpy
1414+
finite_np = np.asarray(finite_data)
1415+
max_v = float(np.nanmax(finite_np))
1416+
bins = _box_plot_bins_from_sample(finite_np, hinge, max_v)
13911417

1392-
iqr = q3 - q1
1393-
raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v]
1394-
bins = np.sort(np.unique(raw_bins))
1395-
# Remove bins above max (they'd create empty classes)
1396-
bins = bins[bins <= max_v]
1397-
if bins[-1] < max_v:
1398-
bins = np.append(bins, max_v)
13991418
out = _bin(agg, bins, np.arange(len(bins)))
14001419
return out
14011420

14021421

1422+
def _run_dask_box_plot(agg, hinge):
1423+
"""Dask+numpy box_plot.
1424+
1425+
Avoids boolean fancy indexing on a dask array (which produces unknown
1426+
chunk sizes and forces a chunk-size compute pass). Samples the data
1427+
via the same seeded index sampler used by natural_breaks, then
1428+
computes percentiles on the finite portion of the sample in numpy.
1429+
"""
1430+
data = agg.data
1431+
data_clean = da.where(da.isinf(data), np.nan, data)
1432+
num_data = data_clean.size
1433+
num_sample = min(_BOX_PLOT_DEFAULT_SAMPLE, num_data)
1434+
sample_idx = _generate_sample_indices(num_data, num_sample)
1435+
1436+
sample_lazy = data_clean.ravel()[sample_idx]
1437+
max_lazy = da.nanmax(data_clean)
1438+
sample_np, max_v = dask.compute(sample_lazy, max_lazy)
1439+
sample_np = np.asarray(sample_np)
1440+
finite_np = sample_np[np.isfinite(sample_np)]
1441+
max_v = float(max_v)
1442+
1443+
bins = _box_plot_bins_from_sample(finite_np, hinge, max_v)
1444+
return _bin(agg, bins, np.arange(len(bins)))
1445+
1446+
14031447
def _run_dask_cupy_box_plot(agg, hinge):
1448+
"""Dask+cupy box_plot: sample on-device, transfer only the sample."""
14041449
data = agg.data
1405-
data_cpu = data.map_blocks(cupy.asnumpy, dtype=data.dtype, meta=np.array(()))
1406-
data_clean = da.where(da.isinf(data_cpu), np.nan, data_cpu)
1407-
finite_data = data_clean[da.isfinite(data_clean)]
1450+
data_clean = da.where(da.isinf(data), np.nan, data)
1451+
num_data = data_clean.size
1452+
num_sample = min(_BOX_PLOT_DEFAULT_SAMPLE, num_data)
1453+
sample_idx = _generate_sample_indices(num_data, num_sample)
14081454

1409-
q1_l = da.percentile(finite_data, 25)
1410-
q2_l = da.percentile(finite_data, 50)
1411-
q3_l = da.percentile(finite_data, 75)
1412-
max_l = da.nanmax(data_clean)
1413-
q1, q2, q3, max_v = dask.compute(q1_l, q2_l, q3_l, max_l)
1414-
q1, q2, q3, max_v = q1.item(), q2.item(), q3.item(), max_v.item()
1455+
sample_lazy = data_clean.ravel()[sample_idx]
1456+
max_lazy = da.nanmax(data_clean)
1457+
sample_cp, max_v = dask.compute(sample_lazy, max_lazy)
1458+
sample_np = cupy.asnumpy(sample_cp)
1459+
finite_np = sample_np[np.isfinite(sample_np)]
1460+
max_v = float(cupy.asnumpy(max_v).item()) if hasattr(max_v, 'get') else float(max_v)
14151461

1416-
iqr = q3 - q1
1417-
raw_bins = [q1 - hinge * iqr, q1, q2, q3, q3 + hinge * iqr, max_v]
1418-
bins = np.sort(np.unique(raw_bins))
1419-
bins = bins[bins <= max_v]
1420-
if bins[-1] < max_v:
1421-
bins = np.append(bins, max_v)
1422-
out = _bin(agg, bins, np.arange(len(bins)))
1423-
return out
1462+
bins = _box_plot_bins_from_sample(finite_np, hinge, max_v)
1463+
return _bin(agg, bins, np.arange(len(bins)))
14241464

14251465

14261466
@supports_dataset
@@ -1459,7 +1499,7 @@ def box_plot(agg: xr.DataArray,
14591499

14601500
mapper = ArrayTypeFunctionMapping(
14611501
numpy_func=lambda *args: _run_box_plot(*args, module=np),
1462-
dask_func=lambda *args: _run_box_plot(*args, module=da),
1502+
dask_func=_run_dask_box_plot,
14631503
cupy_func=lambda *args: _run_box_plot(*args, module=cupy),
14641504
dask_cupy_func=_run_dask_cupy_box_plot,
14651505
)

0 commit comments

Comments
 (0)