Skip to content

Commit e6e07d2

Browse files
authored
Merge pull request LREN-CHUV#53 from LREN-CHUV/dist-pca
distributed PCA
2 parents 5c4e605 + 8d243ae commit e6e07d2

7 files changed

Lines changed: 198 additions & 41 deletions

File tree

python-correlation-heatmap/README.md

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,39 @@
66

77
Calculate correlation heatmap, only works for **real variables**.
88

9-
You can run it on single node with `compute` or in a distributed way with
9+
You can run it on single node with `compute` and env variable `MODEL_PARAM_graph=correlation_heatmap`.
10+
11+
1. `compute`
12+
13+
Or in a distributed mode with
1014

1115
1. `compute --mode intermediate`
1216
2. `compute --mode aggregate --job-ids 1 2 3`
1317

1418
Intermediate mode calculates covariance matrix from a single node, while aggregate mode is used after intermediate to
15-
combine statistics from multiple jobs.
19+
combine statistics from multiple jobs and produce the final graph.
20+
21+
22+
# Python Distributed PCA
23+
24+
Calculate PCA and return biplot visualization and screeplot. It only works for **real variables**.
25+
26+
You can run it on single node with `compute` and env variable `MODEL_PARAM_graph=pca`.
27+
28+
1. `compute`
29+
30+
Or in a distributed mode with
31+
32+
1. `compute --mode intermediate`
33+
2. `compute --mode aggregate --job-ids 1 2 3`
34+
35+
<!--
36+
Proposal for distributed mode with PCA scores graph included. See https://trello.com/c/jfLav9K6/58-distributed-pca for
37+
discussion
38+
1. `compute --mode intermediate` (calculate covariance matrices on nodes)
39+
2. `compute --mode aggregate --job-ids 1 2 3` (calculate aggregated correlation matrix)
40+
3. `compute --mode intermediate --agg-job-id 4` (calculate covariance matrices and also sample scores for PCA)
41+
4. `compute --mode aggregate --job-ids 5 6 7 --graph pca` (produce plotly visualization) -->
1642

1743

1844
## Build (for contributors)

python-correlation-heatmap/correlation_heatmap.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _compute_intermediate_result(inputs):
107107

108108

109109
@utils.catch_user_error
110-
def aggregate_stats(job_ids):
110+
def aggregate_stats(job_ids, graph_type=None):
111111
"""Get all partial statistics from all nodes and aggregate them.
112112
:input job_ids: list of job_ids with intermediate results
113113
"""
@@ -117,12 +117,14 @@ def aggregate_stats(job_ids):
117117

118118
corr, columns = _aggregate_results(results)
119119

120-
graph_type = parameters.get_parameter('MODEL_PARAM_graph', str, 'correlation_heatmap')
120+
graph_type = graph_type or parameters.get_parameter('graph', str, 'correlation_heatmap')
121121

122122
if graph_type == 'correlation_heatmap':
123123
_save_corr_heatmap(corr, columns)
124124
elif graph_type == 'pca':
125-
raise errors.UserError('MODEL_PARAM_graph=pca is not yet supported for distributed mode.')
125+
# save PCA graphs, but leave out the one with PCA scores
126+
logging.warning('Sample scores graph is not yet implemented for distributed PCA.')
127+
_save_pca(corr, columns, X=None)
126128
else:
127129
raise errors.UserError('MODEL_PARAM_graph only supports values `correlation_heatmap` and `pca`')
128130

@@ -168,7 +170,7 @@ def _save_corr_heatmap(corr, columns):
168170
logging.info("DONE")
169171

170172

171-
def _pca(corr, X):
173+
def _pca(corr, X=None):
172174
# calculate eigenvectors and eigenvalues
173175
eig_vals, eig_vecs = np.linalg.eig(corr)
174176

@@ -187,43 +189,55 @@ def _pca(corr, X):
187189

188190
# convert original data to scores
189191
# NOTE: since we are working with correlation matrix, original data must be standardized first!
190-
X_std = (X - X.mean()) / X.std()
191-
Y = X_std.dot(W)
192+
if X is not None:
193+
X_std = (X - X.mean()) / X.std()
194+
Y = X_std.dot(W)
195+
else:
196+
Y = None
192197

193198
return eig_vals, eig_vecs, Y
194199

195200

196-
def _figure(eig_vals, eig_vecs, Y, X):
201+
def _figure(eig_vals, eig_vecs, Y, columns):
202+
show_scores = Y is not None
203+
204+
titles = ['Scree plot', 'Eigen-components', 'Variables scores']
205+
titles.append('Samples scores' if show_scores else 'Samples scores <br>(not available in distributed mode)')
206+
197207
# plotting
198208
fig = tools.make_subplots(
199-
rows=2, cols=2, subplot_titles=('Samples scores', 'Variables scores', 'Scree plot', 'Eigen-components')
209+
rows=2, cols=2, subplot_titles=titles
200210
)
201211

202-
for d in _biplot_samples(Y.values):
212+
for d in _screeplot(eig_vals):
203213
fig.append_trace(d, 1, 1)
204214

205-
for d in _biplot_variables(eig_vecs, list(X.columns)):
215+
for d in _eigencomponents(eig_vecs, columns):
206216
fig.append_trace(d, 1, 2)
207217

208-
for d in _screeplot(eig_vals):
218+
for d in _biplot_variables(eig_vecs, columns):
209219
fig.append_trace(d, 2, 1)
210220

211-
for d in _eigencomponents(eig_vecs, list(X.columns)):
212-
fig.append_trace(d, 2, 2)
221+
# only show sample scores in single node mode
222+
if show_scores:
223+
for d in _biplot_samples(Y.values):
224+
fig.append_trace(d, 2, 2)
213225

214226
var_exp = _explained_variance(eig_vals)
215227

216-
fig['layout']['xaxis1'].update(title='PC1 ({:.1%})'.format(var_exp[0]))
217-
fig['layout']['yaxis1'].update(title='PC2 ({:.1%})'.format(var_exp[1]))
218-
fig['layout']['xaxis2'].update(title='PC1 ({:.1%})'.format(var_exp[0]), range=[-1.05, 1.05])
219-
fig['layout']['yaxis2'].update(title='PC2 ({:.1%})'.format(var_exp[1]), range=[-1.05, 1.05])
220-
fig['layout']['yaxis3'].update(title='Explained variance in percent')
228+
fig['layout']['yaxis1'].update(title='Explained variance in percent')
229+
fig['layout']['xaxis3'].update(title='PC1 ({:.1%})'.format(var_exp[0]), range=[-1.05, 1.05])
230+
fig['layout']['yaxis3'].update(title='PC2 ({:.1%})'.format(var_exp[1]), range=[-1.05, 1.05])
231+
232+
if show_scores:
233+
fig['layout']['xaxis4'].update(title='PC1 ({:.1%})'.format(var_exp[0]))
234+
fig['layout']['yaxis4'].update(title='PC2 ({:.1%})'.format(var_exp[1]))
221235

222236
# unit-circle for biplot
223237
circle = {
224238
'type': 'circle',
225-
'xref': 'x2',
226-
'yref': 'y2',
239+
'xref': 'x3',
240+
'yref': 'y3',
227241
'x0': -1,
228242
'y0': -1,
229243
'x1': 1,
@@ -238,10 +252,10 @@ def _figure(eig_vals, eig_vecs, Y, X):
238252
return fig
239253

240254

241-
def _save_pca(corr, columns, X):
255+
def _save_pca(corr, columns, X=None):
242256
"""Generate PCA visualization in plotly format. Inspired by https://plot.ly/ipython-notebooks/principal-component-analysis/"""
243257
eig_vals, eig_vecs, Y = _pca(corr, X)
244-
fig = _figure(eig_vals, eig_vecs, Y, X)
258+
fig = _figure(eig_vals, eig_vecs, Y, columns)
245259

246260
logging.info("Results:\n{}".format(fig))
247261
io_helper.save_results(json.dumps(fig), shapes.Shapes.PLOTLY)

python-correlation-heatmap/tests/Dockerfile

Lines changed: 0 additions & 11 deletions
This file was deleted.

python-correlation-heatmap/tests/docker-compose.yml

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ services:
8888
MODEL_PARAM_design: "factorial"
8989
MODEL_PARAM_graph: "correlation_heatmap"
9090

91+
correlation-heatmap-single:
92+
extends: correlation-heatmap-base
93+
container_name: "correlation-heatmap-single"
94+
environment:
95+
JOB_ID: '1'
96+
PARAM_query: "SELECT lefthippocampus, minimentalstate, opticchiasm, subjectageyears FROM cde_features_a LIMIT 100"
97+
links:
98+
- "db:db"
99+
91100
correlation-heatmap-a:
92101
extends: correlation-heatmap-base
93102
container_name: "correlation-heatmap-a"
@@ -114,7 +123,7 @@ services:
114123
links:
115124
- "db:db"
116125

117-
distributed-pca:
126+
distributed-pca-base:
118127
image: "hbpmip/python-distributed-pca:latest"
119128
container_name: "distributed-pca"
120129
restart: "no"
@@ -136,9 +145,42 @@ services:
136145
PARAM_variables: "lefthippocampus"
137146
PARAM_covariables: "minimentalstate,opticchiasm,subjectageyears"
138147
PARAM_grouping: ""
139-
PARAM_query: "SELECT lefthippocampus, minimentalstate, opticchiasm, subjectageyears FROM cde_features_a LIMIT 100"
140148
PARAM_meta: "{\"lefthippocampus\":{\"code\":\"lefthippocampus\",\"type\":\"real\",\"mean\":3.0,\"std\":0.35},\"minimentalstate\":{\"code\":\"minimentalstate\",\"type\":\"real\",\"mean\":24.0,\"std\":5.0},\"opticchiasm\":{\"code\":\"opticchiasm\",\"type\":\"real\",\"mean\":0.08,\"std\":0.009},\"subjectage\":{\"code\":\"subjectage\",\"type\":\"real\",\"mean\":71.0,\"std\":8.0},\"subjectageyears\":{\"description\":\"Subject age in years.\",\"methodology\":\"mip-cde\",\"label\":\"Age Years\",\"minValue\":0,\"code\":\"subjectageyears\",\"units\":\"years\",\"length\":3,\"maxValue\":130.0,\"type\":\"integer\"}}"
141149
MODEL_PARAM_design: "factorial"
142150
MODEL_PARAM_graph: "pca"
151+
152+
distributed-pca-single:
153+
extends: distributed-pca-base
154+
container_name: "distributed-pca-single"
155+
environment:
156+
JOB_ID: '1'
157+
PARAM_query: "SELECT lefthippocampus, minimentalstate, opticchiasm, subjectageyears FROM cde_features_a LIMIT 100"
158+
links:
159+
- "db:db"
160+
161+
distributed-pca-a:
162+
extends: distributed-pca-base
163+
container_name: "distributed-pca-a"
164+
environment:
165+
JOB_ID: '1'
166+
PARAM_query: "SELECT lefthippocampus, minimentalstate, opticchiasm, subjectageyears FROM cde_features_a LIMIT 100"
167+
links:
168+
- "db:db"
169+
170+
distributed-pca-b:
171+
extends: distributed-pca-base
172+
container_name: "distributed-pca-b"
173+
environment:
174+
JOB_ID: '2'
175+
PARAM_query: "SELECT lefthippocampus, minimentalstate, opticchiasm, subjectageyears FROM cde_features_b LIMIT 100"
176+
links:
177+
- "db:db"
178+
179+
distributed-pca-agg:
180+
extends: distributed-pca-base
181+
container_name: "distributed-pca-agg"
182+
environment:
183+
JOB_ID: '3'
184+
PARAM_query: "SELECT lefthippocampus, minimentalstate, opticchiasm, subjectageyears FROM cde_features_b LIMIT 100"
143185
links:
144186
- "db:db"

python-correlation-heatmap/tests/test.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ echo "Initialise the databases..."
4646
$DOCKER_COMPOSE run sample_data_db_setup
4747
$DOCKER_COMPOSE run woken_db_setup
4848

49+
## single-node mode
4950
echo
50-
echo "Run the distributed-pca algorithm..."
51-
$DOCKER_COMPOSE run distributed-pca compute
51+
echo "Run the correlation-heatmap algorithm on single node..."
52+
$DOCKER_COMPOSE run correlation-heatmap-single compute
5253

54+
## distributed mode
5355
# echo
5456
# echo "Run the correlation-heatmap-a..."
5557
# $DOCKER_COMPOSE run correlation-heatmap-a compute --mode intermediate
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env bash
2+
3+
set -o pipefail # trace ERR through pipes
4+
set -o errtrace # trace ERR through 'time command' and other functions
5+
set -o errexit ## set -e : exit the script if any statement returns a non-true return value
6+
7+
get_script_dir () {
8+
SOURCE="${BASH_SOURCE[0]}"
9+
10+
while [ -h "$SOURCE" ]; do
11+
DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
12+
SOURCE="$( readlink "$SOURCE" )"
13+
[[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE"
14+
done
15+
cd -P "$( dirname "$SOURCE" )"
16+
pwd
17+
}
18+
19+
cd "$(get_script_dir)"
20+
21+
if [[ $NO_SUDO || -n "$CIRCLECI" ]]; then
22+
DOCKER_COMPOSE="docker-compose"
23+
elif groups $USER | grep &>/dev/null '\bdocker\b'; then
24+
DOCKER_COMPOSE="docker-compose"
25+
else
26+
DOCKER_COMPOSE="sudo docker-compose"
27+
fi
28+
29+
function _cleanup() {
30+
local error_code="$?"
31+
echo "Stopping the containers..."
32+
$DOCKER_COMPOSE stop | true
33+
$DOCKER_COMPOSE down | true
34+
$DOCKER_COMPOSE rm -f > /dev/null 2> /dev/null | true
35+
exit $error_code
36+
}
37+
trap _cleanup EXIT INT TERM
38+
39+
echo "Starting the databases..."
40+
$DOCKER_COMPOSE up -d --remove-orphans db
41+
$DOCKER_COMPOSE run wait_dbs
42+
$DOCKER_COMPOSE run create_dbs
43+
44+
echo
45+
echo "Initialise the databases..."
46+
$DOCKER_COMPOSE run sample_data_db_setup
47+
$DOCKER_COMPOSE run woken_db_setup
48+
49+
# single-node mode
50+
# echo
51+
# echo "Run the distributed-pca algorithm on single node..."
52+
# $DOCKER_COMPOSE run distributed-pca-single compute
53+
54+
## distributed mode
55+
echo
56+
echo "Run the distributed-pca-a..."
57+
$DOCKER_COMPOSE run distributed-pca-a compute --mode intermediate
58+
echo "Run the distributed-pca-b..."
59+
$DOCKER_COMPOSE run distributed-pca-b compute --mode intermediate
60+
echo "Run the distributed-pca-agg..."
61+
$DOCKER_COMPOSE run distributed-pca-agg compute --mode aggregate --job-ids 1 2
62+
63+
echo
64+
# Cleanup
65+
_cleanup

python-correlation-heatmap/tests/unit/test_correlation_heatmap.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def intermediate_data_2():
105105
@mock.patch('correlation_heatmap.io_helper.fetch_data')
106106
@mock.patch('correlation_heatmap.io_helper.get_results')
107107
@mock.patch('correlation_heatmap.io_helper.save_results')
108-
def test_aggregate_stats(mock_save_results, mock_get_results, mock_fetch):
108+
def test_aggregate_stats_correlation_heatmap(mock_save_results, mock_get_results, mock_fetch):
109109

110110
def mock_results(job_id):
111111
job_id = str(job_id)
@@ -116,7 +116,7 @@ def mock_results(job_id):
116116

117117
mock_get_results.side_effect = mock_results
118118

119-
aggregate_stats([1, 2])
119+
aggregate_stats([1, 2], graph_type='correlation_heatmap')
120120
results = json.loads(mock_save_results.call_args[0][0])
121121
assert round_dict(results) == [
122122
{
@@ -128,3 +128,22 @@ def mock_results(job_id):
128128
'zmax': 1
129129
}
130130
]
131+
132+
133+
@mock.patch('correlation_heatmap.io_helper.fetch_data')
134+
@mock.patch('correlation_heatmap.io_helper.get_results')
135+
@mock.patch('correlation_heatmap.io_helper.save_results')
136+
def test_aggregate_stats_pca(mock_save_results, mock_get_results, mock_fetch):
137+
138+
def mock_results(job_id):
139+
job_id = str(job_id)
140+
if job_id == '1':
141+
return mock.MagicMock(data=json.dumps(intermediate_data_1()))
142+
elif job_id == '2':
143+
return mock.MagicMock(data=json.dumps(intermediate_data_2()))
144+
145+
mock_get_results.side_effect = mock_results
146+
147+
aggregate_stats([1, 2], graph_type='pca')
148+
results = json.loads(mock_save_results.call_args[0][0])
149+
assert set(results.keys()) == {'layout', 'data'}

0 commit comments

Comments
 (0)