Skip to content

Commit 16b6dcf

Browse files
authored
When using rioxarray.open_rasterio() with band_as_variable=True and then calling .to_array().sel(variable='band_1', (#846)
drop=True), the resulting dask array has a different chunk structure than arrays created directly. The crosstab function failed because it pairs dask blocks positionally via to_delayed().ravel(), which requires aligned chunks. To fix this: Added chunk validation for 2D inputs in crosstab, using the existing validate_arrays function (which is already used by zonal_stats). This function: 1. Validates that shapes match
1 parent 7c8d575 commit 16b6dcf

2 files changed

Lines changed: 56 additions & 0 deletions

File tree

xrspatial/tests/test_zonal.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,56 @@ def test_nodata_values_crosstab_3d(
716716
assert_input_data_unmodified(data_values_3d, copied_data_values_3d)
717717

718718

719+
@pytest.mark.skipif(not dask_array_available(), reason="Requires Dask")
720+
def test_crosstab_dask_from_dataset():
721+
"""
722+
Test crosstab with dask arrays originating from xarray Datasets.
723+
724+
This is a regression test for issue #777 where dask arrays created via
725+
Dataset.to_array().sel() had misaligned chunks that caused IndexError.
726+
"""
727+
# Simulate what happens with rioxarray band_as_variable=True
728+
data_band1 = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
729+
[0, 0, 1, 1, 2, 2, 3, 3],
730+
[0, 0, 1, 1, 2, 2, 3, 3]], dtype=float)
731+
data_band2 = np.array([[1, 1, 2, 2, 3, 3, 0, 0],
732+
[1, 1, 2, 2, 3, 3, 0, 0],
733+
[1, 1, 2, 2, 3, 3, 0, 0]], dtype=float)
734+
735+
# Use different chunk sizes to simulate real-world scenario
736+
dask_band1 = da.from_array(data_band1, chunks=(2, 3))
737+
dask_band2 = da.from_array(data_band2, chunks=(2, 3))
738+
739+
ds = xr.Dataset({
740+
'band_1': (['y', 'x'], dask_band1),
741+
'band_2': (['y', 'x'], dask_band2),
742+
})
743+
744+
# This is the pattern from issue #777: to_array().sel(variable='band_1', drop=True)
745+
values = ds.to_array().sel(variable='band_1', drop=True)
746+
747+
# Create zones with different chunks
748+
zones_data = np.array([[0, 0, 1, 1, 2, 2, 3, 3],
749+
[0, 0, 1, 1, 2, 2, 3, 3],
750+
[0, 0, 1, 1, 2, 2, 3, 3]], dtype=float)
751+
zones_dask = da.from_array(zones_data, chunks=(3, 4))
752+
zones = xr.DataArray(zones_dask, dims=['y', 'x'])
753+
754+
# This should not raise an error
755+
result = crosstab(zones, values)
756+
assert isinstance(result, dd.DataFrame)
757+
758+
result_df = result.compute()
759+
expected = {
760+
'zone': [0.0, 1.0, 2.0, 3.0],
761+
0.0: [6, 0, 0, 0],
762+
1.0: [0, 6, 0, 0],
763+
2.0: [0, 0, 6, 0],
764+
3.0: [0, 0, 0, 6],
765+
}
766+
check_results('dask+numpy', result, expected)
767+
768+
719769
def test_apply():
720770

721771
def func(x):

xrspatial/zonal.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,12 @@ def crosstab(
10341034
if values.ndim not in [2, 3]:
10351035
raise ValueError("`values` must use either 2D or 3D coordinates.")
10361036

1037+
# For 2D values, validate and align chunks between zones and values
1038+
# This is critical for dask arrays that may come from different sources
1039+
# (e.g., xarray Datasets via to_array().sel())
1040+
if values.ndim == 2:
1041+
validate_arrays(zones, values)
1042+
10371043
agg_2d = ["percentage", "count"]
10381044
agg_3d_numpy = _DEFAULT_STATS.keys()
10391045
agg_3d_dask = ["count"]

0 commit comments

Comments
 (0)