Skip to content

Commit 21b4aad

Browse files
authored
Reinstate tests that now pass due to changes in Cubed. (#45)
1 parent 99855c0 commit 21b4aad

1 file changed

Lines changed: 18 additions & 18 deletions

File tree

cubed_xarray/tests/test_xarray.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import xarray as xr
1616
from cubed import raise_if_computes as raise_if_cubed_computes
1717
from xarray import DataArray, Dataset, Variable
18+
from xarray.core import duck_array_ops
1819
from xarray.tests import (
1920
assert_allclose,
2021
assert_array_equal,
@@ -132,7 +133,7 @@ def test_roll(self):
132133
u = self.eager_var
133134
v = self.lazy_var
134135
self.assertLazyAndIdentical(u.roll(x=2), v.roll(x=2))
135-
# assert v.data.chunks == v.roll(x=1).data.chunks # TODO: fails
136+
assert v.data.chunks == v.roll(x=1).data.chunks
136137

137138
def test_unary_op(self):
138139
u = self.eager_var
@@ -146,7 +147,7 @@ def test_binary_op(self):
146147
v = self.lazy_var
147148
self.assertLazyAndIdentical(2 * u, 2 * v)
148149
self.assertLazyAndIdentical(u + u, v + v)
149-
# self.assertLazyAndIdentical(u[0] + u, v[0] + v) # TODO: fails
150+
self.assertLazyAndIdentical(u[0] + u, v[0] + v)
150151

151152
def test_binary_op_bitshift(self) -> None:
152153
# bit shifts only work on ints so we need to generate
@@ -185,22 +186,23 @@ def test_reduce(self):
185186
u = self.eager_var
186187
v = self.lazy_var
187188
self.assertLazyAndAllClose(u.mean(), v.mean())
188-
# TODO: other reduce functions need work
189+
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
190+
self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
191+
with raise_if_cubed_computes():
192+
v.reduce(duck_array_ops.mean)
193+
# TODO: std, argmax, argmin compute eagerly (not lazy) in cubed
189194
# self.assertLazyAndAllClose(u.std(), v.std())
190195
# with raise_if_cubed_computes():
191196
# actual = v.argmax(dim="x")
192197
# self.assertLazyAndAllClose(u.argmax(dim="x"), actual)
193198
# with raise_if_cubed_computes():
194199
# actual = v.argmin(dim="x")
195200
# self.assertLazyAndAllClose(u.argmin(dim="x"), actual)
196-
# self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
197-
# self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
201+
# TODO: median no longer raises NotImplementedError in xarray
198202
# with pytest.raises(NotImplementedError, match=r"only works along an axis"):
199203
# v.median()
200204
# with pytest.raises(NotImplementedError, match=r"only works along an axis"):
201205
# v.median(v.dims)
202-
# with raise_if_cubed_computes():
203-
# v.reduce(duck_array_ops.mean)
204206

205207
def test_missing_values(self):
206208
values = np.array([0, 1, np.nan, 3])
@@ -210,7 +212,7 @@ def test_missing_values(self):
210212
lazy_var = Variable("x", data)
211213
self.assertLazyAndIdentical(eager_var, lazy_var.fillna(lazy_var))
212214
self.assertLazyAndIdentical(Variable("x", range(4)), lazy_var.fillna(2))
213-
# self.assertLazyAndIdentical(eager_var.count(), lazy_var.count()) # TODO: doesn't use array API
215+
self.assertLazyAndIdentical(eager_var.count(), lazy_var.count())
214216

215217
def test_concat(self):
216218
u = self.eager_var
@@ -423,7 +425,6 @@ def test_ufuncs(self):
423425
v = self.lazy_array
424426
self.assertLazyAndAllClose(np.sin(u), np.sin(v))
425427

426-
@pytest.mark.xfail(reason="failure in cubed")
427428
def test_where_dispatching(self):
428429
a = np.arange(10)
429430
b = a > 3
@@ -665,16 +666,15 @@ def test_unify_chunks(map_ds):
665666
assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5))
666667
assert out_b.chunks == expected_chunks
667668

668-
# TODO: following fails
669-
# # Test unordered dims
670-
# da = ds_copy["cxy"]
671-
# out_a, out_b = xr.unify_chunks(da.chunk({"x": -1}), da.T.chunk({"y": -1}))
672-
# assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5))
673-
# assert out_b.chunks == ((5, 5, 5, 5), (4, 4, 2))
669+
# Test unordered dims
670+
da = ds_copy["cxy"]
671+
out_a, out_b = xr.unify_chunks(da.chunk({"x": -1}), da.T.chunk({"y": -1}))
672+
assert out_a.chunks == ((4, 4, 2), (10, 10))
673+
assert out_b.chunks == ((10, 10), (4, 4, 2))
674674

675-
# # Test mismatch
676-
# with pytest.raises(ValueError, match=r"Dimension 'x' size mismatch: 10 != 2"):
677-
# xr.unify_chunks(da, da.isel(x=slice(2)))
675+
# Test mismatch
676+
with pytest.raises(ValueError, match=r"Dimension 'x' size mismatch: 10 != 2"):
677+
xr.unify_chunks(da, da.isel(x=slice(2)))
678678

679679

680680
@pytest.mark.parametrize("obj", [make_ds(), make_da()])

0 commit comments

Comments
 (0)