Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 67 additions & 20 deletions xrspatial/tests/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,18 +680,8 @@ def test_stats_all_nan_zone(backend):
'sum': [12.0],
'count': [2],
}
elif 'dask' in backend:
# dask uses nansum reduction, so count/sum of all-NaN become 0
expected = {
'zone': [1, 2],
'mean': [np.nan, 6.0],
'max': [np.nan, 7.0],
'min': [np.nan, 5.0],
'sum': [0.0, 12.0],
'count': [0, 2],
}
else:
# numpy keeps empty zone with NaN for every stat
# numpy and dask both return NaN for all-NaN zones
expected = {
'zone': [1, 2],
'mean': [np.nan, 6.0],
Expand Down Expand Up @@ -798,16 +788,8 @@ def test_stats_nodata_wipes_zone(backend):
'sum': [10.0],
'count': [2],
}
elif 'dask' in backend:
expected = {
'zone': [1, 2],
'mean': [np.nan, 5.0],
'max': [np.nan, 7.0],
'min': [np.nan, 3.0],
'sum': [0.0, 10.0],
'count': [0, 2],
}
else:
# numpy and dask both return NaN for zones with no valid values
expected = {
'zone': [1, 2],
'mean': [np.nan, 5.0],
Expand Down Expand Up @@ -868,6 +850,71 @@ def test_zonal_stats_inputs_unmodified(backend, data_zones, data_values_2d, resu
assert_input_data_unmodified(data_values_2d, copied_data_values_2d)


@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning")
@pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning")
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
def test_stats_variance_numerical_stability_1090(backend):
"""Dask std/var should match numpy for data with large mean, small spread.

Regression test for #1090: the naive one-pass formula
``(Σx² − (Σx)²/n) / n`` loses precision through catastrophic
cancellation. The fix uses Chan-Golub-LeVeque parallel merge.
"""
if 'dask' in backend and not dask_array_available():
pytest.skip("Requires Dask")

# Values near 1e8 with a spread of 1: the naive formula would lose
# most of the significant digits in float64.
zones_data = np.array([[1, 1, 1, 1, 1, 1]])
values_data = np.array([[1e8, 1e8 + 1, 1e8 + 2,
1e8 + 3, 1e8 + 4, 1e8 + 5]], dtype=np.float64)

zones = create_test_raster(zones_data, backend, chunks=(1, 3))
values = create_test_raster(values_data, backend, chunks=(1, 3))

df_result = stats(zones=zones, values=values,
stats_funcs=['mean', 'std', 'var'])

if hasattr(df_result, 'compute'):
df_result = df_result.compute()

# Reference: population variance of [0,1,2,3,4,5] = 35/12 ≈ 2.9167
expected_var = np.var(np.arange(6, dtype=np.float64))
expected_std = np.std(np.arange(6, dtype=np.float64))

actual_var = float(df_result['var'].iloc[0])
actual_std = float(df_result['std'].iloc[0])

assert abs(actual_var - expected_var) < 1e-6, (
f"var={actual_var}, expected={expected_var}"
)
assert abs(actual_std - expected_std) < 1e-6, (
f"std={actual_std}, expected={expected_std}"
)


def test_stats_nodata_none_no_warning_1090():
"""Passing nodata_values=None (the default) should not trigger warnings.

Regression test for #1090: ``zone_values != None`` triggered a numpy
FutureWarning.
"""
import warnings

zones_data = np.array([[1, 1], [2, 2]], dtype=float)
values_data = np.array([[1.0, 2.0], [3.0, 4.0]])
zones = xr.DataArray(zones_data)
values = xr.DataArray(values_data)

with warnings.catch_warnings():
warnings.simplefilter("error")
df = stats(zones=zones, values=values, nodata_values=None)

assert len(df) == 2
assert float(df['mean'].iloc[0]) == 1.5
assert float(df['mean'].iloc[1]) == 3.5


@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning")
@pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning")
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
Expand Down
127 changes: 101 additions & 26 deletions xrspatial/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,77 @@ def _stats_majority(data):
min=lambda z: z.min(),
sum=lambda z: z.sum(),
count=lambda z: _stats_count(z),
sum_squares=lambda z: (z**2).sum()
sum_squares=lambda z: ((z - z.mean()) ** 2).sum() # block-level M2
)


def _nanreduce_preserve_allnan(blocks, func):
"""Reduce across blocks, returning NaN when ALL blocks are NaN for a zone.

``np.nansum`` returns 0 for all-NaN input; we want NaN so that zones
with no valid values propagate NaN, consistent with the numpy backend.
"""
result = func(blocks, axis=0)
all_nan = np.all(np.isnan(blocks), axis=0)
result[all_nan] = np.nan
return result


_DASK_STATS = dict(
max=lambda block_maxes: np.nanmax(block_maxes, axis=0),
min=lambda block_mins: np.nanmin(block_mins, axis=0),
sum=lambda block_sums: np.nansum(block_sums, axis=0),
count=lambda block_counts: np.nansum(block_counts, axis=0),
sum_squares=lambda block_sum_squares: np.nansum(block_sum_squares, axis=0),
squared_sum=lambda block_sums: np.nansum(block_sums, axis=0)**2,
max=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmax),
min=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmin),
sum=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum),
count=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum),
sum_squares=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum),
)
def _dask_mean(sums, counts): return sums / counts # noqa
def _dask_std(sum_squares, squared_sum, n): return np.sqrt((sum_squares - squared_sum/n) / n) # noqa
def _dask_var(sum_squares, squared_sum, n): return (sum_squares - squared_sum/n) / n # noqa


def _dask_mean(sums, counts): # noqa
return sums / counts


def _parallel_variance(block_counts, block_sums, block_m2s):
"""Population variance via Chan-Golub-LeVeque parallel merge.

Each input is (n_blocks, n_zones). ``block_m2s`` contains
per-block M2 values (sum of squared deviations from the block mean),
NOT raw sum-of-squares. Returns (n_zones,) population variance,
with NaN for zones that have no valid values in any block.
"""
n_blocks = block_counts.shape[0]
n_zones = block_counts.shape[1]

n_acc = np.zeros(n_zones, dtype=np.float64)
mean_acc = np.zeros(n_zones, dtype=np.float64)
m2_acc = np.zeros(n_zones, dtype=np.float64)

for i in range(n_blocks):
nc = np.asarray(block_counts[i], dtype=np.float64)
sc = np.asarray(block_sums[i], dtype=np.float64)
m2_b = np.asarray(block_m2s[i], dtype=np.float64)

has_data = np.isfinite(nc) & (nc > 0)
nc_safe = np.where(has_data, nc, 1.0) # avoid /0

with np.errstate(invalid='ignore', divide='ignore'):
mean_b = sc / nc_safe

nc = np.where(has_data, nc, 0.0)
n_ab = n_acc + nc

delta = mean_b - mean_acc
with np.errstate(invalid='ignore', divide='ignore'):
n_ab_safe = np.where(n_ab > 0, n_ab, 1.0)
correction = delta ** 2 * n_acc * nc / n_ab_safe
new_mean = mean_acc + delta * nc / n_ab_safe

m2_acc = np.where(has_data, m2_acc + m2_b + correction, m2_acc)
mean_acc = np.where(has_data, new_mean, mean_acc)
n_acc = np.where(has_data, n_ab, n_acc)

with np.errstate(invalid='ignore', divide='ignore'):
var = np.where(n_acc > 0, m2_acc / n_acc, np.nan)
return var


@ngjit
Expand Down Expand Up @@ -269,7 +325,10 @@ def _calc_stats(
if unique_zones[i] in zone_ids:
zone_values = values_by_zones[start:end]
# filter out non-finite and nodata_values
zone_values = zone_values[np.isfinite(zone_values) & (zone_values != nodata_values)]
mask = np.isfinite(zone_values)
if nodata_values is not None:
mask = mask & (zone_values != nodata_values)
zone_values = zone_values[mask]
if len(zone_values) > 0:
results[i] = func(zone_values)
start = end
Expand Down Expand Up @@ -342,9 +401,11 @@ def _stats_dask_numpy(
sum=values.dtype,
count=np.int64,
sum_squares=values.dtype,
squared_sum=values.dtype,
)

# Keep per-block stacked arrays for the parallel variance merge
stacked_blocks = {}

for s in basis_stats:
if s == 'sum_squares' and not compute_sum_squares:
continue
Expand All @@ -358,21 +419,34 @@ def _stats_dask_numpy(
for z, v in zip(zones_blocks, values_blocks)
]
zonal_stats = da.stack(stats_by_block, allow_unknown_chunksizes=True)

if compute_sum_squares and s in ('count', 'sum', 'sum_squares'):
stacked_blocks[s] = zonal_stats

stats_func_by_block = delayed(_DASK_STATS[s])
stats_dict[s] = da.from_delayed(
stats_func_by_block(zonal_stats), shape=(np.nan,), dtype=np.float64
)

if 'mean' in stats_funcs:
stats_dict['mean'] = _dask_mean(stats_dict['sum'], stats_dict['count'])
if 'std' in stats_funcs:
stats_dict['std'] = _dask_std(
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']
)
if 'var' in stats_funcs:
stats_dict['var'] = _dask_var(
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']

if 'std' in stats_funcs or 'var' in stats_funcs:
var_result = da.from_delayed(
delayed(_parallel_variance)(
stacked_blocks['count'],
stacked_blocks['sum'],
stacked_blocks['sum_squares'],
),
shape=(np.nan,), dtype=np.float64,
)
if 'var' in stats_funcs:
stats_dict['var'] = var_result
if 'std' in stats_funcs:
stats_dict['std'] = da.from_delayed(
delayed(np.sqrt)(var_result),
shape=(np.nan,), dtype=np.float64,
)

# generate dask dataframe
stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1, ignore_unknown_divisions=True)
Expand Down Expand Up @@ -846,9 +920,10 @@ def _single_zone_crosstab_2d(
):
# 1D flatten zone_values, i.e, original data is 2D
# filter out non-finite and nodata_values
zone_values = zone_values[
np.isfinite(zone_values) & (zone_values != nodata_values)
]
mask = np.isfinite(zone_values)
if nodata_values is not None:
mask = mask & (zone_values != nodata_values)
zone_values = zone_values[mask]
total_count = zone_values.shape[0]
crosstab_dict[TOTAL_COUNT].append(total_count)

Expand Down Expand Up @@ -877,10 +952,10 @@ def _single_zone_crosstab_3d(
if cat in cat_ids:
zone_cat_data = zone_values[j]
# filter out non-finite and nodata_values
zone_cat_data = zone_cat_data[
np.isfinite(zone_cat_data)
& (zone_cat_data != nodata_values)
]
cat_mask = np.isfinite(zone_cat_data)
if nodata_values is not None:
cat_mask = cat_mask & (zone_cat_data != nodata_values)
zone_cat_data = zone_cat_data[cat_mask]
crosstab_dict[cat].append(stats_func(zone_cat_data))


Expand Down
Loading