44import multiprocessing as mp
55from _weakrefset import WeakSet
66from pathlib import Path
7+ from types import MappingProxyType
78
89import numpy as np
910import pytest
1011
1112from psi_io import wrhdf_3d
1213
14+ import sys
15+ sys .path .insert (0 , str (Path (__file__ ).parent .parent ))
16+ from tests .utils import _MESH_RESOLUTION , _TRACING_DIRECTION , _OLD_MAS , _BASE_MESH_PARAMS , _BASE_LPS_PARAMS , _MESH_RESO_MAPPING , _TRACING_DIR_MAPPING , _OLD_MAS_MAPPING , _DOMAIN_RANGES
17+
1318
1419@pytest .fixture (autouse = True )
1520def reset_singleton (monkeypatch ):
@@ -45,103 +50,144 @@ def launch_point_factory():
4550
4651
4752@pytest .fixture (scope = "session" )
48- def default_params ():
49- from tests .utils import read_defaults
50- return read_defaults ()
53+ def reference_traces ():
54+ ref_path = Path (__file__ ).parent / "data" / "reference_traces.npz"
55+ ref_traces = np .load (ref_path , allow_pickle = True )
56+ return ref_traces
5157
5258
53- @pytest .fixture (scope = "session" )
54- def default_mesh_params (default_params ):
55- return default_params ["mesh" ]
59+ @pytest .fixture (scope = "session" ,
60+ params = _MESH_RESOLUTION )
61+ def mesh_resolution (request ):
62+ return request .param
5663
5764
58- @pytest .fixture (scope = "session" )
59- def default_lps_params (default_params ):
60- return default_params ["lps" ]
65+ @pytest .fixture (scope = "session" ,
66+ params = _TRACING_DIRECTION )
67+ def tracing_direction (request ):
68+ return request .param
69+
70+
71+ @pytest .fixture (scope = "session" ,
72+ params = _OLD_MAS )
73+ def mas_type (request ):
74+ return request .param
6175
6276
6377@pytest .fixture (scope = "session" )
64- def reference_traces ():
65- ref_path = Path (__file__ ).parent / "data" / "reference_traces.npz"
66- ref_traces = np .load (ref_path , allow_pickle = True )
67- return ref_traces
78+ def tmp_data_dir (tmp_path_factory ):
79+ # One temp directory for the whole test session
80+ return tmp_path_factory .mktemp ("test_data_dir" )
6881
6982
7083@pytest .fixture (scope = "session" )
71- def _default_fields_cached (default_mesh_params ):
72- from tests .utils import dipole_field
73- return dipole_field (** default_mesh_params ["base" ], ** default_mesh_params ["params" ]["normal" ])
84+ def _default_fields_cached (dipole_field_factory , mas_type ):
85+ return dipole_field_factory (** _BASE_MESH_PARAMS , ** _MESH_RESO_MAPPING ["normal" ], ** _OLD_MAS_MAPPING [mas_type ])
7486
7587
76- @pytest .fixture (scope = "session" , params = [ "coarse" , "normal" , "fine" ], ids = lambda x : x )
77- def _mesh_fields_cached (tmp_path_factory , default_mesh_params , dipole_field_factory , request ):
78- level = request . param
88+ @pytest .fixture (scope = "session" )
89+ def _mesh_fields_cached (tmp_data_dir , dipole_field_factory ,
90+ mesh_resolution , mas_type ):
7991 br , bt , bp , r , t , p = dipole_field_factory (
80- ** default_mesh_params ["base" ],
81- ** default_mesh_params ["params" ][level ],
92+ ** _BASE_MESH_PARAMS ,
93+ ** _MESH_RESO_MAPPING [mesh_resolution ],
94+ ** _OLD_MAS_MAPPING [mas_type ]
8295 )
83- data_dir = tmp_path_factory .mktemp (f"{ level } _magnetic_field_files" )
96+ data_dir = tmp_data_dir / mesh_resolution / mas_type
97+ data_dir .mkdir (parents = True , exist_ok = True )
8498 for dim , data in zip (['br' , 'bt' , 'bp' ], [br , bt , bp ]):
8599 filepath = data_dir / f"{ dim } .h5"
86- wrhdf_3d ( str ( filepath ), r , t , p , data )
87- yield level , tuple ( data_dir / f" { dim } .h5" for dim in [ 'br' , 'bt' , 'bp' ] ), ( br , bt , bp , r , t , p )
88- shutil . rmtree (data_dir , ignore_errors = True )
100+ if not filepath . exists ():
101+ wrhdf_3d ( str ( filepath ), r , t , p , data )
102+ return tuple (data_dir / f" { dim } .h5" for dim in [ 'br' , 'bt' , 'bp' ]), ( br , bt , bp , r , t , p )
89103
90104
91- @pytest .fixture (scope = "session" , params = ["fwd" , "bwd" , "both" ], ids = lambda x : x )
92- def _launch_points_cached (default_lps_params , launch_point_factory , request ):
93- level = request .param
105+ @pytest .fixture (scope = "session" )
106+ def _launch_points_cached (launch_point_factory , tracing_direction ):
94107 lps = launch_point_factory (
95- ** default_lps_params [ 'base' ] ,
96- ** default_lps_params [ 'params' ][ level ]
108+ ** _BASE_LPS_PARAMS ,
109+ ** _TRACING_DIR_MAPPING [ tracing_direction ],
97110 )
98- return level , lps
111+ return lps
99112
100113
101- @pytest .fixture (scope = "session" , params = [ "coarse" , "normal" , "fine" ], ids = lambda x : x )
102- def interdomain_files (tmp_path_factory , default_params , dipole_field_factory , request ):
103- level = request . param
114+ @pytest .fixture (scope = "session" )
115+ def interdomain_files (tmp_data_dir , dipole_field_factory ,
116+ mesh_resolution , mas_type ):
104117 base_params = {
105- ** default_params ["mesh" ]["base" ],
106- ** default_params ["mesh" ]["params" ][level ]
118+ ** _BASE_MESH_PARAMS ,
119+ ** _MESH_RESO_MAPPING [mesh_resolution ],
120+ ** _OLD_MAS_MAPPING [mas_type ]
107121 }
108- cor_params = {** base_params , ** default_params [ "_testing" ][ "domain_ranges" ][ " cor" ]}
109- hel_params = {** base_params , ** default_params [ "_testing" ][ "domain_ranges" ][ " hel" ]}
122+ cor_params = {** base_params , ** _DOMAIN_RANGES [ ' cor' ]}
123+ hel_params = {** base_params , ** _DOMAIN_RANGES [ ' hel' ]}
110124
111125 br_cor , bt_cor , bp_cor , r_cor , t_cor , p_cor = dipole_field_factory (** cor_params )
112126 br_hel , bt_hel , bp_hel , r_hel , t_hel , p_hel = dipole_field_factory (** hel_params )
113127
114- data_dir = tmp_path_factory .mktemp (f"{ level } _magnetic_field_files" )
128+ data_dir = tmp_data_dir / mesh_resolution / mas_type
129+ data_dir .mkdir (parents = True , exist_ok = True )
115130 for dim , data in zip (['br_cor' , 'bt_cor' , 'bp_cor' ], [br_cor , bt_cor , bp_cor ]):
116131 filepath = data_dir / f"{ dim } .h5"
117- wrhdf_3d (str (filepath ), r_cor , t_cor , p_cor , data )
132+ if not filepath .exists ():
133+ wrhdf_3d (str (filepath ), r_cor , t_cor , p_cor , data )
118134 for dim , data in zip (['br_hel' , 'bt_hel' , 'bp_hel' ], [br_hel , bt_hel , bp_hel ]):
119135 filepath = data_dir / f"{ dim } .h5"
120- wrhdf_3d (str (filepath ), r_hel , t_hel , p_hel , data )
121- yield level , tuple (data_dir / f"{ dim } _{ dom } .h5" for dom in ['cor' , 'hel' ] for dim in ['br' , 'bt' , 'bp' ])
122- shutil .rmtree (data_dir , ignore_errors = True )
136+ if not filepath .exists ():
137+ wrhdf_3d (str (filepath ), r_hel , t_hel , p_hel , data )
138+ return tuple (data_dir / f"{ dim } _{ dom } .h5" for dom in ['cor' , 'hel' ] for dim in ['br' , 'bt' , 'bp' ])
139+
140+
141+ @pytest .fixture (scope = "session" )
142+ def old_and_new_mas (tmp_data_dir , dipole_field_factory , mesh_resolution ):
143+ data_dir_old_mas = tmp_data_dir / mesh_resolution / 'old_mas'
144+ data_dir_new_mas = tmp_data_dir / mesh_resolution / 'new_mas'
145+ if any (not (data_dir_old_mas / f'{ dim } .h5' ).exists () for dim in ('br' , 'bt' , 'bp' )):
146+ br_old , bt_old , bp_old , r_old , t_old , p_old = dipole_field_factory (
147+ ** _BASE_MESH_PARAMS ,
148+ ** _MESH_RESO_MAPPING [mesh_resolution ],
149+ ** _OLD_MAS_MAPPING ['old_mas' ]
150+ )
151+ data_dir_old_mas .mkdir (parents = True , exist_ok = True )
152+ for dim , data in zip (['br' , 'bt' , 'bp' ], [br_old , bt_old , bp_old ]):
153+ fp = data_dir_old_mas / f"{ dim } .h5"
154+ wrhdf_3d (str (fp ), r_old , t_old , p_old , data )
155+
156+ if any (not (data_dir_new_mas / f'{ dim } .h5' ).exists () for dim in ('br' , 'bt' , 'bp' )):
157+ br_new , bt_new , bp_new , r_new , t_new , p_new = dipole_field_factory (
158+ ** _BASE_MESH_PARAMS ,
159+ ** _MESH_RESO_MAPPING [mesh_resolution ],
160+ ** _OLD_MAS_MAPPING ['new_mas' ]
161+ )
162+ data_dir_new_mas .mkdir (parents = True , exist_ok = True )
163+ for dim , data in zip (['br' , 'bt' , 'bp' ], [br_new , bt_new , bp_new ]):
164+ fp = data_dir_new_mas / f"{ dim } .h5"
165+ wrhdf_3d (str (fp ), r_new , t_new , p_new , data )
166+
167+ return (tuple (data_dir_old_mas / f"{ dim } .h5" for dim in ['br' , 'bt' , 'bp' ]),
168+ tuple (data_dir_new_mas / f"{ dim } .h5" for dim in ['br' , 'bt' , 'bp' ]))
123169
124170
125171@pytest .fixture
126172def mesh_fields_asarray (_mesh_fields_cached ):
127- level , _ , fields = _mesh_fields_cached
173+ _ , fields = _mesh_fields_cached
128174 # Return fresh copies of arrays for each test
129175 copied = tuple (np .copy (a ) if hasattr (a , "dtype" ) else a for a in fields )
130- return level , copied
176+ return copied
131177
132178
133179@pytest .fixture
134180def mesh_fields_aspaths (_mesh_fields_cached ):
135- level , paths , _ = _mesh_fields_cached
136- return level , tuple (str (p ) for p in paths )
181+ paths , _ = _mesh_fields_cached
182+ return tuple (str (p ) for p in paths )
137183
138184
139185@pytest .fixture
140186def launch_points (_launch_points_cached ):
141- level , lps = _launch_points_cached
187+ lps = _launch_points_cached
142188 # Return fresh copies of arrays for each test
143189 copied = np .copy (lps )
144- return level , copied
190+ return copied
145191
146192
147193@pytest .fixture
@@ -161,8 +207,7 @@ def _default_datadir(tmp_path_factory, _default_fields_cached):
161207 for dim , data in zip (['br' , 'bt' , 'bp' ], [br , bt , bp ]):
162208 filepath = data_dir / f"{ dim } .h5"
163209 wrhdf_3d (str (filepath ), r , t , p , data )
164- yield data_dir
165- shutil .rmtree (data_dir , ignore_errors = True )
210+ return data_dir
166211
167212
168213@pytest .fixture (scope = "session" )
0 commit comments