@@ -108,9 +108,9 @@ def __init__(self):
108108 self .df_all_sample = pd .DataFrame ()
109109 self .X = self .df_X .to_numpy ()
110110 self .df_summary = pd .DataFrame ()
111- self .mapper_graph = nx . Graph ()
111+ self .mapper_graph = None
112112 self .mapper_plot = None
113- self .mapper_fig = self . _init_fig ()
113+ self .mapper_fig = None
114114 self .mapper_fig_outdated = True
115115 self .auto_rendering = self ._auto_rendering ()
116116
@@ -142,6 +142,8 @@ def _init_fig(self):
142142 return fig
143143
144144 def _auto_rendering (self ):
145+ if self .mapper_graph is None :
146+ return False
145147 nodes_num = self .mapper_graph .number_of_nodes ()
146148 return nodes_num <= MAX_NODES
147149
@@ -154,9 +156,9 @@ def set_df(self, X, y):
154156 self .df_all_sample = pd .concat ([self .df_y_sample , self .df_X_sample ], axis = 1 )
155157 self .X = self .df_X .to_numpy ()
156158 self .df_summary = _get_data_summary (self .df_X , self .df_y )
157- self .mapper_graph = nx . Graph ()
159+ self .mapper_graph = None
158160 self .mapper_plot = None
159- self .mapper_fig = self . _init_fig ()
161+ self .mapper_fig = None
160162 self .mapper_fig_outdated = True
161163 self .auto_rendering = self ._auto_rendering ()
162164
@@ -169,7 +171,7 @@ def set_mapper(self, mapper_graph):
169171 width = 500 ,
170172 colors = self .X ,
171173 seed = VD_SEED )
172- self .mapper_fig = self . _init_fig ()
174+ self .mapper_fig = None
173175 self .mapper_fig_outdated = True
174176 self .auto_rendering = self ._auto_rendering ()
175177
@@ -374,11 +376,12 @@ def _clustering_tuning(clustering_type):
374376 if clustering_type == V_CLUSTERING_TRIVIAL :
375377 clustering = TrivialClustering ()
376378 elif clustering_type == V_CLUSTERING_AGGLOMERATIVE :
377- clust_n = st .number_input (
379+ clust_num = st .number_input (
378380 '🧮 Clusters' ,
379381 value = 2 ,
380382 min_value = 1 )
381- clustering = AgglomerativeClustering (n_clusters = clust_n )
383+ n_clusters = int (clust_num )
384+ clustering = AgglomerativeClustering (n_clusters = n_clusters )
382385 return clustering
383386
384387
@@ -500,29 +503,35 @@ def _mapper_caption():
500503
501504def _mapper_histogram ():
502505 mapper_graph = st .session_state [S_RESULTS ].mapper_graph
503- ccs = nx .connected_components (mapper_graph )
504- size = nx .get_node_attributes (mapper_graph , ATTR_SIZE )
505- node_cc , node_size = {}, {}
506- node_cc_max , node_size_max = 1 , 1
507- for cc in ccs :
508- cc_len = len (cc )
509- for u in cc :
510- u_size = size [u ]
511- node_cc [u ] = cc_len
512- node_size [u ] = u_size
513- if u_size > node_size_max :
514- node_size_max = u_size
515- if cc_len > node_cc_max :
516- node_cc_max = cc_len
517- arr_size = np .array ([node_size [u ]/ node_size_max for u in mapper_graph .nodes ()])
518- arr_cc = np .array ([node_cc [u ]/ node_cc_max for u in mapper_graph .nodes ()])
506+ node_size_feat = 'node size (rel.)'
507+ cc_size_feat = 'conn. comp. size (rel.)'
519508 df = pd .DataFrame (dict (
520- series = np .concatenate ((
521- ['node size (rel.)' ] * len (arr_size ),
522- ['conn. comp. size (rel.)' ] * len (arr_cc ))),
523- data = np .concatenate ((
524- arr_size ,
525- arr_cc ))))
509+ series = [node_size_feat , cc_size_feat ],
510+ data = [0.0 , 0.0 ]))
511+ if mapper_graph :
512+ ccs = nx .connected_components (mapper_graph )
513+ size = nx .get_node_attributes (mapper_graph , ATTR_SIZE )
514+ node_cc , node_size = {}, {}
515+ node_cc_max , node_size_max = 1 , 1
516+ for cc in ccs :
517+ cc_len = len (cc )
518+ for u in cc :
519+ u_size = size [u ]
520+ node_cc [u ] = cc_len
521+ node_size [u ] = u_size
522+ if u_size > node_size_max :
523+ node_size_max = u_size
524+ if cc_len > node_cc_max :
525+ node_cc_max = cc_len
526+ arr_size = np .array ([node_size [u ]/ node_size_max for u in mapper_graph .nodes ()])
527+ arr_cc = np .array ([node_cc [u ]/ node_cc_max for u in mapper_graph .nodes ()])
528+ df = pd .DataFrame (dict (
529+ series = np .concatenate ((
530+ [node_size_feat ] * len (arr_size ),
531+ [cc_size_feat ] * len (arr_cc ))),
532+ data = np .concatenate ((
533+ arr_size ,
534+ arr_cc ))))
526535 fig = px .histogram (
527536 df ,
528537 nbins = 20 ,
@@ -559,11 +568,10 @@ def _mapper_download():
559568 mapper_graph = st .session_state [S_RESULTS ].mapper_graph
560569 mapper_adj = {} if mapper_graph is None else adjacency_data (mapper_graph )
561570 mapper_json = json .dumps (mapper_adj , default = int )
562- nodes_num = mapper_graph .number_of_nodes ()
563571 return st .download_button (
564572 '📥 Download Mapper' ,
565573 data = get_gzip_bytes (mapper_json ),
566- disabled = nodes_num < 1 ,
574+ disabled = mapper_graph is None ,
567575 use_container_width = True ,
568576 file_name = f'mapper_graph_{ int (time .time ())} .json.gzip' )
569577
@@ -622,7 +630,6 @@ def mapper_run_section(lens_type, cover_type, clustering_type):
622630
623631
624632def mapper_color_section ():
625- X = st .session_state [S_RESULTS ].X
626633 df_all = st .session_state [S_RESULTS ].df_all
627634 col_feat = st .selectbox (
628635 '🎨 Color' ,
@@ -669,11 +676,12 @@ def mapper_output_section():
669676
670677def mapper_rendering_section ():
671678 mapper_fig = st .session_state [S_RESULTS ].mapper_fig
672- with st .container (border = False ):
673- st .plotly_chart (
674- mapper_fig ,
675- height = 350 ,
676- use_container_width = True )
679+ if mapper_fig is not None :
680+ with st .container (border = False ):
681+ st .plotly_chart (
682+ mapper_fig ,
683+ height = 350 ,
684+ use_container_width = True )
677685
678686
679687def main ():
0 commit comments