Skip to content

Commit ade0375

Browse files
committed
Added ability to set state like np.random.seed(seed)
1 parent 9efa0f5 commit ade0375

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

conftest.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@
1414
# Set RNG state here
1515
# }
1616

17+
pytest.LAST_RNG_STATE = None
18+
pytest.FIXED_RNG_STATE = {
19+
# DO NOT CHANGE/ALTER!!
20+
"bit_generator": "PCG64",
21+
"state": {
22+
"state": 195349167630453735115769518810051464980,
23+
"inc": 247589055400886363559049235690497450585,
24+
},
25+
"has_uint32": 0,
26+
"uinteger": 0,
27+
}
28+
1729

1830
def pytest_configure(config):
1931
"""
@@ -22,3 +34,44 @@ def pytest_configure(config):
2234
"""
2335
state = json.dumps(pytest.RNG.bit_generator.state, indent=4)
2436
print(f"conftest.py: pytest.RNG.bit_generator.state = {state}")
37+
38+
39+
def _get_rng_state():
40+
"""
41+
Get a copy of the current RNG state
42+
"""
43+
return pytest.RNG.bit_generator.state.copy()
44+
45+
46+
def _set_rng_state(state):
47+
"""
48+
Store existing RNG state and set RNG state
49+
50+
If no state is provided, a hardcoded, safe state will be used
51+
"""
52+
if state is None:
53+
state = pytest.SEED_ZERO.copy()
54+
55+
pytest.LAST_RNG_STATE = pytest.RNG.bit_generator.state.copy()
56+
pytest.RNG.bit_generator.state = state
57+
58+
59+
def _reset_rng_state():
60+
"""
61+
Restore the RNG state to the last recorded RNG state
62+
"""
63+
pytest.RNG.bit_generator.state = pytest.LAST_RNG_STATE.copy()
64+
pytest.LAST_RNG_STATE = None
65+
66+
67+
def _fix_rng_state():
68+
"""
69+
Set the RNG state to a fixed, hardcoded, safe state
70+
"""
71+
_set_rng_state(pytest.FIXED_RNG_STATE)
72+
73+
74+
pytest.get_rng_state = _get_rng_state
75+
pytest.set_rng_state = _set_rng_state
76+
pytest.fix_rng_state = _fix_rng_state
77+
pytest.unfix_rng_state = _reset_rng_state

tests/test_core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1906,7 +1906,7 @@ def test_update_incremental_PI_egressTrue_MemoryCheck():
19061906
# a new data point is appended. However, the updated matrix profile index for the
19071907
# middle subsequence `s` should still refer to the first subsequence in
19081908
# the historical data.
1909-
1909+
pytest.fix_rng_state()
19101910
T = pytest.RNG.random(64)
19111911
m = 3
19121912
excl_zone = int(np.ceil(m / config.STUMPY_EXCL_ZONE_DENOM))
@@ -1965,6 +1965,8 @@ def test_update_incremental_PI_egressTrue_MemoryCheck():
19651965
npt.assert_almost_equal(P_ref, P_comp)
19661966
npt.assert_almost_equal(I_ref, I_comp)
19671967

1968+
pytest.unfix_rng_state()
1969+
19681970

19691971
def test_check_self_join():
19701972
with pytest.warns(UserWarning):

0 commit comments

Comments
 (0)