Skip to content

Commit f22dd53

Browse files
TST: Update testing suite to include parameterized 'mas_type' meshes; compare old-style and new-style traces against one another
1 parent f09ef98 commit f22dd53

6 files changed

Lines changed: 218 additions & 59 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ classifiers = [
6262
requires-python = ">=3.10"
6363
dependencies = [
6464
"numpy>=2.1.0",
65-
"psi-io>=2.0.2",
65+
"psi-io>=2.0.5",
6666
]
6767

6868
# Project URLS

tests/test_module.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,3 @@ def test_shared_object():
1313
from tests.utils import check_shared_object
1414
location = check_shared_object()
1515
assert location and isinstance(location, str)
16-
17-
18-
def test_reference_tracers_meta(reference_traces, default_params):
19-
tracer_metadata = json.loads(reference_traces["__meta__"].item())
20-
tracer_defaults = tracer_metadata["defaults"]
21-
for key, value in default_params.items():
22-
if key != '_testing':
23-
if isinstance(value, dict):
24-
for subkey, subvalue in value.items():
25-
assert tracer_defaults[key][subkey] == subvalue
26-
else:
27-
assert tracer_defaults[key] == value

tests/test_tracer/test_fortran_calls.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,12 @@ def test_trace_without_fieldline_validation(tracer_instance, default_fields_asar
2929
assert isinstance(response, Traces)
3030

3131

32+
def test_multiple_run_calls(tracer_instance, default_fields_asarrays):
33+
br_in, bt_in, bp_in, r_in, t_in, p_in = default_fields_asarrays
34+
tracer_instance.br = br_in, r_in, t_in, p_in
35+
tracer_instance.bt = bt_in, r_in, t_in, p_in
36+
tracer_instance.bp = bp_in, r_in, t_in, p_in
37+
tracer_instance.run()
38+
tracer_instance.run()
39+
40+
Lines changed: 82 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,105 @@
11
import pytest
22
from numpy.testing import assert_allclose
33

4-
from tests.utils import compute_weighted_fieldline_difference
4+
from tests.utils import compute_weighted_fieldline_difference, _BUFFER_SIZE, _TEST_TOLERANCES
55
from mapflpy.utils import trim_fieldline_nan_buffer, combine_fwd_bwd_traces
66

77

88
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
9-
def test_tracing_against_reference_traces(tracer_instance, mesh_fields_asarray, launch_points, default_params, reference_traces):
10-
mesh_id, (br, bt, bp, r, t, p) = mesh_fields_asarray
11-
lps_id, lps = launch_points
9+
def test_tracing_against_reference_traces(tracer_instance,
10+
mesh_resolution,
11+
mas_type,
12+
tracing_direction,
13+
mesh_fields_asarray,
14+
launch_points,
15+
reference_traces):
16+
br, bt, bp, r, t, p = mesh_fields_asarray
17+
lps = launch_points
1218
tracer_instance.br = br, r, t, p
1319
tracer_instance.bt = bt, r, t, p
1420
tracer_instance.bp = bp, r, t, p
1521

16-
match lps_id:
22+
match tracing_direction:
1723
case 'fwd':
1824
tracer_instance.set_tracing_direction('f')
19-
traces = tracer_instance.trace(lps, buffer_size=default_params['lps']['BUFFER'])
25+
traces = tracer_instance.trace(lps, buffer_size=_BUFFER_SIZE)
2026
case 'bwd':
2127
tracer_instance.set_tracing_direction('b')
22-
traces = tracer_instance.trace(lps, buffer_size=default_params['lps']['BUFFER'])
28+
traces = tracer_instance.trace(lps, buffer_size=_BUFFER_SIZE)
2329
case 'both':
2430
tracer_instance.set_tracing_direction('f')
25-
fwd_traces = tracer_instance.trace(lps, buffer_size=default_params['lps']['BUFFER'])
31+
fwd_traces = tracer_instance.trace(lps, buffer_size=_BUFFER_SIZE)
2632
tracer_instance.set_tracing_direction('b')
27-
bwd_traces = tracer_instance.trace(lps, buffer_size=default_params['lps']['BUFFER'])
33+
bwd_traces = tracer_instance.trace(lps, buffer_size=_BUFFER_SIZE)
2834
traces = combine_fwd_bwd_traces(fwd_traces, bwd_traces)
2935
case _:
30-
raise ValueError(f'Unknown launch points id: {lps_id}')
36+
raise ValueError(f'Unknown launch points id: {tracing_direction}')
3137

3238
traces_trimmed = trim_fieldline_nan_buffer(traces)
3339
for i, arr in enumerate(traces_trimmed):
34-
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_id}_{lps_id}_{i}'])
35-
assert_allclose(wdist, 0, atol=default_params['_testing']['tolerances']['atol_exact'])
40+
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_resolution}_{mas_type}_{tracing_direction}_{i}'])
41+
assert_allclose(wdist, 0, atol=_TEST_TOLERANCES['atol_exact'])
42+
43+
44+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
45+
def test_old_mas_traces_against_new_mas_traces(tracer_instance,
46+
mesh_resolution,
47+
tracing_direction,
48+
old_and_new_mas,
49+
launch_points,
50+
reference_traces):
51+
mas_old, mas_new = old_and_new_mas
52+
tracer_instance.load_fields(*mas_old)
53+
54+
match tracing_direction:
55+
case 'fwd':
56+
tracer_instance.set_tracing_direction('f')
57+
traces = tracer_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
58+
case 'bwd':
59+
tracer_instance.set_tracing_direction('b')
60+
traces = tracer_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
61+
case 'both':
62+
tracer_instance.set_tracing_direction('f')
63+
fwd_traces = tracer_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
64+
tracer_instance.set_tracing_direction('b')
65+
bwd_traces = tracer_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
66+
traces = combine_fwd_bwd_traces(fwd_traces, bwd_traces)
67+
case _:
68+
raise ValueError(f'Unknown launch points id: {tracing_direction}')
69+
70+
traces_trimmed = trim_fieldline_nan_buffer(traces)
71+
for i, arr in enumerate(traces_trimmed):
72+
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_resolution}_new_mas_{tracing_direction}_{i}'])
73+
assert_allclose(wdist, 0, atol=_TEST_TOLERANCES['atol_fuzzy'])
74+
75+
76+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
77+
def test_new_mas_traces_against_old_mas_traces(tracer_instance,
78+
mesh_resolution,
79+
tracing_direction,
80+
old_and_new_mas,
81+
launch_points,
82+
reference_traces):
83+
mas_old, mas_new = old_and_new_mas
84+
tracer_instance.load_fields(*mas_new)
85+
86+
match tracing_direction:
87+
case 'fwd':
88+
tracer_instance.set_tracing_direction('f')
89+
traces = tracer_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
90+
case 'bwd':
91+
tracer_instance.set_tracing_direction('b')
92+
traces = tracer_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
93+
case 'both':
94+
tracer_instance.set_tracing_direction('f')
95+
fwd_traces = tracer_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
96+
tracer_instance.set_tracing_direction('b')
97+
bwd_traces = tracer_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
98+
traces = combine_fwd_bwd_traces(fwd_traces, bwd_traces)
99+
case _:
100+
raise ValueError(f'Unknown launch points id: {tracing_direction}')
101+
102+
traces_trimmed = trim_fieldline_nan_buffer(traces)
103+
for i, arr in enumerate(traces_trimmed):
104+
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_resolution}_old_mas_{tracing_direction}_{i}'])
105+
assert_allclose(wdist, 0, atol=_TEST_TOLERANCES['atol_fuzzy'])

tests/test_tracermp/test_fortran_calls_.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,14 @@ def test_trace_without_fieldline_validation(tracermp_instance, default_fields_as
2525
tracermp_instance.bp = bp_in
2626
tracermp_instance.run()
2727
response = tracermp_instance.trace(launch_points=[1, 1, 1], buffer_size=10)
28-
assert isinstance(response, Traces)
28+
assert isinstance(response, Traces)
29+
30+
31+
def test_multiple_run_calls(tracermp_instance, default_fields_aspaths):
32+
tracermp_instance['verbose_'] = True
33+
br_in, bt_in, bp_in = default_fields_aspaths
34+
tracermp_instance.br = br_in
35+
tracermp_instance.bt = bt_in
36+
tracermp_instance.bp = bp_in
37+
tracermp_instance.run()
38+
tracermp_instance.run()

tests/test_tracermp/test_trace_routines_.py

Lines changed: 115 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,78 +2,159 @@
22
from numpy.testing import assert_allclose
33

44
from mapflpy.scripts import run_forward_tracing, run_backward_tracing, run_fwdbwd_tracing, inter_domain_tracing
5-
from tests.utils import compute_fieldline_length, compute_weighted_fieldline_difference
5+
from tests.utils import compute_fieldline_length, compute_weighted_fieldline_difference, _BUFFER_SIZE, _TEST_TOLERANCES, _DOMAIN_RANGES
66
from mapflpy.utils import trim_fieldline_nan_buffer, combine_fwd_bwd_traces
77

88

99
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
10-
def test_tracing_against_reference_traces(tracermp_instance, mesh_fields_aspaths, launch_points, default_params, reference_traces):
11-
mesh_id, (br, bt, bp) = mesh_fields_aspaths
12-
lps_id, lps = launch_points
10+
def test_tracing_against_reference_traces(tracermp_instance,
11+
mesh_resolution,
12+
mas_type,
13+
tracing_direction,
14+
mesh_fields_aspaths,
15+
launch_points,
16+
reference_traces):
17+
br, bt, bp = mesh_fields_aspaths
18+
lps = launch_points
1319
tracermp_instance.br = br
1420
tracermp_instance.bt = bt
1521
tracermp_instance.bp = bp
1622

17-
match lps_id:
23+
match tracing_direction:
1824
case 'fwd':
1925
tracermp_instance.set_tracing_direction('f')
20-
traces = tracermp_instance.trace(lps, buffer_size=default_params['lps']['BUFFER'])
26+
traces = tracermp_instance.trace(lps, buffer_size=_BUFFER_SIZE)
2127
case 'bwd':
2228
tracermp_instance.set_tracing_direction('b')
23-
traces = tracermp_instance.trace(lps, buffer_size=default_params['lps']['BUFFER'])
29+
traces = tracermp_instance.trace(lps, buffer_size=_BUFFER_SIZE)
2430
case 'both':
2531
tracermp_instance.set_tracing_direction('f')
26-
fwd_traces = tracermp_instance.trace(lps, buffer_size=default_params['lps']['BUFFER'])
32+
fwd_traces = tracermp_instance.trace(lps, buffer_size=_BUFFER_SIZE)
2733
tracermp_instance.set_tracing_direction('b')
28-
bwd_traces = tracermp_instance.trace(lps, buffer_size=default_params['lps']['BUFFER'])
34+
bwd_traces = tracermp_instance.trace(lps, buffer_size=_BUFFER_SIZE)
2935
traces = combine_fwd_bwd_traces(fwd_traces, bwd_traces)
3036
case _:
31-
raise ValueError(f'Unknown launch points id: {lps_id}')
37+
raise ValueError(f'Unknown launch points id: {tracing_direction}')
3238

3339
traces_trimmed = trim_fieldline_nan_buffer(traces)
3440
for i, arr in enumerate(traces_trimmed):
35-
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_id}_{lps_id}_{i}'])
36-
assert_allclose(wdist, 0, atol=default_params['_testing']['tolerances']['atol_exact'])
41+
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_resolution}_{mas_type}_{tracing_direction}_{i}'])
42+
assert_allclose(wdist, 0, atol=_TEST_TOLERANCES['atol_exact'])
43+
44+
45+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
46+
def test_old_mas_traces_against_new_mas_traces(tracermp_instance,
47+
mesh_resolution,
48+
tracing_direction,
49+
old_and_new_mas,
50+
launch_points,
51+
reference_traces):
52+
mas_old, mas_new = old_and_new_mas
53+
tracermp_instance.load_fields(*mas_old)
54+
55+
match tracing_direction:
56+
case 'fwd':
57+
tracermp_instance.set_tracing_direction('f')
58+
traces = tracermp_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
59+
case 'bwd':
60+
tracermp_instance.set_tracing_direction('b')
61+
traces = tracermp_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
62+
case 'both':
63+
tracermp_instance.set_tracing_direction('f')
64+
fwd_traces = tracermp_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
65+
tracermp_instance.set_tracing_direction('b')
66+
bwd_traces = tracermp_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
67+
traces = combine_fwd_bwd_traces(fwd_traces, bwd_traces)
68+
case _:
69+
raise ValueError(f'Unknown launch points id: {tracing_direction}')
70+
71+
traces_trimmed = trim_fieldline_nan_buffer(traces)
72+
for i, arr in enumerate(traces_trimmed):
73+
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_resolution}_new_mas_{tracing_direction}_{i}'])
74+
assert_allclose(wdist, 0, atol=_TEST_TOLERANCES['atol_fuzzy'])
75+
76+
77+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
78+
def test_new_mas_traces_against_old_mas_traces(tracermp_instance,
79+
mesh_resolution,
80+
tracing_direction,
81+
old_and_new_mas,
82+
launch_points,
83+
reference_traces):
84+
mas_old, mas_new = old_and_new_mas
85+
tracermp_instance.load_fields(*mas_new)
86+
87+
match tracing_direction:
88+
case 'fwd':
89+
tracermp_instance.set_tracing_direction('f')
90+
traces = tracermp_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
91+
case 'bwd':
92+
tracermp_instance.set_tracing_direction('b')
93+
traces = tracermp_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
94+
case 'both':
95+
tracermp_instance.set_tracing_direction('f')
96+
fwd_traces = tracermp_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
97+
tracermp_instance.set_tracing_direction('b')
98+
bwd_traces = tracermp_instance.trace(launch_points, buffer_size=_BUFFER_SIZE)
99+
traces = combine_fwd_bwd_traces(fwd_traces, bwd_traces)
100+
case _:
101+
raise ValueError(f'Unknown launch points id: {tracing_direction}')
102+
103+
traces_trimmed = trim_fieldline_nan_buffer(traces)
104+
for i, arr in enumerate(traces_trimmed):
105+
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_resolution}_old_mas_{tracing_direction}_{i}'])
106+
assert_allclose(wdist, 0, atol=_TEST_TOLERANCES['atol_fuzzy'])
37107

38108

39109
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
40-
def test_tracing_scripts_against_reference_traces(mesh_fields_aspaths, launch_points, default_params, reference_traces):
41-
mesh_id, (br, bt, bp) = mesh_fields_aspaths
42-
lps_id, lps = launch_points
110+
def test_tracing_scripts_against_reference_traces(mesh_resolution,
111+
mas_type,
112+
tracing_direction,
113+
mesh_fields_aspaths,
114+
launch_points,
115+
reference_traces):
116+
br, bt, bp = mesh_fields_aspaths
117+
lps = launch_points
43118

44-
match lps_id:
119+
match tracing_direction:
45120
case 'fwd':
46-
traces = run_forward_tracing(br, bt, bp, lps, buffer_size=default_params['lps']['BUFFER'])
121+
traces = run_forward_tracing(br, bt, bp, lps, buffer_size=_BUFFER_SIZE)
47122
case 'bwd':
48-
traces = run_backward_tracing(br, bt, bp, lps, buffer_size=default_params['lps']['BUFFER'])
123+
traces = run_backward_tracing(br, bt, bp, lps, buffer_size=_BUFFER_SIZE)
49124
case 'both':
50-
traces = run_fwdbwd_tracing(br, bt, bp, lps, buffer_size=default_params['lps']['BUFFER'])
125+
traces = run_fwdbwd_tracing(br, bt, bp, lps, buffer_size=_BUFFER_SIZE)
51126
case _:
52-
raise ValueError(f'Unknown launch points id: {lps_id}')
127+
raise ValueError(f'Unknown launch points id: {tracing_direction}')
53128

54129
traces_trimmed = trim_fieldline_nan_buffer(traces)
55130
for i, arr in enumerate(traces_trimmed):
56-
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_id}_{lps_id}_{i}'])
57-
assert_allclose(wdist, 0, atol=default_params['_testing']['tolerances']['atol_exact'])
131+
wdist = compute_weighted_fieldline_difference(arr, reference_traces[f'{mesh_resolution}_{mas_type}_{tracing_direction}_{i}'])
132+
assert_allclose(wdist, 0, atol=_TEST_TOLERANCES['atol_exact'])
133+
58134

59135
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
60-
def test_interdomain_tracing_against_reference_traces(interdomain_files, launch_points, default_params, reference_traces):
136+
def test_interdomain_tracing_against_reference_traces(mesh_resolution,
137+
mas_type,
138+
tracing_direction,
139+
interdomain_files,
140+
launch_points,
141+
reference_traces):
61142
"""
62143
This test compares an interdomain trace where the domain has been split at r_interface to the reference traces
63144
where there was no split-domain. These traces can differ more because exactly where the interface lies along
64145
a reference field line segment can vary, which is effectively like seeding a part (or parts) of the trace
65146
with a slightly different start location. This means the traces *will not* be the same length and we must
66147
use the fuzzy tolerances because the errors are related to the mesh and discretization of B (not the tracer itself).
67148
"""
68-
mesh_id, (br_cor, bt_cor, bp_cor, br_hel, bt_hel, bp_hel) = interdomain_files
69-
lps_id, lps = launch_points
70-
buffer = default_params['lps']['BUFFER']
149+
br_cor, bt_cor, bp_cor, br_hel, bt_hel, bp_hel = interdomain_files
150+
lps = launch_points
151+
buffer = _BUFFER_SIZE
71152

72-
assert default_params["_testing"]["domain_ranges"]['cor']['r1'] == default_params["_testing"]["domain_ranges"]['hel']['r0'], \
153+
assert _DOMAIN_RANGES['cor']['r1'] == _DOMAIN_RANGES['hel']['r0'], \
73154
"Inconsistent domain interface radii."
74-
r_interface = default_params["_testing"]["domain_ranges"]['cor']['r1']
155+
r_interface = _DOMAIN_RANGES['cor']['r1']
75156

76-
if lps_id != 'both':
157+
if tracing_direction != 'both':
77158
pytest.skip("Interdomain tracing only implemented for both directions.")
78159
else:
79160
traces, *_ = inter_domain_tracing(br_cor,
@@ -86,10 +167,11 @@ def test_interdomain_tracing_against_reference_traces(interdomain_files, launch_
86167
r_interface=r_interface,
87168
buffer_size=buffer)
88169
for i, arr in enumerate(traces):
170+
reference_trace = reference_traces[f'{mesh_resolution}_{mas_type}_{tracing_direction}_{i}']
89171
# compare the distance of the first and last points (footprints)
90-
wdist = compute_weighted_fieldline_difference(arr[:, [0, -1]], reference_traces[f'{mesh_id}_{lps_id}_{i}'][:, [0, -1]])
91-
assert_allclose(wdist, 0, atol=default_params["_testing"]["tolerances"]['atol_fuzzy'])
172+
wdist = compute_weighted_fieldline_difference(arr[:, [0, -1]], reference_trace[:, [0, -1]])
173+
assert_allclose(wdist, 0, atol=_TEST_TOLERANCES['atol_fuzzy'])
92174
# compare the lengths of the traces
93175
len_test = compute_fieldline_length(arr)
94-
len_ref = compute_fieldline_length(reference_traces[f'{mesh_id}_{lps_id}_{i}'])
95-
assert_allclose(len_test, len_ref, rtol=default_params["_testing"]["tolerances"]['rtol_fuzzy'])
176+
len_ref = compute_fieldline_length(reference_trace)
177+
assert_allclose(len_test, len_ref, rtol=_TEST_TOLERANCES['rtol_fuzzy'])

0 commit comments

Comments
 (0)