Skip to content

Commit 838a1d6

Browse files
committed
Merge branch 'develop' of github.com:lucasimi/tda-mapper-python into feature/default-overlap-frac
2 parents b9021aa + 35efd43 commit 838a1d6

File tree

6 files changed

+259
-186
lines changed

6 files changed

+259
-186
lines changed

src/tdamapper/_plot_matplotlib.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626

2727

2828
def plot_matplotlib(
29-
mapper_plot,
30-
width,
31-
height,
32-
title,
33-
colors,
34-
agg,
35-
cmap,
36-
):
29+
mapper_plot,
30+
width,
31+
height,
32+
title,
33+
colors,
34+
agg,
35+
cmap,
36+
):
3737
px = 1 / plt.rcParams['figure.dpi'] # pixel in inches
3838
fig, ax = plt.subplots(figsize=(width * px, height * px))
3939
ax.get_xaxis().set_visible(False)

src/tdamapper/_plot_plotly.py

Lines changed: 75 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,29 @@ def _edge_pos_array(graph, dim, node_pos):
4848

4949

5050
def plot_plotly(
51-
mapper_plot,
52-
width,
53-
height,
54-
title,
55-
colors,
56-
agg=np.nanmean,
57-
cmap='jet'
58-
):
51+
mapper_plot,
52+
width,
53+
height,
54+
title,
55+
colors,
56+
agg=np.nanmean,
57+
cmap='jet',
58+
):
5959
node_col = aggregate_graph(colors, mapper_plot.graph, agg)
6060
fig = _figure(mapper_plot, node_col, width, height, title, cmap)
6161
return fig
6262

6363

6464
def plot_plotly_update(
65-
mapper_plot,
66-
fig,
67-
width=None,
68-
height=None,
69-
title=None,
70-
colors=None,
71-
agg=None,
72-
cmap=None
73-
):
65+
mapper_plot,
66+
fig,
67+
width=None,
68+
height=None,
69+
title=None,
70+
colors=None,
71+
agg=None,
72+
cmap=None,
73+
):
7474
if (colors is not None) and (agg is not None) and (cmap is not None):
7575
_update_traces_col(mapper_plot, fig, colors, agg, cmap)
7676
if cmap is not None:
@@ -105,9 +105,10 @@ def _update_edge_trace_col(mapper_plot, fig, cmap, colors_agg, colors_list):
105105
line_color=colors_avg,
106106
line_colorscale=cmap,
107107
line_cmax=max(colors_list, default=None),
108-
line_cmin=min(colors_list, default=None)),
109-
selector=dict(
110-
name='edges_trace'))
108+
line_cmin=min(colors_list, default=None),
109+
),
110+
selector=dict(name='edges_trace'),
111+
)
111112

112113

113114
def _update_node_trace_col(mapper_plot, fig, colors_agg, colors_list):
@@ -116,44 +117,53 @@ def _update_node_trace_col(mapper_plot, fig, colors_agg, colors_list):
116117
text=_text(mapper_plot, colors_agg),
117118
marker_color=colors_list,
118119
marker_cmax=max(colors_list, default=None),
119-
marker_cmin=min(colors_list, default=None)),
120-
selector=dict(
121-
name='nodes_trace'))
120+
marker_cmin=min(colors_list, default=None),
121+
),
122+
selector=dict(name='nodes_trace'),
123+
)
122124

123125

124126
def _update_traces_cmap(mapper_plot, fig, cmap):
125127
fig.update_traces(
126128
patch=dict(
127129
marker_colorscale=cmap,
128-
marker_line_colorscale=cmap),
129-
selector=dict(
130-
name='nodes_trace'))
130+
marker_line_colorscale=cmap,
131+
),
132+
selector=dict(name='nodes_trace'),
133+
)
131134
if mapper_plot.dim == 3:
132135
fig.update_traces(
133-
patch=dict(
134-
line_colorscale=cmap),
135-
selector=dict(
136-
name='edges_trace'))
136+
patch=dict(line_colorscale=cmap),
137+
selector=dict(name='edges_trace'),
138+
)
137139

138140

139141
def _update_traces_title(mapper_plot, fig, title):
140142
fig.update_traces(
141-
patch=dict(
142-
marker_colorbar=_colorbar(mapper_plot, title)),
143-
selector=dict(
144-
name='nodes_trace'))
143+
patch=dict(marker_colorbar=_colorbar(mapper_plot, title)),
144+
selector=dict(name='nodes_trace'),
145+
)
145146

146147

147148
def _update_layout(fig, width, height):
148149
fig.update_layout(
149150
width=width,
150-
height=height)
151+
height=height,
152+
)
151153

152154

153155
def _figure(mapper_plot, node_col, width, height, title, cmap):
154156
node_pos = mapper_plot.positions
155-
node_pos_arr = _node_pos_array(mapper_plot.graph, mapper_plot.dim, node_pos)
156-
edge_pos_arr = _edge_pos_array(mapper_plot.graph, mapper_plot.dim, node_pos)
157+
node_pos_arr = _node_pos_array(
158+
mapper_plot.graph,
159+
mapper_plot.dim,
160+
node_pos,
161+
)
162+
edge_pos_arr = _edge_pos_array(
163+
mapper_plot.graph,
164+
mapper_plot.dim,
165+
node_pos,
166+
)
157167
_edges_tr = _edges_trace(mapper_plot, edge_pos_arr, node_col, cmap)
158168
_nodes_tr = _nodes_trace(mapper_plot, node_pos_arr, node_col, title, cmap)
159169
_layout_ = _layout(width, height)
@@ -166,7 +176,8 @@ def _nodes_trace(mapper_plot, node_pos_arr, node_col, title, cmap):
166176
attr_size = nx.get_node_attributes(mapper_plot.graph, ATTR_SIZE)
167177
max_size = max(attr_size.values(), default=1.0)
168178
scatter_text = _text(mapper_plot, node_col)
169-
marker_size = [25.0 * math.sqrt(attr_size[n] / max_size) for n in mapper_plot.graph.nodes()]
179+
marker_size = [25.0 * math.sqrt(attr_size[n] / max_size) for n in
180+
mapper_plot.graph.nodes()]
170181
colors = list(node_col.values())
171182
scatter = dict(
172183
name='nodes_trace',
@@ -188,10 +199,11 @@ def _nodes_trace(mapper_plot, node_pos_arr, node_col, title, cmap):
188199
colorscale=cmap,
189200
cmin=min(colors, default=None),
190201
cmax=max(colors, default=None),
191-
colorbar=_colorbar(mapper_plot, title)))
202+
colorbar=_colorbar(mapper_plot, title),
203+
),
204+
)
192205
if mapper_plot.dim == 3:
193-
scatter.update(dict(
194-
z=node_pos_arr[2]))
206+
scatter.update(dict(z=node_pos_arr[2]))
195207
return go.Scatter3d(scatter)
196208
elif mapper_plot.dim == 2:
197209
return go.Scatter(scatter)
@@ -206,7 +218,8 @@ def _edges_trace(mapper_plot, edge_pos_arr, node_col, cmap):
206218
opacity=_EDGE_OPACITY,
207219
line_width=_EDGE_WIDTH,
208220
line_color=_EDGE_COLOR,
209-
hoverinfo='skip')
221+
hoverinfo='skip',
222+
)
210223
if mapper_plot.dim == 3:
211224
colors_avg = []
212225
for e in mapper_plot.graph.edges():
@@ -216,16 +229,20 @@ def _edges_trace(mapper_plot, edge_pos_arr, node_col, cmap):
216229
colors_avg.append(c1)
217230
colors = list(node_col.values())
218231
scatter.update(dict(
219-
z=edge_pos_arr[2],
220-
line_color=colors_avg,
221-
line_cmin=min(colors, default=None),
222-
line_cmax=max(colors, default=None),
223-
line_colorscale=cmap))
232+
z=edge_pos_arr[2],
233+
line_color=colors_avg,
234+
line_cmin=min(colors, default=None),
235+
line_cmax=max(colors, default=None),
236+
line_colorscale=cmap,
237+
),
238+
)
224239
return go.Scatter3d(scatter)
225240
elif mapper_plot.dim == 2:
226241
scatter.update(dict(
227-
marker_colorscale=cmap,
228-
marker_line_colorscale=cmap))
242+
marker_colorscale=cmap,
243+
marker_line_colorscale=cmap,
244+
),
245+
)
229246
return go.Scatter(scatter)
230247

231248

@@ -239,7 +256,8 @@ def _layout(width, height):
239256
showticklabels=False,
240257
showgrid=False,
241258
zeroline=False,
242-
title='')
259+
title='',
260+
)
243261
scene_axis = dict(
244262
showgrid=True,
245263
visible=True,
@@ -252,7 +270,8 @@ def _layout(width, height):
252270
linewidth=1,
253271
mirror=True,
254272
showticklabels=False,
255-
title='')
273+
title='',
274+
)
256275
return go.Layout(
257276
uirevision='constant',
258277
plot_bgcolor='rgba(0, 0, 0, 0)',
@@ -267,7 +286,9 @@ def _layout(width, height):
267286
scene=dict(
268287
xaxis=scene_axis,
269288
yaxis=scene_axis,
270-
zaxis=scene_axis))
289+
zaxis=scene_axis,
290+
),
291+
)
271292

272293

273294
def _colorbar(mapper_plot, title):
@@ -285,7 +306,8 @@ def _colorbar(mapper_plot, title):
285306
tickwidth=1,
286307
tickformat='.2g',
287308
nticks=_TICKS_NUM,
288-
tickmode='auto')
309+
tickmode='auto',
310+
)
289311
if title is not None:
290312
cbar['title'] = title
291313
if mapper_plot.dim == 3:

src/tdamapper/_plot_pyvis.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@
1414

1515

1616
def plot_pyvis(
17-
mapper_plot,
18-
notebook,
19-
output_file,
20-
colors,
21-
agg,
22-
title,
23-
width,
24-
height,
25-
cmap,
26-
):
17+
mapper_plot,
18+
notebook,
19+
output_file,
20+
colors,
21+
agg,
22+
title,
23+
width,
24+
height,
25+
cmap,
26+
):
2727
net = _compute_net(
2828
mapper_plot=mapper_plot,
2929
width=width,
@@ -37,22 +37,22 @@ def plot_pyvis(
3737

3838

3939
def _compute_net(
40-
mapper_plot,
41-
notebook,
42-
colors,
43-
agg,
44-
width,
45-
height,
46-
cmap,
47-
):
40+
mapper_plot,
41+
notebook,
42+
colors,
43+
agg,
44+
width,
45+
height,
46+
cmap,
47+
):
4848
net = Network(
4949
height=height,
5050
width=width,
5151
directed=False,
5252
notebook=notebook,
5353
select_menu=True,
5454
filter_menu=True,
55-
neighborhood_highlight=True
55+
neighborhood_highlight=True,
5656
)
5757
net.toggle_physics(False)
5858
graph = mapper_plot.graph

src/tdamapper/clustering.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,33 @@
44

55
from tdamapper.core import mapper_connected_components, TrivialCover
66
import tdamapper.core
7-
from tdamapper._common import ParamsMixin, clone
7+
from tdamapper._common import ParamsMixin, clone, warn_deprecated
88

99

1010
class TrivialClustering(tdamapper.core.TrivialClustering):
1111
"""
12-
Deprecated. Use :class:`tdamapper.core.TrivialClustering`.
12+
**DEPRECATED**: This class is deprecated and will be removed in a future
13+
release. Use :class:`tdamapper.core.TrivialClustering`.
1314
"""
14-
pass
15+
def __init__(self):
16+
warn_deprecated(
17+
TrivialClustering.__qualname__,
18+
tdamapper.core.TrivialClustering.__qualname__,
19+
)
20+
super().__init__()
1521

1622

1723
class FailSafeClustering(tdamapper.core.FailSafeClustering):
1824
"""
19-
Deprecated. Use :class:`tdamapper.core.FailSafeClustering`.
25+
**DEPRECATED**: This class is deprecated and will be removed in a future
26+
release. Use :class:`tdamapper.core.FailSafeClustering`.
2027
"""
21-
pass
28+
def __init__(self, clustering=None, verbose=True):
29+
warn_deprecated(
30+
FailSafeClustering.__qualname__,
31+
tdamapper.core.FailSafeClustering.__qualname__,
32+
)
33+
super().__init__(clustering, verbose)
2234

2335

2436
class MapperClustering(ParamsMixin):
@@ -41,11 +53,16 @@ class MapperClustering(ParamsMixin):
4153
dataset.
4254
:type clustering: A class compatible with scikit-learn estimators from
4355
:mod:`sklearn.cluster`
56+
:param n_jobs: The maximum number of parallel clustering jobs. This
57+
parameter is passed to the constructor of :class:`joblib.Parallel`.
58+
Defaults to 1.
59+
:type n_jobs: int
4460
"""
4561

46-
def __init__(self, cover=None, clustering=None):
62+
def __init__(self, cover=None, clustering=None, n_jobs=1):
4763
self.cover = cover
4864
self.clustering = clustering
65+
self.n_jobs = n_jobs
4966

5067
def fit(self, X, y=None):
5168
cover = TrivialCover() if self.cover is None \
@@ -54,7 +71,14 @@ def fit(self, X, y=None):
5471
clustering = TrivialClustering() if self.clustering is None \
5572
else self.clustering
5673
clustering = clone(clustering)
74+
n_jobs = self.n_jobs
5775
y = X if y is None else y
58-
itm_lbls = mapper_connected_components(X, y, cover, clustering)
76+
itm_lbls = mapper_connected_components(
77+
X,
78+
y,
79+
cover,
80+
clustering,
81+
n_jobs=n_jobs,
82+
)
5983
self.labels_ = [itm_lbls[i] for i, _ in enumerate(X)]
6084
return self

0 commit comments

Comments
 (0)