Skip to content

Commit 821b8d5

Browse files
committed
Improved data handling
1 parent 14b1f53 commit 821b8d5

File tree

1 file changed

+45
-37
lines changed

1 file changed

+45
-37
lines changed

app/streamlit_app.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def __init__(self):
108108
self.df_all_sample = pd.DataFrame()
109109
self.X = self.df_X.to_numpy()
110110
self.df_summary = pd.DataFrame()
111-
self.mapper_graph = nx.Graph()
111+
self.mapper_graph = None
112112
self.mapper_plot = None
113-
self.mapper_fig = self._init_fig()
113+
self.mapper_fig = None
114114
self.mapper_fig_outdated = True
115115
self.auto_rendering = self._auto_rendering()
116116

@@ -142,6 +142,8 @@ def _init_fig(self):
142142
return fig
143143

144144
def _auto_rendering(self):
145+
if self.mapper_graph is None:
146+
return False
145147
nodes_num = self.mapper_graph.number_of_nodes()
146148
return nodes_num <= MAX_NODES
147149

@@ -154,9 +156,9 @@ def set_df(self, X, y):
154156
self.df_all_sample = pd.concat([self.df_y_sample, self.df_X_sample], axis=1)
155157
self.X = self.df_X.to_numpy()
156158
self.df_summary = _get_data_summary(self.df_X, self.df_y)
157-
self.mapper_graph = nx.Graph()
159+
self.mapper_graph = None
158160
self.mapper_plot = None
159-
self.mapper_fig = self._init_fig()
161+
self.mapper_fig = None
160162
self.mapper_fig_outdated = True
161163
self.auto_rendering = self._auto_rendering()
162164

@@ -169,7 +171,7 @@ def set_mapper(self, mapper_graph):
169171
width=500,
170172
colors=self.X,
171173
seed=VD_SEED)
172-
self.mapper_fig = self._init_fig()
174+
self.mapper_fig = None
173175
self.mapper_fig_outdated = True
174176
self.auto_rendering = self._auto_rendering()
175177

@@ -374,11 +376,12 @@ def _clustering_tuning(clustering_type):
374376
if clustering_type == V_CLUSTERING_TRIVIAL:
375377
clustering = TrivialClustering()
376378
elif clustering_type == V_CLUSTERING_AGGLOMERATIVE:
377-
clust_n = st.number_input(
379+
clust_num = st.number_input(
378380
'🧮 Clusters',
379381
value=2,
380382
min_value=1)
381-
clustering = AgglomerativeClustering(n_clusters=clust_n)
383+
n_clusters = int(clust_num)
384+
clustering = AgglomerativeClustering(n_clusters=n_clusters)
382385
return clustering
383386

384387

@@ -500,29 +503,35 @@ def _mapper_caption():
500503

501504
def _mapper_histogram():
502505
mapper_graph = st.session_state[S_RESULTS].mapper_graph
503-
ccs = nx.connected_components(mapper_graph)
504-
size = nx.get_node_attributes(mapper_graph, ATTR_SIZE)
505-
node_cc, node_size = {}, {}
506-
node_cc_max, node_size_max = 1, 1
507-
for cc in ccs:
508-
cc_len = len(cc)
509-
for u in cc:
510-
u_size = size[u]
511-
node_cc[u] = cc_len
512-
node_size[u] = u_size
513-
if u_size > node_size_max:
514-
node_size_max = u_size
515-
if cc_len > node_cc_max:
516-
node_cc_max = cc_len
517-
arr_size = np.array([node_size[u]/node_size_max for u in mapper_graph.nodes()])
518-
arr_cc = np.array([node_cc[u]/node_cc_max for u in mapper_graph.nodes()])
506+
node_size_feat = 'node size (rel.)'
507+
cc_size_feat = 'conn. comp. size (rel.)'
519508
df = pd.DataFrame(dict(
520-
series=np.concatenate((
521-
['node size (rel.)'] * len(arr_size),
522-
['conn. comp. size (rel.)'] * len(arr_cc))),
523-
data=np.concatenate((
524-
arr_size,
525-
arr_cc))))
509+
series=[node_size_feat, cc_size_feat],
510+
data=[0.0, 0.0]))
511+
if mapper_graph:
512+
ccs = nx.connected_components(mapper_graph)
513+
size = nx.get_node_attributes(mapper_graph, ATTR_SIZE)
514+
node_cc, node_size = {}, {}
515+
node_cc_max, node_size_max = 1, 1
516+
for cc in ccs:
517+
cc_len = len(cc)
518+
for u in cc:
519+
u_size = size[u]
520+
node_cc[u] = cc_len
521+
node_size[u] = u_size
522+
if u_size > node_size_max:
523+
node_size_max = u_size
524+
if cc_len > node_cc_max:
525+
node_cc_max = cc_len
526+
arr_size = np.array([node_size[u]/node_size_max for u in mapper_graph.nodes()])
527+
arr_cc = np.array([node_cc[u]/node_cc_max for u in mapper_graph.nodes()])
528+
df = pd.DataFrame(dict(
529+
series=np.concatenate((
530+
[node_size_feat] * len(arr_size),
531+
[cc_size_feat] * len(arr_cc))),
532+
data=np.concatenate((
533+
arr_size,
534+
arr_cc))))
526535
fig = px.histogram(
527536
df,
528537
nbins=20,
@@ -559,11 +568,10 @@ def _mapper_download():
559568
mapper_graph = st.session_state[S_RESULTS].mapper_graph
560569
mapper_adj = {} if mapper_graph is None else adjacency_data(mapper_graph)
561570
mapper_json = json.dumps(mapper_adj, default=int)
562-
nodes_num = mapper_graph.number_of_nodes()
563571
return st.download_button(
564572
'📥 Download Mapper',
565573
data=get_gzip_bytes(mapper_json),
566-
disabled=nodes_num < 1,
574+
disabled=mapper_graph is None,
567575
use_container_width=True,
568576
file_name=f'mapper_graph_{int(time.time())}.json.gzip')
569577

@@ -622,7 +630,6 @@ def mapper_run_section(lens_type, cover_type, clustering_type):
622630

623631

624632
def mapper_color_section():
625-
X = st.session_state[S_RESULTS].X
626633
df_all = st.session_state[S_RESULTS].df_all
627634
col_feat = st.selectbox(
628635
'🎨 Color',
@@ -669,11 +676,12 @@ def mapper_output_section():
669676

670677
def mapper_rendering_section():
671678
mapper_fig = st.session_state[S_RESULTS].mapper_fig
672-
with st.container(border=False):
673-
st.plotly_chart(
674-
mapper_fig,
675-
height=350,
676-
use_container_width=True)
679+
if mapper_fig is not None:
680+
with st.container(border=False):
681+
st.plotly_chart(
682+
mapper_fig,
683+
height=350,
684+
use_container_width=True)
677685

678686

679687
def main():

0 commit comments

Comments
 (0)