Skip to content

Commit 9edd073

Browse files
authored
Fix normalize dask paths: replace boolean indexing with lazy reductions (#1125)
* Add sweep-performance design spec Parallel subagent triage + ralph-loop workflow for auditing all xrspatial modules for performance bottlenecks, OOM risk under 30TB dask workloads, and backend-specific anti-patterns. * Add sweep-performance implementation plan 7 tasks covering command scaffold, module scoring, parallel subagent dispatch, report merging, ralph-loop generation, and smoke tests. * Add sweep-performance slash command * Fix normalize dask paths: replace boolean indexing with nanmin/nanmax (#1124) Replace `data[finite_mask]` (boolean fancy indexing that materializes dask arrays) with `da.where(finite_mask, data, nan)` + `da.nanmin()`/ `da.nanmax()`/`da.nanmean()`/`da.nanstd()` for lazy per-chunk reductions. Guard division by zero in rescale with safe_range to prevent inf/nan in lazy evaluation (da.where evaluates both branches).
1 parent d45f27a commit 9edd073

File tree

1 file changed

+22
-33
lines changed

1 file changed

+22
-33
lines changed

xrspatial/normalize.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,20 @@ def _run_numpy_rescale(data, new_min, new_max):
6464

6565

6666
def _run_dask_numpy_rescale(data, new_min, new_max):
67-
# Compute global stats first (returns scalars), then map element-wise.
67+
# Replace non-finite values with NaN so nanmin/nanmax skip them,
68+
# avoiding boolean fancy indexing (which materializes dask arrays).
6869
finite_mask = da.isfinite(data)
69-
finite_vals = data[finite_mask]
70-
data_min = finite_vals.min()
71-
data_max = finite_vals.max()
70+
finite_data = da.where(finite_mask, data, np.nan)
71+
data_min = da.nanmin(finite_data)
72+
data_max = da.nanmax(finite_data)
7273

7374
new_range = new_max - new_min
7475
data_range = data_max - data_min
75-
76-
out = da.where(
77-
finite_mask,
78-
da.where(
79-
data_range == 0,
80-
new_min,
81-
(data - data_min) / data_range * new_range + new_min,
82-
),
83-
np.nan,
84-
)
76+
# Guard against division by zero: use max(data_range, 1) for the
77+
# division, then overwrite with new_min where data_range == 0.
78+
safe_range = da.where(data_range == 0, 1.0, data_range)
79+
scaled = (data - data_min) / safe_range * new_range + new_min
80+
out = da.where(finite_mask, da.where(data_range == 0, new_min, scaled), np.nan)
8581
return out
8682

8783

@@ -108,22 +104,15 @@ def _run_cupy_rescale(data, new_min, new_max):
108104
def _run_dask_cupy_rescale(data, new_min, new_max):
109105
# Same lazy approach as dask+numpy; dask dispatches to cupy chunks.
110106
finite_mask = da.isfinite(data)
111-
finite_vals = data[finite_mask]
112-
data_min = finite_vals.min()
113-
data_max = finite_vals.max()
107+
finite_data = da.where(finite_mask, data, np.nan)
108+
data_min = da.nanmin(finite_data)
109+
data_max = da.nanmax(finite_data)
114110

115111
new_range = new_max - new_min
116112
data_range = data_max - data_min
117-
118-
out = da.where(
119-
finite_mask,
120-
da.where(
121-
data_range == 0,
122-
new_min,
123-
(data - data_min) / data_range * new_range + new_min,
124-
),
125-
np.nan,
126-
)
113+
safe_range = da.where(data_range == 0, 1.0, data_range)
114+
scaled = (data - data_min) / safe_range * new_range + new_min
115+
out = da.where(finite_mask, da.where(data_range == 0, new_min, scaled), np.nan)
127116
return out
128117

129118

@@ -224,9 +213,9 @@ def _run_numpy_standardize(data, ddof):
224213

225214
def _run_dask_numpy_standardize(data, ddof):
226215
finite_mask = da.isfinite(data)
227-
finite_vals = data[finite_mask]
228-
mean = finite_vals.mean()
229-
std = finite_vals.std(ddof=ddof)
216+
finite_data = da.where(finite_mask, data, np.nan)
217+
mean = da.nanmean(finite_data)
218+
std = da.nanstd(finite_data, ddof=ddof)
230219

231220
out = da.where(
232221
finite_mask,
@@ -254,9 +243,9 @@ def _run_cupy_standardize(data, ddof):
254243

255244
def _run_dask_cupy_standardize(data, ddof):
256245
finite_mask = da.isfinite(data)
257-
finite_vals = data[finite_mask]
258-
mean = finite_vals.mean()
259-
std = finite_vals.std(ddof=ddof)
246+
finite_data = da.where(finite_mask, data, np.nan)
247+
mean = da.nanmean(finite_data)
248+
std = da.nanstd(finite_data, ddof=ddof)
260249

261250
out = da.where(
262251
finite_mask,

0 commit comments

Comments
 (0)