Skip to content

Commit cfe757b

Browse files
committed
add a pytest fixture for some of the xr ds
1 parent 64b019b commit cfe757b

1 file changed

Lines changed: 44 additions & 54 deletions

File tree

tests/test_accesor.py

Lines changed: 44 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,25 @@
66
from yt_xarray._utilities import construct_minimal_ds
77

88

9+
@pytest.fixture()
10+
def ds_xr():
11+
# a base xarray ds to be used in various places.
12+
tfield = "a_new_field"
13+
n_x = 3
14+
n_y = 4
15+
n_z = 5
16+
ds = construct_minimal_ds(
17+
field_name=tfield,
18+
n_fields=3,
19+
n_x=n_x,
20+
n_y=n_y,
21+
n_z=n_z,
22+
z_name="depth",
23+
coord_order=["z", "y", "x"],
24+
)
25+
return ds
26+
27+
928
def test_accessor():
1029

1130
tfield = "a_new_field"
@@ -71,33 +90,23 @@ def test_bbox():
7190
# the test dataset.
7291

7392

74-
def test_load_uniform_grid():
93+
def test_load_uniform_grid(ds_xr):
7594

76-
tfield = "a_new_field"
77-
n_x = 3
78-
n_y = 4
79-
n_z = 5
80-
ds = construct_minimal_ds(
81-
field_name=tfield,
82-
n_fields=3,
83-
n_x=n_x,
84-
n_y=n_y,
85-
n_z=n_z,
86-
z_name="depth",
87-
coord_order=["z", "y", "x"],
88-
)
89-
90-
flds = [tfield + "_0", tfield + "_1"]
91-
ds_yt = ds.yt.load_uniform_grid(flds)
95+
flds = ["a_new_field_0", "a_new_field_1"]
96+
ds_yt = ds_xr.yt.load_uniform_grid(flds)
9297
assert ds_yt.coordinates.name == "internal_geographic"
9398
expected_field_list = [("stream", f) for f in flds]
9499
assert all([f in expected_field_list] for f in ds_yt.field_list)
95100

96-
ds_yt = ds.yt.load_uniform_grid() # should generate a ds with all fields
97-
flds = [tfield + "_0", tfield + "_1", tfield + "_2"]
101+
ds_yt = ds_xr.yt.load_uniform_grid() # should generate a ds with all fields
102+
flds = flds + [
103+
"a_new_field_2",
104+
]
98105
expected_field_list = [("stream", f) for f in flds]
99106
assert all([f in expected_field_list] for f in ds_yt.field_list)
100107

108+
tfield = "nice_field"
109+
n_x, n_y, n_z = (7, 5, 17)
101110
ds = construct_minimal_ds(
102111
field_name=tfield,
103112
n_fields=3,
@@ -122,6 +131,9 @@ def test_load_uniform_grid():
122131
y_name="y",
123132
coord_order=["z", "y", "x"],
124133
)
134+
flds = [
135+
tfield + "_0",
136+
]
125137
ds_yt = ds.yt.load_uniform_grid(flds, length_unit="km")
126138
assert ds_yt.coordinates.name == "cartesian"
127139
assert all([f in expected_field_list] for f in ds_yt.field_list)
@@ -130,46 +142,24 @@ def test_load_uniform_grid():
130142
@pytest.mark.skipif(
131143
yt.__version__.startswith("4.1") is False, reason="requires yt>=4.1.0"
132144
)
133-
def test_load_grid_from_callable():
134-
tfield = "a_new_field"
135-
n_x = 3
136-
n_y = 4
137-
n_z = 5
138-
ds_xr = construct_minimal_ds(
139-
field_name=tfield,
140-
n_fields=3,
141-
n_x=n_x,
142-
n_y=n_y,
143-
n_z=n_z,
144-
z_name="depth",
145-
coord_order=["z", "y", "x"],
146-
)
147-
148-
flds = [tfield + "_0", tfield + "_1"]
149-
145+
def test_load_grid_from_callable(ds_xr):
150146
ds = ds_xr.yt.load_grid_from_callable()
147+
flds = list(ds_xr.data_vars)
151148
for fld in flds:
152149
assert ("stream", fld) in ds.field_list
153150

154151
f = ds.all_data()[flds[0]]
155-
assert len(f) == n_x * n_y * n_z
152+
assert len(f) == ds_xr.data_vars[flds[0]].size
156153

157154

158-
def test_yt_ds_attr():
159-
tfield = "a_new_field"
160-
n_x = 3
161-
n_y = 4
162-
n_z = 5
163-
ds_xr = construct_minimal_ds(
164-
field_name=tfield,
165-
n_fields=3,
166-
n_x=n_x,
167-
n_y=n_y,
168-
n_z=n_z,
169-
z_name="depth",
170-
coord_order=["z", "y", "x"],
171-
)
172-
173-
ds = ds_xr.yt.ds() # alias to load_grid_from_callable but good to check
174-
for fld in list(ds_xr.data_vars):
155+
@pytest.mark.skipif(
156+
yt.__version__.startswith("4.1") is False, reason="requires yt>=4.1.0"
157+
)
158+
def test_yt_ds_attr(ds_xr):
159+
ds = ds_xr.yt.ds()
160+
flds = list(ds_xr.data_vars)
161+
for fld in flds:
175162
assert ("stream", fld) in ds.field_list
163+
164+
f = ds.all_data()[flds[0]]
165+
assert len(f) == ds_xr.data_vars[flds[0]].size

0 commit comments

Comments
 (0)