Skip to content

Commit 20a090e

Browse files
committed
Add Dataset tests for rechunk_no_shuffle (#1069)
1 parent 6258962 commit 20a090e

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

xrspatial/tests/test_accessor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def test_dataset_accessor_has_expected_methods():
106106
'proximity', 'allocation', 'direction', 'cost_distance',
107107
'ndvi', 'evi', 'arvi', 'savi', 'nbr', 'sipi',
108108
'rasterize',
109+
'rechunk_no_shuffle',
109110
]
110111
for name in expected:
111112
assert name in names, f"Missing method: {name}"

xrspatial/tests/test_rechunk_no_shuffle.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def test_numpy_passthrough():
9797
# Input validation
9898
# ---------------------------------------------------------------------------
9999

100-
def test_rejects_non_dataarray():
101-
with pytest.raises(TypeError, match="expected xr.DataArray"):
100+
def test_rejects_non_dataarray_or_dataset():
101+
with pytest.raises(TypeError, match="expected xr.DataArray or xr.Dataset"):
102102
rechunk_no_shuffle(np.zeros((10, 10)))
103103

104104

@@ -121,3 +121,75 @@ def test_accessor():
121121
direct = rechunk_no_shuffle(raster, target_mb=16)
122122
via_accessor = raster.xrs.rechunk_no_shuffle(target_mb=16)
123123
assert direct.chunks == via_accessor.chunks
124+
125+
126+
# ---------------------------------------------------------------------------
127+
# Dataset support
128+
# ---------------------------------------------------------------------------
129+
130+
def _make_dask_dataset(chunks=128):
131+
"""Dataset with two dask variables and one numpy variable."""
132+
dask_a = xr.DataArray(
133+
da.zeros((512, 512), chunks=chunks, dtype=np.float32),
134+
dims=['y', 'x'], name='a',
135+
)
136+
dask_b = xr.DataArray(
137+
da.ones((512, 512), chunks=chunks, dtype=np.float64),
138+
dims=['y', 'x'], name='b',
139+
)
140+
numpy_c = xr.DataArray(
141+
np.zeros((512, 512), dtype=np.float32),
142+
dims=['y', 'x'], name='c',
143+
)
144+
return xr.Dataset({'a': dask_a, 'b': dask_b, 'c': numpy_c},
145+
attrs={'crs': 'EPSG:32610'})
146+
147+
148+
def test_dataset_rechunks_all_dask_vars():
149+
"""Both dask variables should get bigger chunks."""
150+
ds = _make_dask_dataset(chunks=64)
151+
result = rechunk_no_shuffle(ds, target_mb=16)
152+
assert isinstance(result, xr.Dataset)
153+
for name in ['a', 'b']:
154+
orig_chunk = ds[name].chunks[0][0]
155+
new_chunk = result[name].chunks[0][0]
156+
assert new_chunk > orig_chunk
157+
assert new_chunk % orig_chunk == 0
158+
159+
160+
def test_dataset_numpy_var_unchanged():
161+
"""Numpy-backed variable passes through without modification."""
162+
ds = _make_dask_dataset()
163+
result = rechunk_no_shuffle(ds, target_mb=16)
164+
# 'c' is numpy-backed, should still be numpy
165+
assert not hasattr(result['c'].data, 'dask')
166+
np.testing.assert_array_equal(result['c'].values, ds['c'].values)
167+
168+
169+
def test_dataset_preserves_attrs_and_coords():
170+
"""Dataset attributes and coordinates survive rechunking."""
171+
ds = _make_dask_dataset()
172+
ds = ds.assign_coords(y=np.arange(512), x=np.arange(512))
173+
result = rechunk_no_shuffle(ds, target_mb=16)
174+
assert result.attrs == ds.attrs
175+
xr.testing.assert_equal(result.coords.to_dataset(), ds.coords.to_dataset())
176+
177+
178+
def test_dataset_preserves_values():
179+
"""Data values are identical after rechunking."""
180+
np.random.seed(1069)
181+
arr = da.from_array(np.random.rand(256, 256).astype(np.float32), chunks=64)
182+
ds = xr.Dataset({'v': xr.DataArray(arr, dims=['y', 'x'])})
183+
result = rechunk_no_shuffle(ds, target_mb=1)
184+
np.testing.assert_array_equal(ds['v'].values, result['v'].values)
185+
186+
187+
def test_dataset_accessor():
188+
"""The Dataset .xrs.rechunk_no_shuffle() accessor works."""
189+
import xrspatial # noqa: F401
190+
ds = _make_dask_dataset(chunks=64)
191+
direct = rechunk_no_shuffle(ds, target_mb=16)
192+
via_accessor = ds.xrs.rechunk_no_shuffle(target_mb=16)
193+
for name in ds.data_vars:
194+
if hasattr(ds[name].data, 'dask'):
195+
assert direct[name].chunks == via_accessor[name].chunks

0 commit comments

Comments
 (0)