Skip to content

Commit 102452d

Browse files
authored
Merge pull request LREN-CHUV#40 from LREN-CHUV/better-correlation-heatmap
better correlation heatmap
2 parents ca663f6 + 1e70ad5 commit 102452d

2 files changed

Lines changed: 24 additions & 22 deletions

File tree

python-correlation-heatmap/correlation_heatmap.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,13 @@ def _aggregate_results(results):
130130

131131
def _save_corr_heatmap(corr, columns):
132132
"""Generate heatmap from correlation matrix and return it in plotly format"""
133-
trace = go.Heatmap(z=corr,
133+
# revert y-axis so that diagonal goes from top left to bottom right
134+
trace = go.Heatmap(z=corr[::-1, :],
134135
x=columns,
135-
y=columns)
136+
y=columns[::-1],
137+
zmin=-1,
138+
zmax=1,
139+
)
136140
data = [trace]
137141

138142
logging.info("Results:\n{}".format(data))

python-correlation-heatmap/tests/unit/test_correlation_heatmap.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from correlation_heatmap import intermediate_stats, aggregate_stats, compute
88

99

10+
# TODO: put into io_helper.testing when convenient
1011
def round_dict(d, precision=3):
1112
"""Round all numerical values in a dictionary recursively."""
1213
d = copy.deepcopy(d)
@@ -34,18 +35,16 @@ def test_compute(mock_save_results, mock_fetch_data):
3435

3536
compute()
3637
results = json.loads(mock_save_results.call_args[0][0])
37-
assert round_dict(results) == round_dict([
38+
assert round_dict(results) == [
3839
{
39-
'type':
40-
'heatmap',
40+
'type': 'heatmap',
41+
'z': [[-0.429, -0.543, 1.0], [0.417, 1.0, -0.543], [1.0, 0.417, -0.429]],
4142
'x': ['iq', 'score_test1', 'stress_before_test1'],
42-
'y': ['iq', 'score_test1', 'stress_before_test1'],
43-
'z': [
44-
[1.0, 0.4168913285, -0.4287450417], [0.4168913285, 1.0, -0.5426534614],
45-
[-0.4287450417, -0.5426534614, 1.0]
46-
]
43+
'y': ['stress_before_test1', 'score_test1', 'iq'],
44+
'zmin': -1,
45+
'zmax': 1
4746
}
48-
])
47+
]
4948

5049

5150
@mock.patch('correlation_heatmap.io_helper.fetch_data')
@@ -70,17 +69,14 @@ def test_intermediate_stats(mock_save_results, mock_fetch_data):
7069

7170
intermediate_stats()
7271
results = json.loads(mock_save_results.call_args[0][0])
73-
assert round_dict(results) == round_dict({
72+
assert round_dict(results) == {
7473
'columns': ['iq', 'score_test1', 'stress_before_test1'],
75-
'means': [73.8815754762, 1096.5049055743, 52.9296397352],
76-
'X^T * X': [
77-
[32751.4170961055, 486164.9357124355, 23458.8913944936],
78-
[486164.9357124355, 7321018.913143162, 345715.382923219],
79-
[23458.8913944936, 345715.382923219, 17009.1219934008]
80-
],
74+
'means': [73.882, 1096.505, 52.93],
75+
'X^T * X':
76+
[[32751.417, 486164.936, 23458.891], [486164.936, 7321018.913, 345715.383], [23458.891, 345715.383, 17009.122]],
8177
'count':
8278
6
83-
})
79+
}
8480

8581

8682
def intermediate_data_1():
@@ -112,11 +108,13 @@ def mock_results(job_id):
112108

113109
aggregate_stats([1, 2])
114110
results = json.loads(mock_save_results.call_args[0][0])
115-
assert results == [
111+
assert round_dict(results) == [
116112
{
117113
'type': 'heatmap',
114+
'z': [[-0.429, 1.0], [1.0, -0.429]],
118115
'x': ['iq', 'stress_before_test1'],
119-
'y': ['iq', 'stress_before_test1'],
120-
'z': [[1.0, -0.4287450502], [-0.4287450502, 1.0]]
116+
'y': ['stress_before_test1', 'iq'],
117+
'zmin': -1,
118+
'zmax': 1
121119
}
122120
]

0 commit comments

Comments
 (0)