@@ -97,8 +97,8 @@ def test_numpy_passthrough():
9797# Input validation
9898# ---------------------------------------------------------------------------
9999
100- def test_rejects_non_dataarray ():
101- with pytest .raises (TypeError , match = "expected xr.DataArray" ):
100+ def test_rejects_non_dataarray_or_dataset ():
101+ with pytest .raises (TypeError , match = "expected xr.DataArray or xr.Dataset " ):
102102 rechunk_no_shuffle (np .zeros ((10 , 10 )))
103103
104104
@@ -121,3 +121,75 @@ def test_accessor():
121121 direct = rechunk_no_shuffle (raster , target_mb = 16 )
122122 via_accessor = raster .xrs .rechunk_no_shuffle (target_mb = 16 )
123123 assert direct .chunks == via_accessor .chunks
124+
125+
126+ # ---------------------------------------------------------------------------
127+ # Dataset support
128+ # ---------------------------------------------------------------------------
129+
130+ def _make_dask_dataset (chunks = 128 ):
131+ """Dataset with two dask variables and one numpy variable."""
132+ dask_a = xr .DataArray (
133+ da .zeros ((512 , 512 ), chunks = chunks , dtype = np .float32 ),
134+ dims = ['y' , 'x' ], name = 'a' ,
135+ )
136+ dask_b = xr .DataArray (
137+ da .ones ((512 , 512 ), chunks = chunks , dtype = np .float64 ),
138+ dims = ['y' , 'x' ], name = 'b' ,
139+ )
140+ numpy_c = xr .DataArray (
141+ np .zeros ((512 , 512 ), dtype = np .float32 ),
142+ dims = ['y' , 'x' ], name = 'c' ,
143+ )
144+ return xr .Dataset ({'a' : dask_a , 'b' : dask_b , 'c' : numpy_c },
145+ attrs = {'crs' : 'EPSG:32610' })
146+
147+
148+ def test_dataset_rechunks_all_dask_vars ():
149+ """Both dask variables should get bigger chunks."""
150+ ds = _make_dask_dataset (chunks = 64 )
151+ result = rechunk_no_shuffle (ds , target_mb = 16 )
152+ assert isinstance (result , xr .Dataset )
153+ for name in ['a' , 'b' ]:
154+ orig_chunk = ds [name ].chunks [0 ][0 ]
155+ new_chunk = result [name ].chunks [0 ][0 ]
156+ assert new_chunk > orig_chunk
157+ assert new_chunk % orig_chunk == 0
158+
159+
160+ def test_dataset_numpy_var_unchanged ():
161+ """Numpy-backed variable passes through without modification."""
162+ ds = _make_dask_dataset ()
163+ result = rechunk_no_shuffle (ds , target_mb = 16 )
164+ # 'c' is numpy-backed, should still be numpy
165+ assert not hasattr (result ['c' ].data , 'dask' )
166+ np .testing .assert_array_equal (result ['c' ].values , ds ['c' ].values )
167+
168+
169+ def test_dataset_preserves_attrs_and_coords ():
170+ """Dataset attributes and coordinates survive rechunking."""
171+ ds = _make_dask_dataset ()
172+ ds = ds .assign_coords (y = np .arange (512 ), x = np .arange (512 ))
173+ result = rechunk_no_shuffle (ds , target_mb = 16 )
174+ assert result .attrs == ds .attrs
175+ xr .testing .assert_equal (result .coords .to_dataset (), ds .coords .to_dataset ())
176+
177+
178+ def test_dataset_preserves_values ():
179+ """Data values are identical after rechunking."""
180+ np .random .seed (1069 )
181+ arr = da .from_array (np .random .rand (256 , 256 ).astype (np .float32 ), chunks = 64 )
182+ ds = xr .Dataset ({'v' : xr .DataArray (arr , dims = ['y' , 'x' ])})
183+ result = rechunk_no_shuffle (ds , target_mb = 1 )
184+ np .testing .assert_array_equal (ds ['v' ].values , result ['v' ].values )
185+
186+
187+ def test_dataset_accessor ():
188+ """The Dataset .xrs.rechunk_no_shuffle() accessor works."""
189+ import xrspatial # noqa: F401
190+ ds = _make_dask_dataset (chunks = 64 )
191+ direct = rechunk_no_shuffle (ds , target_mb = 16 )
192+ via_accessor = ds .xrs .rechunk_no_shuffle (target_mb = 16 )
193+ for name in ds .data_vars :
194+ if hasattr (ds [name ].data , 'dask' ):
195+ assert direct [name ].chunks == via_accessor [name ].chunks
0 commit comments