@@ -20,78 +20,81 @@ def dask_cluster():
2020 cluster .close ()
2121
2222
23- @pytest .mark .parametrize (
24- "seed" , np .random .choice (np .arange (10000 ), size = 25 , replace = False )
25- )
26- def test_random_ostinato (seed ):
27- m = 50
28- np .random .seed (seed )
29- Ts = [np .random .rand (n ) for n in [64 , 128 , 256 ]]
30-
31- ref_radius , ref_Ts_idx , ref_subseq_idx = naive .aamp_ostinato (Ts , m )
32- comp_radius , comp_Ts_idx , comp_subseq_idx = aamp_ostinato (Ts , m )
33-
34- npt .assert_almost_equal (ref_radius , comp_radius )
35- npt .assert_almost_equal (ref_Ts_idx , comp_Ts_idx )
36- npt .assert_almost_equal (ref_subseq_idx , comp_subseq_idx )
37-
38-
39- @pytest .mark .parametrize ("seed" , [41 , 88 , 290 , 292 , 310 , 328 , 538 , 556 , 563 , 570 ])
40- def test_deterministic_ostinato (seed ):
41- m = 50
42- np .random .seed (seed )
43- Ts = [np .random .rand (n ) for n in [64 , 128 , 256 ]]
23+ def test_random_ostinato ():
24+ for _ in range (25 ):
25+ m = 50
26+ Ts = [pytest .RNG .random (n ) for n in [64 , 128 , 256 ]]
4427
45- for p in [1.0 , 2.0 , 3.0 ]:
46- ref_radius , ref_Ts_idx , ref_subseq_idx = naive .aamp_ostinato (Ts , m , p = p )
47- comp_radius , comp_Ts_idx , comp_subseq_idx = aamp_ostinato (Ts , m , p = p )
28+ ref_radius , ref_Ts_idx , ref_subseq_idx = naive .aamp_ostinato (Ts , m )
29+ comp_radius , comp_Ts_idx , comp_subseq_idx = aamp_ostinato (Ts , m )
4830
4931 npt .assert_almost_equal (ref_radius , comp_radius )
5032 npt .assert_almost_equal (ref_Ts_idx , comp_Ts_idx )
5133 npt .assert_almost_equal (ref_subseq_idx , comp_subseq_idx )
5234
5335
54- @pytest .mark .parametrize (
55- "seed" , np .random .choice (np .arange (10000 ), size = 25 , replace = False )
56- )
57- def test_random_ostinatoed (seed , dask_cluster ):
58- with Client (dask_cluster ) as dask_client :
36+ def test_deterministic_ostinato ():
37+ pytest .fix_rng_state ()
38+
39+ for _ in range (10 ):
5940 m = 50
60- np .random .seed (seed )
61- Ts = [np .random .rand (n ) for n in [64 , 128 , 256 ]]
41+ Ts = [pytest .RNG .random (n ) for n in [64 , 128 , 256 ]]
6242
63- ref_radius , ref_Ts_idx , ref_subseq_idx = naive .aamp_ostinato (Ts , m )
64- comp_radius , comp_Ts_idx , comp_subseq_idx = aamp_ostinatoed (dask_client , Ts , m )
43+ for p in [1.0 , 2.0 , 3.0 ]:
44+ ref_radius , ref_Ts_idx , ref_subseq_idx = naive .aamp_ostinato (Ts , m , p = p )
45+ comp_radius , comp_Ts_idx , comp_subseq_idx = aamp_ostinato (Ts , m , p = p )
6546
66- npt .assert_almost_equal (ref_radius , comp_radius )
67- npt .assert_almost_equal (ref_Ts_idx , comp_Ts_idx )
68- npt .assert_almost_equal (ref_subseq_idx , comp_subseq_idx )
47+ npt .assert_almost_equal (ref_radius , comp_radius )
48+ npt .assert_almost_equal (ref_Ts_idx , comp_Ts_idx )
49+ npt .assert_almost_equal (ref_subseq_idx , comp_subseq_idx )
6950
51+ pytest .unfix_rng_state ()
7052
71- @ pytest . mark . parametrize ( "seed" , [ 41 , 88 , 290 , 292 , 310 , 328 , 538 , 556 , 563 , 570 ])
72- def test_deterministic_ostinatoed ( seed , dask_cluster ):
53+
54+ def test_random_ostinatoed ( dask_cluster ):
7355 with Client (dask_cluster ) as dask_client :
74- m = 50
75- np . random . seed ( seed )
76- Ts = [np . random . rand (n ) for n in [64 , 128 , 256 ]]
56+ for _ in range ( 25 ):
57+ m = 50
58+ Ts = [pytest . RNG . random (n ) for n in [64 , 128 , 256 ]]
7759
78- for p in [1.0 , 2.0 , 3.0 ]:
79- ref_radius , ref_Ts_idx , ref_subseq_idx = naive .aamp_ostinato (Ts , m , p = p )
60+ ref_radius , ref_Ts_idx , ref_subseq_idx = naive .aamp_ostinato (Ts , m )
8061 comp_radius , comp_Ts_idx , comp_subseq_idx = aamp_ostinatoed (
81- dask_client , Ts , m , p = p
62+ dask_client , Ts , m
8263 )
8364
8465 npt .assert_almost_equal (ref_radius , comp_radius )
8566 npt .assert_almost_equal (ref_Ts_idx , comp_Ts_idx )
8667 npt .assert_almost_equal (ref_subseq_idx , comp_subseq_idx )
8768
8869
70+ @pytest .mark .parametrize ("seed" , [41 , 88 , 290 , 292 , 310 , 328 , 538 , 556 , 563 , 570 ])
71+ def test_deterministic_ostinatoed (seed , dask_cluster ):
72+ pytest .fix_rng_state ()
73+
74+ with Client (dask_cluster ) as dask_client :
75+ for _ in range (10 ):
76+ m = 50
77+ Ts = [pytest .RNG .random (n ) for n in [64 , 128 , 256 ]]
78+
79+ for p in [1.0 , 2.0 , 3.0 ]:
80+ ref_radius , ref_Ts_idx , ref_subseq_idx = naive .aamp_ostinato (Ts , m , p = p )
81+ comp_radius , comp_Ts_idx , comp_subseq_idx = aamp_ostinatoed (
82+ dask_client , Ts , m , p = p
83+ )
84+
85+ npt .assert_almost_equal (ref_radius , comp_radius )
86+ npt .assert_almost_equal (ref_Ts_idx , comp_Ts_idx )
87+ npt .assert_almost_equal (ref_subseq_idx , comp_subseq_idx )
88+
89+ pytest .unfix_rng_state ()
90+
91+
8992def test_input_not_overwritten_ostinato ():
9093 # aamp_ostinato preprocesses its input, a list of time series,
9194 # by replacing nan value with 0 in each time series.
9295 # This test ensures that the original input is not overwritten
9396 m = 50
94- Ts = [np . random . rand (n ) for n in [64 , 128 , 256 ]]
97+ Ts = [pytest . RNG . random (n ) for n in [64 , 128 , 256 ]]
9598 for T in Ts :
9699 T [0 ] = np .nan
97100
@@ -107,7 +110,7 @@ def test_input_not_overwritten_ostinato():
107110def test_extract_several_consensus_ostinato ():
108111 # This test is to further ensure that the function `aamp_ostinato`
109112 # does not tamper with the original data.
110- Ts = [np . random . rand (n ) for n in [256 , 512 , 1024 ]]
113+ Ts = [pytest . RNG . random (n ) for n in [256 , 512 , 1024 ]]
111114 Ts_ref = [T .copy () for T in Ts ]
112115 Ts_comp = [T .copy () for T in Ts ]
113116
@@ -145,7 +148,7 @@ def test_input_not_overwritten_ostinatoed(dask_cluster):
145148 # This test ensures that the original input is not overwritten
146149 with Client (dask_cluster ) as dask_client :
147150 m = 50
148- Ts = [np . random . rand (n ) for n in [64 , 128 , 256 ]]
151+ Ts = [pytest . RNG . random (n ) for n in [64 , 128 , 256 ]]
149152 for T in Ts :
150153 T [0 ] = np .nan
151154
@@ -163,7 +166,7 @@ def test_input_not_overwritten_ostinatoed(dask_cluster):
163166def test_extract_several_consensus_ostinatoed (dask_cluster ):
164167 # This test is to further ensure that the function `ostinatoed`
165168 # does not tamper with the original data.
166- Ts = [np . random . rand (n ) for n in [256 , 512 , 1024 ]]
169+ Ts = [pytest . RNG . random (n ) for n in [256 , 512 , 1024 ]]
167170 Ts_ref = [T .copy () for T in Ts ]
168171 Ts_comp = [T .copy () for T in Ts ]
169172
0 commit comments