77from correlation_heatmap import intermediate_stats , aggregate_stats , compute
88
99
10+ # TODO: put into io_helper.testing when convenient
1011def round_dict (d , precision = 3 ):
1112 """Round all numerical values in a dictionary recursively."""
1213 d = copy .deepcopy (d )
@@ -34,18 +35,16 @@ def test_compute(mock_save_results, mock_fetch_data):
3435
3536 compute ()
3637 results = json .loads (mock_save_results .call_args [0 ][0 ])
37- assert round_dict (results ) == round_dict ( [
38+ assert round_dict (results ) == [
3839 {
39- 'type' :
40- 'heatmap' ,
40+ 'type' : 'heatmap' ,
41+ 'z' : [[ - 0.429 , - 0.543 , 1.0 ], [ 0.417 , 1.0 , - 0.543 ], [ 1.0 , 0.417 , - 0.429 ]] ,
4142 'x' : ['iq' , 'score_test1' , 'stress_before_test1' ],
42- 'y' : ['iq' , 'score_test1' , 'stress_before_test1' ],
43- 'z' : [
44- [1.0 , 0.4168913285 , - 0.4287450417 ], [0.4168913285 , 1.0 , - 0.5426534614 ],
45- [- 0.4287450417 , - 0.5426534614 , 1.0 ]
46- ]
43+ 'y' : ['stress_before_test1' , 'score_test1' , 'iq' ],
44+ 'zmin' : - 1 ,
45+ 'zmax' : 1
4746 }
48- ])
47+ ]
4948
5049
5150@mock .patch ('correlation_heatmap.io_helper.fetch_data' )
@@ -70,17 +69,14 @@ def test_intermediate_stats(mock_save_results, mock_fetch_data):
7069
7170 intermediate_stats ()
7271 results = json .loads (mock_save_results .call_args [0 ][0 ])
73- assert round_dict (results ) == round_dict ( {
72+ assert round_dict (results ) == {
7473 'columns' : ['iq' , 'score_test1' , 'stress_before_test1' ],
75- 'means' : [73.8815754762 , 1096.5049055743 , 52.9296397352 ],
76- 'X^T * X' : [
77- [32751.4170961055 , 486164.9357124355 , 23458.8913944936 ],
78- [486164.9357124355 , 7321018.913143162 , 345715.382923219 ],
79- [23458.8913944936 , 345715.382923219 , 17009.1219934008 ]
80- ],
74+ 'means' : [73.882 , 1096.505 , 52.93 ],
75+ 'X^T * X' :
76+ [[32751.417 , 486164.936 , 23458.891 ], [486164.936 , 7321018.913 , 345715.383 ], [23458.891 , 345715.383 , 17009.122 ]],
8177 'count' :
8278 6
83- })
79+ }
8480
8581
8682def intermediate_data_1 ():
@@ -112,11 +108,13 @@ def mock_results(job_id):
112108
113109 aggregate_stats ([1 , 2 ])
114110 results = json .loads (mock_save_results .call_args [0 ][0 ])
115- assert results == [
111+ assert round_dict ( results ) == [
116112 {
117113 'type' : 'heatmap' ,
114+ 'z' : [[- 0.429 , 1.0 ], [1.0 , - 0.429 ]],
118115 'x' : ['iq' , 'stress_before_test1' ],
119- 'y' : ['iq' , 'stress_before_test1' ],
120- 'z' : [[1.0 , - 0.4287450502 ], [- 0.4287450502 , 1.0 ]]
116+ 'y' : ['stress_before_test1' , 'iq' ],
117+ 'zmin' : - 1 ,
118+ 'zmax' : 1
121119 }
122120 ]
0 commit comments