Skip to content

Commit f09ef98

Browse files
TST: Directly reference defaults in testing utils module; include fixtures for old/new mas type meshes
1 parent d2db495 commit f09ef98

4 files changed

Lines changed: 217 additions & 173 deletions

File tree

tests/conftest.py

Lines changed: 96 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
import multiprocessing as mp
55
from _weakrefset import WeakSet
66
from pathlib import Path
7+
from types import MappingProxyType
78

89
import numpy as np
910
import pytest
1011

1112
from psi_io import wrhdf_3d
1213

14+
import sys
15+
sys.path.insert(0, str(Path(__file__).parent.parent))
16+
from tests.utils import _MESH_RESOLUTION, _TRACING_DIRECTION, _OLD_MAS, _BASE_MESH_PARAMS, _BASE_LPS_PARAMS, _MESH_RESO_MAPPING, _TRACING_DIR_MAPPING, _OLD_MAS_MAPPING, _DOMAIN_RANGES
17+
1318

1419
@pytest.fixture(autouse=True)
1520
def reset_singleton(monkeypatch):
@@ -45,103 +50,144 @@ def launch_point_factory():
4550

4651

4752
@pytest.fixture(scope="session")
48-
def default_params():
49-
from tests.utils import read_defaults
50-
return read_defaults()
53+
def reference_traces():
54+
ref_path = Path(__file__).parent / "data" / "reference_traces.npz"
55+
ref_traces = np.load(ref_path, allow_pickle=True)
56+
return ref_traces
5157

5258

53-
@pytest.fixture(scope="session")
54-
def default_mesh_params(default_params):
55-
return default_params["mesh"]
59+
@pytest.fixture(scope="session",
60+
params=_MESH_RESOLUTION)
61+
def mesh_resolution(request):
62+
return request.param
5663

5764

58-
@pytest.fixture(scope="session")
59-
def default_lps_params(default_params):
60-
return default_params["lps"]
65+
@pytest.fixture(scope="session",
66+
params=_TRACING_DIRECTION)
67+
def tracing_direction(request):
68+
return request.param
69+
70+
71+
@pytest.fixture(scope="session",
72+
params=_OLD_MAS)
73+
def mas_type(request):
74+
return request.param
6175

6276

6377
@pytest.fixture(scope="session")
64-
def reference_traces():
65-
ref_path = Path(__file__).parent / "data" / "reference_traces.npz"
66-
ref_traces = np.load(ref_path, allow_pickle=True)
67-
return ref_traces
78+
def tmp_data_dir(tmp_path_factory):
79+
# One temp directory for the whole test session
80+
return tmp_path_factory.mktemp("test_data_dir")
6881

6982

7083
@pytest.fixture(scope="session")
71-
def _default_fields_cached(default_mesh_params):
72-
from tests.utils import dipole_field
73-
return dipole_field(**default_mesh_params["base"], **default_mesh_params["params"]["normal"])
84+
def _default_fields_cached(dipole_field_factory, mas_type):
85+
return dipole_field_factory(**_BASE_MESH_PARAMS, **_MESH_RESO_MAPPING["normal"], **_OLD_MAS_MAPPING[mas_type])
7486

7587

76-
@pytest.fixture(scope="session", params=["coarse", "normal", "fine"], ids=lambda x: x)
77-
def _mesh_fields_cached(tmp_path_factory, default_mesh_params, dipole_field_factory, request):
78-
level = request.param
88+
@pytest.fixture(scope="session")
89+
def _mesh_fields_cached(tmp_data_dir, dipole_field_factory,
90+
mesh_resolution, mas_type):
7991
br, bt, bp, r, t, p = dipole_field_factory(
80-
**default_mesh_params["base"],
81-
**default_mesh_params["params"][level],
92+
**_BASE_MESH_PARAMS,
93+
**_MESH_RESO_MAPPING[mesh_resolution],
94+
**_OLD_MAS_MAPPING[mas_type]
8295
)
83-
data_dir = tmp_path_factory.mktemp(f"{level}_magnetic_field_files")
96+
data_dir = tmp_data_dir / mesh_resolution / mas_type
97+
data_dir.mkdir(parents=True, exist_ok=True)
8498
for dim, data in zip(['br', 'bt', 'bp'], [br, bt, bp]):
8599
filepath = data_dir / f"{dim}.h5"
86-
wrhdf_3d(str(filepath), r, t, p, data)
87-
yield level, tuple(data_dir / f"{dim}.h5" for dim in ['br', 'bt', 'bp']), (br, bt, bp, r, t, p)
88-
shutil.rmtree(data_dir, ignore_errors=True)
100+
if not filepath.exists():
101+
wrhdf_3d(str(filepath), r, t, p, data)
102+
return tuple(data_dir / f"{dim}.h5" for dim in ['br', 'bt', 'bp']), (br, bt, bp, r, t, p)
89103

90104

91-
@pytest.fixture(scope="session", params=["fwd", "bwd", "both"], ids=lambda x: x)
92-
def _launch_points_cached(default_lps_params, launch_point_factory, request):
93-
level = request.param
105+
@pytest.fixture(scope="session")
106+
def _launch_points_cached(launch_point_factory, tracing_direction):
94107
lps = launch_point_factory(
95-
**default_lps_params['base'],
96-
**default_lps_params['params'][level]
108+
**_BASE_LPS_PARAMS,
109+
**_TRACING_DIR_MAPPING[tracing_direction],
97110
)
98-
return level, lps
111+
return lps
99112

100113

101-
@pytest.fixture(scope="session", params=["coarse", "normal", "fine"], ids=lambda x: x)
102-
def interdomain_files(tmp_path_factory, default_params, dipole_field_factory, request):
103-
level = request.param
114+
@pytest.fixture(scope="session")
115+
def interdomain_files(tmp_data_dir, dipole_field_factory,
116+
mesh_resolution, mas_type):
104117
base_params = {
105-
**default_params["mesh"]["base"],
106-
**default_params["mesh"]["params"][level]
118+
**_BASE_MESH_PARAMS,
119+
**_MESH_RESO_MAPPING[mesh_resolution],
120+
**_OLD_MAS_MAPPING[mas_type]
107121
}
108-
cor_params = {**base_params, **default_params["_testing"]["domain_ranges"]["cor"]}
109-
hel_params = {**base_params, **default_params["_testing"]["domain_ranges"]["hel"]}
122+
cor_params = {**base_params, **_DOMAIN_RANGES['cor']}
123+
hel_params = {**base_params, **_DOMAIN_RANGES['hel']}
110124

111125
br_cor, bt_cor, bp_cor, r_cor, t_cor, p_cor = dipole_field_factory(**cor_params)
112126
br_hel, bt_hel, bp_hel, r_hel, t_hel, p_hel = dipole_field_factory(**hel_params)
113127

114-
data_dir = tmp_path_factory.mktemp(f"{level}_magnetic_field_files")
128+
data_dir = tmp_data_dir / mesh_resolution / mas_type
129+
data_dir.mkdir(parents=True, exist_ok=True)
115130
for dim, data in zip(['br_cor', 'bt_cor', 'bp_cor'], [br_cor, bt_cor, bp_cor]):
116131
filepath = data_dir / f"{dim}.h5"
117-
wrhdf_3d(str(filepath), r_cor, t_cor, p_cor, data)
132+
if not filepath.exists():
133+
wrhdf_3d(str(filepath), r_cor, t_cor, p_cor, data)
118134
for dim, data in zip(['br_hel', 'bt_hel', 'bp_hel'], [br_hel, bt_hel, bp_hel]):
119135
filepath = data_dir / f"{dim}.h5"
120-
wrhdf_3d(str(filepath), r_hel, t_hel, p_hel, data)
121-
yield level, tuple(data_dir / f"{dim}_{dom}.h5" for dom in ['cor', 'hel'] for dim in ['br', 'bt', 'bp'])
122-
shutil.rmtree(data_dir, ignore_errors=True)
136+
if not filepath.exists():
137+
wrhdf_3d(str(filepath), r_hel, t_hel, p_hel, data)
138+
return tuple(data_dir / f"{dim}_{dom}.h5" for dom in ['cor', 'hel'] for dim in ['br', 'bt', 'bp'])
139+
140+
141+
@pytest.fixture(scope="session")
142+
def old_and_new_mas(tmp_data_dir, dipole_field_factory, mesh_resolution):
143+
data_dir_old_mas = tmp_data_dir / mesh_resolution /'old_mas'
144+
data_dir_new_mas = tmp_data_dir / mesh_resolution /'new_mas'
145+
if any(not (data_dir_old_mas / f'{dim}.h5').exists() for dim in ('br', 'bt', 'bp')):
146+
br_old, bt_old, bp_old, r_old, t_old, p_old = dipole_field_factory(
147+
**_BASE_MESH_PARAMS,
148+
**_MESH_RESO_MAPPING[mesh_resolution],
149+
**_OLD_MAS_MAPPING['old_mas']
150+
)
151+
data_dir_old_mas.mkdir(parents=True, exist_ok=True)
152+
for dim, data in zip(['br', 'bt', 'bp'], [br_old, bt_old, bp_old]):
153+
fp = data_dir_old_mas / f"{dim}.h5"
154+
wrhdf_3d(str(fp), r_old, t_old, p_old, data)
155+
156+
if any(not (data_dir_new_mas / f'{dim}.h5').exists() for dim in ('br', 'bt', 'bp')):
157+
br_new, bt_new, bp_new, r_new, t_new, p_new = dipole_field_factory(
158+
**_BASE_MESH_PARAMS,
159+
**_MESH_RESO_MAPPING[mesh_resolution],
160+
**_OLD_MAS_MAPPING['new_mas']
161+
)
162+
data_dir_new_mas.mkdir(parents=True, exist_ok=True)
163+
for dim, data in zip(['br', 'bt', 'bp'], [br_new, bt_new, bp_new]):
164+
fp = data_dir_new_mas / f"{dim}.h5"
165+
wrhdf_3d(str(fp), r_new, t_new, p_new, data)
166+
167+
return (tuple(data_dir_old_mas / f"{dim}.h5" for dim in ['br', 'bt', 'bp']),
168+
tuple(data_dir_new_mas / f"{dim}.h5" for dim in ['br', 'bt', 'bp']))
123169

124170

125171
@pytest.fixture
126172
def mesh_fields_asarray(_mesh_fields_cached):
127-
level, _, fields = _mesh_fields_cached
173+
_, fields = _mesh_fields_cached
128174
# Return fresh copies of arrays for each test
129175
copied = tuple(np.copy(a) if hasattr(a, "dtype") else a for a in fields)
130-
return level, copied
176+
return copied
131177

132178

133179
@pytest.fixture
134180
def mesh_fields_aspaths(_mesh_fields_cached):
135-
level, paths, _ = _mesh_fields_cached
136-
return level, tuple(str(p) for p in paths)
181+
paths, _= _mesh_fields_cached
182+
return tuple(str(p) for p in paths)
137183

138184

139185
@pytest.fixture
140186
def launch_points(_launch_points_cached):
141-
level, lps = _launch_points_cached
187+
lps = _launch_points_cached
142188
# Return fresh copies of arrays for each test
143189
copied = np.copy(lps)
144-
return level, copied
190+
return copied
145191

146192

147193
@pytest.fixture
@@ -161,8 +207,7 @@ def _default_datadir(tmp_path_factory, _default_fields_cached):
161207
for dim, data in zip(['br', 'bt', 'bp'], [br, bt, bp]):
162208
filepath = data_dir / f"{dim}.h5"
163209
wrhdf_3d(str(filepath), r, t, p, data)
164-
yield data_dir
165-
shutil.rmtree(data_dir, ignore_errors=True)
210+
return data_dir
166211

167212

168213
@pytest.fixture(scope="session")

tests/data/defaults.json

Lines changed: 0 additions & 66 deletions
This file was deleted.

tests/data/reference_traces.npz

8.12 MB
Binary file not shown.

0 commit comments

Comments
 (0)