1515import xarray as xr
1616from cubed import raise_if_computes as raise_if_cubed_computes
1717from xarray import DataArray , Dataset , Variable
18+ from xarray .core import duck_array_ops
1819from xarray .tests import (
1920 assert_allclose ,
2021 assert_array_equal ,
@@ -132,7 +133,7 @@ def test_roll(self):
132133 u = self .eager_var
133134 v = self .lazy_var
134135 self .assertLazyAndIdentical (u .roll (x = 2 ), v .roll (x = 2 ))
135- # assert v.data.chunks == v.roll(x=1).data.chunks # TODO: fails
136+ assert v .data .chunks == v .roll (x = 1 ).data .chunks
136137
137138 def test_unary_op (self ):
138139 u = self .eager_var
@@ -146,7 +147,7 @@ def test_binary_op(self):
146147 v = self .lazy_var
147148 self .assertLazyAndIdentical (2 * u , 2 * v )
148149 self .assertLazyAndIdentical (u + u , v + v )
149- # self.assertLazyAndIdentical(u[0] + u, v[0] + v) # TODO: fails
150+ self .assertLazyAndIdentical (u [0 ] + u , v [0 ] + v )
150151
151152 def test_binary_op_bitshift (self ) -> None :
152153 # bit shifts only work on ints so we need to generate
@@ -185,22 +186,23 @@ def test_reduce(self):
185186 u = self .eager_var
186187 v = self .lazy_var
187188 self .assertLazyAndAllClose (u .mean (), v .mean ())
188- # TODO: other reduce functions need work
189+ self .assertLazyAndAllClose ((u > 1 ).any (), (v > 1 ).any ())
190+ self .assertLazyAndAllClose ((u < 1 ).all ("x" ), (v < 1 ).all ("x" ))
191+ with raise_if_cubed_computes ():
192+ v .reduce (duck_array_ops .mean )
193+ # TODO: std, argmax, argmin compute eagerly (not lazy) in cubed
189194 # self.assertLazyAndAllClose(u.std(), v.std())
190195 # with raise_if_cubed_computes():
191196 # actual = v.argmax(dim="x")
192197 # self.assertLazyAndAllClose(u.argmax(dim="x"), actual)
193198 # with raise_if_cubed_computes():
194199 # actual = v.argmin(dim="x")
195200 # self.assertLazyAndAllClose(u.argmin(dim="x"), actual)
196- # self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
197- # self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
201+ # TODO: median no longer raises NotImplementedError in xarray
198202 # with pytest.raises(NotImplementedError, match=r"only works along an axis"):
199203 # v.median()
200204 # with pytest.raises(NotImplementedError, match=r"only works along an axis"):
201205 # v.median(v.dims)
202- # with raise_if_cubed_computes():
203- # v.reduce(duck_array_ops.mean)
204206
205207 def test_missing_values (self ):
206208 values = np .array ([0 , 1 , np .nan , 3 ])
@@ -210,7 +212,7 @@ def test_missing_values(self):
210212 lazy_var = Variable ("x" , data )
211213 self .assertLazyAndIdentical (eager_var , lazy_var .fillna (lazy_var ))
212214 self .assertLazyAndIdentical (Variable ("x" , range (4 )), lazy_var .fillna (2 ))
213- # self.assertLazyAndIdentical(eager_var.count(), lazy_var.count()) # TODO: doesn't use array API
215+ self .assertLazyAndIdentical (eager_var .count (), lazy_var .count ())
214216
215217 def test_concat (self ):
216218 u = self .eager_var
@@ -423,7 +425,6 @@ def test_ufuncs(self):
423425 v = self .lazy_array
424426 self .assertLazyAndAllClose (np .sin (u ), np .sin (v ))
425427
426- @pytest .mark .xfail (reason = "failure in cubed" )
427428 def test_where_dispatching (self ):
428429 a = np .arange (10 )
429430 b = a > 3
@@ -665,16 +666,15 @@ def test_unify_chunks(map_ds):
665666 assert out_a .chunks == ((4 , 4 , 2 ), (5 , 5 , 5 , 5 ))
666667 assert out_b .chunks == expected_chunks
667668
668- # TODO: following fails
669- # # Test unordered dims
670- # da = ds_copy["cxy"]
671- # out_a, out_b = xr.unify_chunks(da.chunk({"x": -1}), da.T.chunk({"y": -1}))
672- # assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5))
673- # assert out_b.chunks == ((5, 5, 5, 5), (4, 4, 2))
669+ # Test unordered dims
670+ da = ds_copy ["cxy" ]
671+ out_a , out_b = xr .unify_chunks (da .chunk ({"x" : - 1 }), da .T .chunk ({"y" : - 1 }))
672+ assert out_a .chunks == ((4 , 4 , 2 ), (10 , 10 ))
673+ assert out_b .chunks == ((10 , 10 ), (4 , 4 , 2 ))
674674
675- # # Test mismatch
676- # with pytest.raises(ValueError, match=r"Dimension 'x' size mismatch: 10 != 2"):
677- # xr.unify_chunks(da, da.isel(x=slice(2)))
675+ # Test mismatch
676+ with pytest .raises (ValueError , match = r"Dimension 'x' size mismatch: 10 != 2" ):
677+ xr .unify_chunks (da , da .isel (x = slice (2 )))
678678
679679
680680@pytest .mark .parametrize ("obj" , [make_ds (), make_da ()])
0 commit comments