-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_regression.py
More file actions
155 lines (126 loc) · 4.89 KB
/
test_regression.py
File metadata and controls
155 lines (126 loc) · 4.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict
import numpy as np
import pytest
from tests.regression.helpers import run_codeentropy_with_config
def _group_index(payload: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
"""Return the groups mapping from a regression payload.
Args:
payload: Parsed JSON payload.
Returns:
Mapping of group id to group data.
Raises:
TypeError: If payload["groups"] is not a dict.
"""
groups = payload.get("groups", {})
if not isinstance(groups, dict):
raise TypeError("payload['groups'] must be a dict")
return groups
def _compare_grouped(
*,
got_payload: Dict[str, Any],
baseline_payload: Dict[str, Any],
rtol: float,
atol: float,
) -> None:
"""Compare grouped regression outputs against baseline values.
Args:
got_payload: Newly produced payload.
baseline_payload: Baseline payload.
rtol: Relative tolerance.
atol: Absolute tolerance.
Raises:
AssertionError: If any required group/component differs from baseline.
"""
got_groups = _group_index(got_payload)
base_groups = _group_index(baseline_payload)
missing_groups = sorted(set(base_groups.keys()) - set(got_groups.keys()))
assert not missing_groups, f"Missing groups in output: {missing_groups}"
mismatches: list[str] = []
for gid, base_g in base_groups.items():
got_g = got_groups[gid]
base_components = base_g.get("components", {})
got_components = got_g.get("components", {})
if not isinstance(base_components, dict) or not isinstance(
got_components, dict
):
mismatches.append(f"group {gid}: components must be dicts")
continue
missing_keys = sorted(set(base_components.keys()) - set(got_components.keys()))
if missing_keys:
mismatches.append(f"group {gid}: missing component keys: {missing_keys}")
continue
for k, expected in base_components.items():
actual = got_components[k]
try:
np.testing.assert_allclose(
float(actual), float(expected), rtol=rtol, atol=atol
)
except AssertionError:
mismatches.append(
f"group {gid} component {k}: expected={expected} got={actual}"
)
if "total" in base_g:
try:
np.testing.assert_allclose(
float(got_g.get("total", 0.0)),
float(base_g["total"]),
rtol=rtol,
atol=atol,
)
except AssertionError:
mismatches.append(
f"group {gid} total: expected={base_g['total']} "
f"got={got_g.get('total')}"
)
assert not mismatches, "Mismatches:\n" + "\n".join(" " + m for m in mismatches)
@pytest.mark.regression
@pytest.mark.parametrize(
"system",
[
pytest.param("benzaldehyde", marks=pytest.mark.slow),
pytest.param("benzene", marks=pytest.mark.slow),
pytest.param("cyclohexane", marks=pytest.mark.slow),
"dna",
pytest.param("ethyl-acetate", marks=pytest.mark.slow),
"methane",
"methanol",
pytest.param("octonol", marks=pytest.mark.slow),
"water",
],
)
def test_regression_matches_baseline(
tmp_path: Path, system: str, request: pytest.FixtureRequest
) -> None:
"""Run a regression test for one system and compare to its baseline.
Args:
tmp_path: Pytest-provided temporary directory.
system: System name parameter.
request: Pytest request fixture for reading CLI options.
"""
repo_root = Path(__file__).resolve().parents[2]
config_path = (
repo_root / "tests" / "regression" / "configs" / system / "config.yaml"
)
baseline_path = repo_root / "tests" / "regression" / "baselines" / f"{system}.json"
assert config_path.exists(), f"Missing config: {config_path}"
assert baseline_path.exists(), f"Missing baseline: {baseline_path}"
baseline_payload = json.loads(baseline_path.read_text())
run = run_codeentropy_with_config(workdir=tmp_path, config_src=config_path)
if request.config.getoption("--codeentropy-debug"):
print("\n[CodeEntropy regression debug]")
print("workdir:", run.workdir)
print("job_dir:", run.job_dir)
print("output_json:", run.output_json)
print("payload copy saved:", run.workdir / "codeentropy_output.json")
if request.config.getoption("--update-baselines"):
baseline_path.write_text(json.dumps(run.payload, indent=2))
pytest.skip(f"Baseline updated for {system}: {baseline_path}")
_compare_grouped(
got_payload=run.payload,
baseline_payload=baseline_payload,
rtol=1e-9,
atol=0.5,
)