Skip to content

Commit c7af7f9

Browse files
authored
Merge pull request #121 from lucasimi/develop
Develop
2 parents fc0437b + f5d7072 commit c7af7f9

File tree

1 file changed

+99
-99
lines changed

1 file changed

+99
-99
lines changed

app/streamlit_app.py

Lines changed: 99 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pandas as pd
88
import numpy as np
99
import plotly.express as px
10-
import plotly.graph_objects as go
1110

1211
import networkx as nx
1312
from networkx.readwrite.json_graph import adjacency_data
@@ -90,8 +89,6 @@
9089

9190
VD_SEED = 42
9291

93-
VD_DIM = 3
94-
9592
# S_* are reusable session stored objects
9693

9794
S_RESULTS = 'stored_results'
@@ -108,40 +105,16 @@ def __init__(self):
108105
self.df_all_sample = pd.DataFrame()
109106
self.X = self.df_X.to_numpy()
110107
self.df_summary = pd.DataFrame()
111-
self.mapper_graph = nx.Graph()
108+
self.mapper_graph = None
112109
self.mapper_plot = None
113-
self.mapper_fig = self._init_fig()
110+
self.mapper_dim = None
111+
self.mapper_fig = None
114112
self.mapper_fig_outdated = True
115113
self.auto_rendering = self._auto_rendering()
116114

117-
def _init_fig(self):
118-
fig = go.Figure(
119-
data=[go.Scatter3d(
120-
x=[],
121-
y=[],
122-
z=[],
123-
mode='markers')])
124-
fig.update_layout(
125-
uirevision='constant',
126-
scene=dict(
127-
xaxis=dict(
128-
showgrid=True,
129-
zeroline=True,
130-
showline=True,
131-
ticks='outside'),
132-
yaxis=dict(
133-
showgrid=True,
134-
zeroline=True,
135-
showline=True,
136-
ticks='outside'),
137-
zaxis=dict(
138-
showgrid=True,
139-
zeroline=True,
140-
showline=True,
141-
ticks='outside')))
142-
return fig
143-
144115
def _auto_rendering(self):
116+
if self.mapper_graph is None:
117+
return False
145118
nodes_num = self.mapper_graph.number_of_nodes()
146119
return nodes_num <= MAX_NODES
147120

@@ -154,27 +127,57 @@ def set_df(self, X, y):
154127
self.df_all_sample = pd.concat([self.df_y_sample, self.df_X_sample], axis=1)
155128
self.X = self.df_X.to_numpy()
156129
self.df_summary = _get_data_summary(self.df_X, self.df_y)
157-
self.mapper_graph = nx.Graph()
130+
self.mapper_graph = None
158131
self.mapper_plot = None
159-
self.mapper_fig = self._init_fig()
132+
self.mapper_dim = None
133+
self.mapper_fig = None
160134
self.mapper_fig_outdated = True
161135
self.auto_rendering = self._auto_rendering()
162136

163137
def set_mapper(self, mapper_graph):
164138
self.mapper_graph = mapper_graph
165-
self.mapper_plot = MapperLayoutInteractive(
166-
self.mapper_graph,
167-
dim=VD_DIM,
168-
height=500,
169-
width=500,
170-
colors=self.X,
171-
seed=VD_SEED)
172-
self.mapper_fig = self._init_fig()
139+
self.mapper_plot = None
140+
self.mapper_dim = None
141+
self.mapper_fig = None
173142
self.mapper_fig_outdated = True
174143
self.auto_rendering = self._auto_rendering()
175144

176-
def set_mapper_fig(self, mapper_fig):
177-
self.mapper_fig = mapper_fig
145+
def set_mapper_fig(self, dim, seed, color_feat, agg, cmap, title):
146+
colors = self.X
147+
df_all = st.session_state[S_RESULTS].df_all
148+
if color_feat in df_all.columns:
149+
df_col = df_all[color_feat]
150+
colors = df_col.to_numpy()
151+
if (self.mapper_plot is None) or (dim != self.mapper_dim):
152+
self.mapper_plot = MapperLayoutInteractive(
153+
self.mapper_graph,
154+
dim=dim,
155+
seed=seed,
156+
colors=colors,
157+
agg=agg,
158+
cmap=cmap,
159+
title=title,
160+
height=500,
161+
width=500)
162+
self.mapper_dim = dim
163+
else:
164+
self.mapper_plot.update(
165+
seed=seed,
166+
colors=colors,
167+
agg=agg,
168+
cmap=cmap,
169+
title=title)
170+
self.mapper_fig = self.mapper_plot.plot()
171+
self.mapper_fig.update_layout(
172+
uirevision='constant',
173+
margin=dict(b=0, l=0, r=0, t=0))
174+
self.mapper_fig.update_xaxes(
175+
#constrain='domain',
176+
showline=False)
177+
self.mapper_fig.update_yaxes(
178+
showline=False,
179+
scaleanchor='x',
180+
scaleratio = 1)
178181
self.mapper_fig_outdated = False
179182

180183

@@ -295,27 +298,8 @@ def _update_mapper(X, lens, cover, clustering):
295298
st.toast('Automatic Rendering Disabled: Graph Too Large', icon='⚠️')
296299

297300

298-
def _update_fig(seed, color_feat, agg, cmap, title):
299-
mapper_plot = st.session_state[S_RESULTS].mapper_plot
300-
if mapper_plot is None:
301-
return
302-
X = st.session_state[S_RESULTS].X
303-
colors = X
304-
df_all = st.session_state[S_RESULTS].df_all
305-
if color_feat in df_all.columns:
306-
df_col = df_all[color_feat]
307-
colors = df_col.to_numpy()
308-
mapper_plot.update(
309-
colors=colors,
310-
seed=seed,
311-
agg=agg,
312-
title=title,
313-
cmap=cmap)
314-
mapper_fig = mapper_plot.plot()
315-
mapper_fig.update_layout(
316-
uirevision='constant',
317-
margin=dict(b=0, l=0, r=0, t=0))
318-
st.session_state[S_RESULTS].set_mapper_fig(mapper_fig)
301+
def _update_fig(dim, seed, color_feat, agg, cmap, title):
302+
st.session_state[S_RESULTS].set_mapper_fig(dim, seed, color_feat, agg, cmap, title)
319303
st.toast('Successfully Rendered Graph', icon='🖌️')
320304

321305

@@ -374,11 +358,12 @@ def _clustering_tuning(clustering_type):
374358
if clustering_type == V_CLUSTERING_TRIVIAL:
375359
clustering = TrivialClustering()
376360
elif clustering_type == V_CLUSTERING_AGGLOMERATIVE:
377-
clust_n = st.number_input(
361+
clust_num = st.number_input(
378362
'🧮 Clusters',
379363
value=2,
380364
min_value=1)
381-
clustering = AgglomerativeClustering(n_clusters=clust_n)
365+
n_clusters = int(clust_num)
366+
clustering = AgglomerativeClustering(n_clusters=n_clusters)
382367
return clustering
383368

384369

@@ -457,6 +442,15 @@ def _mapper_aggregation():
457442
return agg, agg_name
458443

459444

445+
def _mapper_dim():
446+
toggle_3d = st.toggle(
447+
'3D Rendering',
448+
value=True,
449+
on_change=_update_mapper_fig_outdated)
450+
dim = 3 if toggle_3d else 2
451+
return dim
452+
453+
460454
def _mapper_seed():
461455
seed = st.number_input(
462456
'🎲 Seed',
@@ -500,29 +494,35 @@ def _mapper_caption():
500494

501495
def _mapper_histogram():
502496
mapper_graph = st.session_state[S_RESULTS].mapper_graph
503-
ccs = nx.connected_components(mapper_graph)
504-
size = nx.get_node_attributes(mapper_graph, ATTR_SIZE)
505-
node_cc, node_size = {}, {}
506-
node_cc_max, node_size_max = 1, 1
507-
for cc in ccs:
508-
cc_len = len(cc)
509-
for u in cc:
510-
u_size = size[u]
511-
node_cc[u] = cc_len
512-
node_size[u] = u_size
513-
if u_size > node_size_max:
514-
node_size_max = u_size
515-
if cc_len > node_cc_max:
516-
node_cc_max = cc_len
517-
arr_size = np.array([node_size[u]/node_size_max for u in mapper_graph.nodes()])
518-
arr_cc = np.array([node_cc[u]/node_cc_max for u in mapper_graph.nodes()])
497+
node_size_feat = 'node size (rel.)'
498+
cc_size_feat = 'conn. comp. size (rel.)'
519499
df = pd.DataFrame(dict(
520-
series=np.concatenate((
521-
['node size (rel.)'] * len(arr_size),
522-
['conn. comp. size (rel.)'] * len(arr_cc))),
523-
data=np.concatenate((
524-
arr_size,
525-
arr_cc))))
500+
series=[node_size_feat, cc_size_feat],
501+
data=[0.0, 0.0]))
502+
if mapper_graph:
503+
ccs = nx.connected_components(mapper_graph)
504+
size = nx.get_node_attributes(mapper_graph, ATTR_SIZE)
505+
node_cc, node_size = {}, {}
506+
node_cc_max, node_size_max = 1, 1
507+
for cc in ccs:
508+
cc_len = len(cc)
509+
for u in cc:
510+
u_size = size[u]
511+
node_cc[u] = cc_len
512+
node_size[u] = u_size
513+
if u_size > node_size_max:
514+
node_size_max = u_size
515+
if cc_len > node_cc_max:
516+
node_cc_max = cc_len
517+
arr_size = np.array([node_size[u]/node_size_max for u in mapper_graph.nodes()])
518+
arr_cc = np.array([node_cc[u]/node_cc_max for u in mapper_graph.nodes()])
519+
df = pd.DataFrame(dict(
520+
series=np.concatenate((
521+
[node_size_feat] * len(arr_size),
522+
[cc_size_feat] * len(arr_cc))),
523+
data=np.concatenate((
524+
arr_size,
525+
arr_cc))))
526526
fig = px.histogram(
527527
df,
528528
nbins=20,
@@ -559,11 +559,10 @@ def _mapper_download():
559559
mapper_graph = st.session_state[S_RESULTS].mapper_graph
560560
mapper_adj = {} if mapper_graph is None else adjacency_data(mapper_graph)
561561
mapper_json = json.dumps(mapper_adj, default=int)
562-
nodes_num = mapper_graph.number_of_nodes()
563562
return st.download_button(
564563
'📥 Download Mapper',
565564
data=get_gzip_bytes(mapper_json),
566-
disabled=nodes_num < 1,
565+
disabled=mapper_graph is None,
567566
use_container_width=True,
568567
file_name=f'mapper_graph_{int(time.time())}.json.gzip')
569568

@@ -622,7 +621,6 @@ def mapper_run_section(lens_type, cover_type, clustering_type):
622621

623622

624623
def mapper_color_section():
625-
X = st.session_state[S_RESULTS].X
626624
df_all = st.session_state[S_RESULTS].df_all
627625
col_feat = st.selectbox(
628626
'🎨 Color',
@@ -632,6 +630,7 @@ def mapper_color_section():
632630

633631

634632
def mapper_draw_section(color_feat):
633+
dim = _mapper_dim()
635634
seed = _mapper_seed()
636635
cmap = _mapper_cmap()
637636
agg, agg_name = _mapper_aggregation()
@@ -644,9 +643,9 @@ def mapper_draw_section(color_feat):
644643
auto_rendering = st.session_state[S_RESULTS].auto_rendering
645644
title = f'{agg_name} of {color_feat}'
646645
if auto_rendering and mapper_fig_outdated:
647-
_update_fig(seed, color_feat, agg, cmap, title)
646+
_update_fig(dim, seed, color_feat, agg, cmap, title)
648647
elif update_button:
649-
_update_fig(seed, color_feat, agg, cmap, title)
648+
_update_fig(dim, seed, color_feat, agg, cmap, title)
650649

651650

652651
def data_summary_section():
@@ -669,11 +668,12 @@ def mapper_output_section():
669668

670669
def mapper_rendering_section():
671670
mapper_fig = st.session_state[S_RESULTS].mapper_fig
672-
with st.container(border=False):
673-
st.plotly_chart(
674-
mapper_fig,
675-
height=350,
676-
use_container_width=True)
671+
if mapper_fig is not None:
672+
with st.container(border=False):
673+
st.plotly_chart(
674+
mapper_fig,
675+
height=350,
676+
use_container_width=True)
677677

678678

679679
def main():

0 commit comments

Comments
 (0)