77import pandas as pd
88import numpy as np
99import plotly .express as px
10- import plotly .graph_objects as go
1110
1211import networkx as nx
1312from networkx .readwrite .json_graph import adjacency_data
9089
9190VD_SEED = 42
9291
93- VD_DIM = 3
94-
9592# S_* are reusable session stored objects
9693
9794S_RESULTS = 'stored_results'
@@ -108,40 +105,16 @@ def __init__(self):
108105 self .df_all_sample = pd .DataFrame ()
109106 self .X = self .df_X .to_numpy ()
110107 self .df_summary = pd .DataFrame ()
111- self .mapper_graph = nx . Graph ()
108+ self .mapper_graph = None
112109 self .mapper_plot = None
113- self .mapper_fig = self ._init_fig ()
110+ self .mapper_dim = None
111+ self .mapper_fig = None
114112 self .mapper_fig_outdated = True
115113 self .auto_rendering = self ._auto_rendering ()
116114
117- def _init_fig (self ):
118- fig = go .Figure (
119- data = [go .Scatter3d (
120- x = [],
121- y = [],
122- z = [],
123- mode = 'markers' )])
124- fig .update_layout (
125- uirevision = 'constant' ,
126- scene = dict (
127- xaxis = dict (
128- showgrid = True ,
129- zeroline = True ,
130- showline = True ,
131- ticks = 'outside' ),
132- yaxis = dict (
133- showgrid = True ,
134- zeroline = True ,
135- showline = True ,
136- ticks = 'outside' ),
137- zaxis = dict (
138- showgrid = True ,
139- zeroline = True ,
140- showline = True ,
141- ticks = 'outside' )))
142- return fig
143-
144115 def _auto_rendering (self ):
116+ if self .mapper_graph is None :
117+ return False
145118 nodes_num = self .mapper_graph .number_of_nodes ()
146119 return nodes_num <= MAX_NODES
147120
@@ -154,27 +127,57 @@ def set_df(self, X, y):
154127 self .df_all_sample = pd .concat ([self .df_y_sample , self .df_X_sample ], axis = 1 )
155128 self .X = self .df_X .to_numpy ()
156129 self .df_summary = _get_data_summary (self .df_X , self .df_y )
157- self .mapper_graph = nx . Graph ()
130+ self .mapper_graph = None
158131 self .mapper_plot = None
159- self .mapper_fig = self ._init_fig ()
132+ self .mapper_dim = None
133+ self .mapper_fig = None
160134 self .mapper_fig_outdated = True
161135 self .auto_rendering = self ._auto_rendering ()
162136
163137 def set_mapper (self , mapper_graph ):
164138 self .mapper_graph = mapper_graph
165- self .mapper_plot = MapperLayoutInteractive (
166- self .mapper_graph ,
167- dim = VD_DIM ,
168- height = 500 ,
169- width = 500 ,
170- colors = self .X ,
171- seed = VD_SEED )
172- self .mapper_fig = self ._init_fig ()
139+ self .mapper_plot = None
140+ self .mapper_dim = None
141+ self .mapper_fig = None
173142 self .mapper_fig_outdated = True
174143 self .auto_rendering = self ._auto_rendering ()
175144
176- def set_mapper_fig (self , mapper_fig ):
177- self .mapper_fig = mapper_fig
145+ def set_mapper_fig (self , dim , seed , color_feat , agg , cmap , title ):
146+ colors = self .X
147+ df_all = st .session_state [S_RESULTS ].df_all
148+ if color_feat in df_all .columns :
149+ df_col = df_all [color_feat ]
150+ colors = df_col .to_numpy ()
151+ if (self .mapper_plot is None ) or (dim != self .mapper_dim ):
152+ self .mapper_plot = MapperLayoutInteractive (
153+ self .mapper_graph ,
154+ dim = dim ,
155+ seed = seed ,
156+ colors = colors ,
157+ agg = agg ,
158+ cmap = cmap ,
159+ title = title ,
160+ height = 500 ,
161+ width = 500 )
162+ self .mapper_dim = dim
163+ else :
164+ self .mapper_plot .update (
165+ seed = seed ,
166+ colors = colors ,
167+ agg = agg ,
168+ cmap = cmap ,
169+ title = title )
170+ self .mapper_fig = self .mapper_plot .plot ()
171+ self .mapper_fig .update_layout (
172+ uirevision = 'constant' ,
173+ margin = dict (b = 0 , l = 0 , r = 0 , t = 0 ))
174+ self .mapper_fig .update_xaxes (
175+ #constrain='domain',
176+ showline = False )
177+ self .mapper_fig .update_yaxes (
178+ showline = False ,
179+ scaleanchor = 'x' ,
180+ scaleratio = 1 )
178181 self .mapper_fig_outdated = False
179182
180183
@@ -295,27 +298,8 @@ def _update_mapper(X, lens, cover, clustering):
295298 st .toast ('Automatic Rendering Disabled: Graph Too Large' , icon = '⚠️' )
296299
297300
298- def _update_fig (seed , color_feat , agg , cmap , title ):
299- mapper_plot = st .session_state [S_RESULTS ].mapper_plot
300- if mapper_plot is None :
301- return
302- X = st .session_state [S_RESULTS ].X
303- colors = X
304- df_all = st .session_state [S_RESULTS ].df_all
305- if color_feat in df_all .columns :
306- df_col = df_all [color_feat ]
307- colors = df_col .to_numpy ()
308- mapper_plot .update (
309- colors = colors ,
310- seed = seed ,
311- agg = agg ,
312- title = title ,
313- cmap = cmap )
314- mapper_fig = mapper_plot .plot ()
315- mapper_fig .update_layout (
316- uirevision = 'constant' ,
317- margin = dict (b = 0 , l = 0 , r = 0 , t = 0 ))
318- st .session_state [S_RESULTS ].set_mapper_fig (mapper_fig )
301+ def _update_fig (dim , seed , color_feat , agg , cmap , title ):
302+ st .session_state [S_RESULTS ].set_mapper_fig (dim , seed , color_feat , agg , cmap , title )
319303 st .toast ('Successfully Rendered Graph' , icon = '🖌️' )
320304
321305
@@ -374,11 +358,12 @@ def _clustering_tuning(clustering_type):
374358 if clustering_type == V_CLUSTERING_TRIVIAL :
375359 clustering = TrivialClustering ()
376360 elif clustering_type == V_CLUSTERING_AGGLOMERATIVE :
377- clust_n = st .number_input (
361+ clust_num = st .number_input (
378362 '🧮 Clusters' ,
379363 value = 2 ,
380364 min_value = 1 )
381- clustering = AgglomerativeClustering (n_clusters = clust_n )
365+ n_clusters = int (clust_num )
366+ clustering = AgglomerativeClustering (n_clusters = n_clusters )
382367 return clustering
383368
384369
@@ -457,6 +442,15 @@ def _mapper_aggregation():
457442 return agg , agg_name
458443
459444
445+ def _mapper_dim ():
446+ toggle_3d = st .toggle (
447+ '3D Rendering' ,
448+ value = True ,
449+ on_change = _update_mapper_fig_outdated )
450+ dim = 3 if toggle_3d else 2
451+ return dim
452+
453+
460454def _mapper_seed ():
461455 seed = st .number_input (
462456 '🎲 Seed' ,
@@ -500,29 +494,35 @@ def _mapper_caption():
500494
501495def _mapper_histogram ():
502496 mapper_graph = st .session_state [S_RESULTS ].mapper_graph
503- ccs = nx .connected_components (mapper_graph )
504- size = nx .get_node_attributes (mapper_graph , ATTR_SIZE )
505- node_cc , node_size = {}, {}
506- node_cc_max , node_size_max = 1 , 1
507- for cc in ccs :
508- cc_len = len (cc )
509- for u in cc :
510- u_size = size [u ]
511- node_cc [u ] = cc_len
512- node_size [u ] = u_size
513- if u_size > node_size_max :
514- node_size_max = u_size
515- if cc_len > node_cc_max :
516- node_cc_max = cc_len
517- arr_size = np .array ([node_size [u ]/ node_size_max for u in mapper_graph .nodes ()])
518- arr_cc = np .array ([node_cc [u ]/ node_cc_max for u in mapper_graph .nodes ()])
497+ node_size_feat = 'node size (rel.)'
498+ cc_size_feat = 'conn. comp. size (rel.)'
519499 df = pd .DataFrame (dict (
520- series = np .concatenate ((
521- ['node size (rel.)' ] * len (arr_size ),
522- ['conn. comp. size (rel.)' ] * len (arr_cc ))),
523- data = np .concatenate ((
524- arr_size ,
525- arr_cc ))))
500+ series = [node_size_feat , cc_size_feat ],
501+ data = [0.0 , 0.0 ]))
502+ if mapper_graph :
503+ ccs = nx .connected_components (mapper_graph )
504+ size = nx .get_node_attributes (mapper_graph , ATTR_SIZE )
505+ node_cc , node_size = {}, {}
506+ node_cc_max , node_size_max = 1 , 1
507+ for cc in ccs :
508+ cc_len = len (cc )
509+ for u in cc :
510+ u_size = size [u ]
511+ node_cc [u ] = cc_len
512+ node_size [u ] = u_size
513+ if u_size > node_size_max :
514+ node_size_max = u_size
515+ if cc_len > node_cc_max :
516+ node_cc_max = cc_len
517+ arr_size = np .array ([node_size [u ]/ node_size_max for u in mapper_graph .nodes ()])
518+ arr_cc = np .array ([node_cc [u ]/ node_cc_max for u in mapper_graph .nodes ()])
519+ df = pd .DataFrame (dict (
520+ series = np .concatenate ((
521+ [node_size_feat ] * len (arr_size ),
522+ [cc_size_feat ] * len (arr_cc ))),
523+ data = np .concatenate ((
524+ arr_size ,
525+ arr_cc ))))
526526 fig = px .histogram (
527527 df ,
528528 nbins = 20 ,
@@ -559,11 +559,10 @@ def _mapper_download():
559559 mapper_graph = st .session_state [S_RESULTS ].mapper_graph
560560 mapper_adj = {} if mapper_graph is None else adjacency_data (mapper_graph )
561561 mapper_json = json .dumps (mapper_adj , default = int )
562- nodes_num = mapper_graph .number_of_nodes ()
563562 return st .download_button (
564563 '📥 Download Mapper' ,
565564 data = get_gzip_bytes (mapper_json ),
566- disabled = nodes_num < 1 ,
565+ disabled = mapper_graph is None ,
567566 use_container_width = True ,
568567 file_name = f'mapper_graph_{ int (time .time ())} .json.gzip' )
569568
@@ -622,7 +621,6 @@ def mapper_run_section(lens_type, cover_type, clustering_type):
622621
623622
624623def mapper_color_section ():
625- X = st .session_state [S_RESULTS ].X
626624 df_all = st .session_state [S_RESULTS ].df_all
627625 col_feat = st .selectbox (
628626 '🎨 Color' ,
@@ -632,6 +630,7 @@ def mapper_color_section():
632630
633631
634632def mapper_draw_section (color_feat ):
633+ dim = _mapper_dim ()
635634 seed = _mapper_seed ()
636635 cmap = _mapper_cmap ()
637636 agg , agg_name = _mapper_aggregation ()
@@ -644,9 +643,9 @@ def mapper_draw_section(color_feat):
644643 auto_rendering = st .session_state [S_RESULTS ].auto_rendering
645644 title = f'{ agg_name } of { color_feat } '
646645 if auto_rendering and mapper_fig_outdated :
647- _update_fig (seed , color_feat , agg , cmap , title )
646+ _update_fig (dim , seed , color_feat , agg , cmap , title )
648647 elif update_button :
649- _update_fig (seed , color_feat , agg , cmap , title )
648+ _update_fig (dim , seed , color_feat , agg , cmap , title )
650649
651650
652651def data_summary_section ():
@@ -669,11 +668,12 @@ def mapper_output_section():
669668
670669def mapper_rendering_section ():
671670 mapper_fig = st .session_state [S_RESULTS ].mapper_fig
672- with st .container (border = False ):
673- st .plotly_chart (
674- mapper_fig ,
675- height = 350 ,
676- use_container_width = True )
671+ if mapper_fig is not None :
672+ with st .container (border = False ):
673+ st .plotly_chart (
674+ mapper_fig ,
675+ height = 350 ,
676+ use_container_width = True )
677677
678678
679679def main ():
0 commit comments