Skip to content

Commit 24ac602

Browse files
committed
Fix dask benchmark setup for proximity and zonal
- Proximity: compute unique values from numpy data before slicing, since dask unique() returns unknown chunk sizes - Zonal: skip custom stats benchmark for dask (API only supports default stats); fix else branch to not override None for dask
1 parent ad1f1b9 commit 24ac602

2 files changed

Lines changed: 10 additions & 18 deletions

File tree

benchmarks/benchmarks/proximity.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ class Base:
1717
def setup(self, nx, n_target_values, distance_metric, type):
1818
ny = nx // 2
1919
self.agg = get_xr_dataarray((ny, nx), type, is_int=True)
20-
unique_values = np.unique(self.agg.data)
20+
data = self.agg.data
21+
if type == "dask":
22+
data = data.compute()
23+
unique_values = np.unique(data)
2124
self.target_values = unique_values[:n_target_values]
2225

2326

benchmarks/benchmarks/zonal.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,9 @@ def setup(self, raster_dim, zone_dim, backend):
6464
self.zones = create_arr(zones, backend=backend)
6565

6666
# Now setup the custom stat funcs
67-
if backend == 'dask':
68-
from xrspatial.utils import ngjit
69-
70-
@ngjit
71-
def l2normKernel(arr):
72-
acc = 0
73-
for x in arr:
74-
acc += x * x
75-
return np.sqrt(acc)
76-
77-
self.custom_stats = {
78-
'double_sum': lambda val: val.sum()*2,
79-
'l2norm': lambda val: np.sqrt(np.sum(val * val)),
80-
'l2normKernel': lambda val: l2normKernel(val)
81-
}
82-
elif backend == 'cupy':
67+
# Dask backend only supports default stats, so custom_stats is None.
68+
self.custom_stats = None
69+
if backend == 'cupy':
8370
import cupy
8471
l2normKernel = cupy.ReductionKernel(
8572
in_params='T x', out_params='float64 y',
@@ -92,7 +79,7 @@ def l2normKernel(arr):
9279
'l2norm': lambda val: np.sqrt(cupy.sum(val * val)),
9380
'l2normKernel': lambda val: l2normKernel(val)
9481
}
95-
else:
82+
elif backend == 'numpy':
9683
from xrspatial.utils import ngjit
9784

9885
@ngjit
@@ -112,5 +99,7 @@ def time_zonal_stats_default(self, raster_dim, zone_dim, backend):
11299
zonal.stats(zones=self.zones, values=self.values)
113100

114101
def time_zonal_stats_custom(self, raster_dim, zone_dim, backend):
102+
if self.custom_stats is None:
103+
raise NotImplementedError()
115104
zonal.stats(zones=self.zones, values=self.values,
116105
stats_funcs=self.custom_stats)

0 commit comments

Comments
 (0)