2222from tdamapper .plot import MapperLayoutInteractive
2323
2424
25- MAX_NODES = 500
25+ MAX_NODES = 1000
2626
2727MAX_SAMPLES = 1000
2828
7474
7575V_DATA_SUMMARY_COLOR = 'color'
7676
77- V_DATA_SUMMARY_BINS = 5
77+ V_DATA_SUMMARY_BINS = 15
7878
7979# VD_* are reusable default values for widgets
8080
@@ -95,12 +95,13 @@ def __init__(self):
9595 self .df_X_sample = pd .DataFrame ()
9696 self .df_y_sample = pd .DataFrame ()
9797 self .df_all = pd .DataFrame ()
98+ self .df_all_sample = pd .DataFrame ()
9899 self .X = self .df_X .to_numpy ()
99100 self .df_summary = pd .DataFrame ()
100101 self .mapper_graph = nx .Graph ()
101102 self .mapper_plot = None
102103 self .mapper_fig = self ._init_fig ()
103- self .auto_rendering = None
104+ self .mapper_fig_outdated = True
104105
105106 def _init_fig (self ):
106107 fig = go .Figure (
@@ -135,32 +136,29 @@ def set_df(self, X, y):
135136 self .df_X_sample = get_sample (self .df_X )
136137 self .df_y_sample = get_sample (self .df_y )
137138 self .df_all = pd .concat ([self .df_y , self .df_X ], axis = 1 )
139+ self .df_all_sample = pd .concat ([self .df_y_sample , self .df_X_sample ], axis = 1 )
138140 self .X = self .df_X .to_numpy ()
139141 self .df_summary = _get_data_summary (self .df_X , self .df_y )
140142 self .mapper_graph = nx .Graph ()
141143 self .mapper_plot = None
142144 self .mapper_fig = self ._init_fig ()
143- self .auto_rendering = None
145+ self .mapper_fig_outdated = True
144146
145147 def set_mapper (self , mapper_graph ):
146148 self .mapper_graph = mapper_graph
147149 self .mapper_plot = MapperLayoutInteractive (
148150 self .mapper_graph ,
149151 dim = VD_DIM ,
150- height = 450 ,
151- width = 450 ,
152+ height = 500 ,
153+ width = 500 ,
152154 colors = self .X ,
153155 seed = VD_SEED )
154156 self .mapper_fig = self ._init_fig ()
155- nodes_num = mapper_graph .number_of_nodes ()
156- if nodes_num <= MAX_NODES :
157- self .auto_rendering = True
158- else :
159- self .auto_rendering = False
157+ self .mapper_fig_outdated = True
160158
161159 def set_mapper_fig (self , mapper_fig ):
162160 self .mapper_fig = mapper_fig
163- self .auto_rendering = None
161+ self .mapper_fig_outdated = False
164162
165163
166164def lp_metric (p ):
@@ -216,10 +214,15 @@ def _get_data_summary(df_X, df_y):
216214 df_summary = pd .DataFrame ({
217215 V_DATA_SUMMARY_FEAT : df .columns ,
218216 V_DATA_SUMMARY_HIST : df_hist .values .tolist ()})
219- df_summary [V_DATA_SUMMARY_COLOR ] = False
220217 return df_summary
221218
222219
220+ def auto_rendering ():
221+ mapper_graph = st .session_state [S_RESULTS ].mapper_graph
222+ nodes_num = mapper_graph .number_of_nodes ()
223+ return nodes_num <= MAX_NODES
224+
225+
223226def _mapper_caption ():
224227 mapper_graph = st .session_state [S_RESULTS ].mapper_graph
225228 nodes_num = 0
@@ -364,10 +367,10 @@ def _data_caption():
364367 st .caption (cap )
365368
366369
367- def _data_summary ():
368- df_all = st .session_state [S_RESULTS ].df_all
370+ def _data_preview ():
371+ df_all_sample = st .session_state [S_RESULTS ].df_all_sample
369372 st .dataframe (
370- df_all .head (50 ),
373+ df_all_sample .head (50 ),
371374 use_container_width = True ,
372375 height = 200 )
373376
@@ -386,10 +389,31 @@ def _data_download():
386389
387390def data_output_section ():
388391 _data_caption ()
389- _data_summary ()
392+ _data_preview ()
390393 _data_download ()
391394
392395
396+ def _data_summary ():
397+ df_summary = st .session_state [S_RESULTS ].df_summary
398+ st .dataframe (df_summary ,
399+ hide_index = True ,
400+ height = 250 ,
401+ column_config = {
402+ V_DATA_SUMMARY_HIST : st .column_config .AreaChartColumn (
403+ width = 'large' ),
404+ V_DATA_SUMMARY_FEAT : st .column_config .TextColumn (
405+ width = 'small' ,
406+ disabled = True )
407+ },
408+ use_container_width = True )
409+
410+
411+ def data_summary_section ():
412+ _data_caption ()
413+ _data_summary ()
414+ #_data_download()
415+
416+
393417def data_input_section ():
394418 data_source = None
395419 data_source_type = st .selectbox (
@@ -419,8 +443,7 @@ def _update_mapper(X, lens, cover, clustering):
419443 mapper_graph = mapper_algo .fit_transform (X , lens )
420444 st .session_state [S_RESULTS ].set_mapper (mapper_graph )
421445 st .toast ('Successfully Computed Mapper' , icon = '✅' )
422- auto_rendering = st .session_state [S_RESULTS ].auto_rendering
423- if auto_rendering is False :
446+ if not auto_rendering ():
424447 st .toast ('Automatic Rendering Disabled: Graph Too Large' , icon = '⚠️' )
425448
426449
@@ -534,97 +557,67 @@ def _update_fig(seed, colors, agg):
534557 st .toast ('Successfully Rendered Graph' , icon = '✅' )
535558
536559
537- def _update_auto_rendering ():
538- mapper_graph = st .session_state [S_RESULTS ].mapper_graph
539- nodes_num = mapper_graph .number_of_nodes () if mapper_graph else 0
540- if nodes_num <= MAX_NODES :
541- st .session_state [S_RESULTS ].auto_rendering = True
560+ def _update_mapper_fig_outdated ():
561+ st .session_state [S_RESULTS ].mapper_fig_outdated = True
542562
543563
544564def _mapper_colors ():
545565 X = st .session_state [S_RESULTS ].X
546566 df_all = st .session_state [S_RESULTS ].df_all
547- df_summary = st .session_state [S_RESULTS ].df_summary
548567 colors = X
549- data_edit = st .data_editor (
550- df_summary ,
551- height = 250 ,
552- hide_index = True ,
553- disabled = (c for c in df_summary .columns if c != V_DATA_SUMMARY_COLOR ),
554- use_container_width = True ,
555- column_config = {
556- V_DATA_SUMMARY_HIST : st .column_config .AreaChartColumn (
557- width = 'small' ),
558- V_DATA_SUMMARY_FEAT : st .column_config .TextColumn (
559- width = 'small' ,
560- disabled = True ),
561- V_DATA_SUMMARY_COLOR : st .column_config .CheckboxColumn (
562- width = 'small' ,
563- disabled = False )
564- }, on_change = _update_auto_rendering )
565- if not data_edit .empty :
566- color_features = data_edit [data_edit [V_DATA_SUMMARY_COLOR ]][V_DATA_SUMMARY_FEAT ]
567- if not color_features .empty :
568- selected = pd .concat ([df_all [c ] for c in color_features ], axis = 1 )
569- if not selected .empty :
570- colors = selected .to_numpy ()
568+ col_feat = st .selectbox (
569+ 'Color' ,
570+ options = list (df_all .columns ),
571+ on_change = _update_mapper_fig_outdated )
572+ if col_feat in df_all .columns :
573+ df_col = df_all [col_feat ]
574+ colors = df_col .to_numpy ()
571575 return colors
572576
573577
574- def _agg ( name , axis ):
578+ def _mapper_aggregation ( ):
575579 agg = None
576- agg_sel = st .selectbox (
577- name ,
580+ agg_type = st .selectbox (
581+ 'Aggregation' ,
578582 options = ['Mean' , 'Std' , 'Quantile' ],
579- on_change = _update_auto_rendering )
580- if agg_sel == 'Mean' :
581- agg = lambda x : np .nanmean (x , axis = axis )
582- elif agg_sel == 'Std' :
583- agg = lambda x : np .nanstd (x , axis = axis )
584- elif agg_sel == 'Quantile' :
585- r = st .slider ('Rank' , min_value = 0.0 , value = 0.5 , max_value = 1.0 )
586- agg = lambda x : np .quantile (x , q = r , axis = axis )
583+ on_change = _update_mapper_fig_outdated )
584+ if agg_type == 'Mean' :
585+ agg = np .nanmean
586+ elif agg_type == 'Std' :
587+ agg = np .nanstd
588+ elif agg_type == 'Quantile' :
589+ q = st .slider (
590+ 'Rank' ,
591+ value = 0.5 ,
592+ min_value = 0.0 ,
593+ max_value = 1.0 ,
594+ on_change = _update_mapper_fig_outdated )
595+ agg = lambda x : np .nanquantile (x , q = q )
587596 return agg
588597
589598
590- def _mapper_feature_agg ():
591- return _agg ('Feature Aggregation' , axis = 1 )
592-
593-
594- def _mapper_node_agg ():
595- return _agg ('Node Aggregation' , axis = 0 )
596-
597-
598- def _mapper_agg ():
599- feat_agg = _mapper_feature_agg ()
600- node_agg = _mapper_node_agg ()
601- return lambda x : node_agg (feat_agg (x ))
602-
603-
604599def _mapper_seed ():
605600 seed = st .number_input (
606601 'Seed' ,
607602 value = VD_SEED ,
608603 help = 'Changing this value alters the shape' ,
609- on_change = _update_auto_rendering )
604+ on_change = _update_mapper_fig_outdated )
610605 return seed
611606
612607
613- def mapper_draw_section ():
614- colors = _mapper_colors ()
615- agg = _mapper_agg ()
608+ def mapper_draw_section (colors ):
616609 seed = _mapper_seed ()
617- auto_rendering = st .session_state [S_RESULTS ].auto_rendering
618- if auto_rendering :
610+ agg = _mapper_aggregation ()
611+ mapper_plot = st .session_state [S_RESULTS ].mapper_plot
612+ update_button = st .button (
613+ '🎨 Draw' ,
614+ use_container_width = True ,
615+ disabled = mapper_plot is None )
616+ mapper_fig_outdated = st .session_state [S_RESULTS ].mapper_fig_outdated
617+ if auto_rendering () and mapper_fig_outdated :
618+ _update_fig (seed , colors , agg )
619+ elif update_button :
619620 _update_fig (seed , colors , agg )
620- else :
621- mapper_plot = st .session_state [S_RESULTS ].mapper_plot
622- update_button = st .button (
623- '🎨 Draw' ,
624- use_container_width = True ,
625- disabled = mapper_plot is None )
626- if update_button :
627- _update_fig (seed , colors , agg )
628621
629622
630623def mapper_rendering_section ():
@@ -645,13 +638,16 @@ def main():
645638 lens_type , cover_type , clustering_type = mapper_settings_section ()
646639 with st .popover ('🚀 Run' , use_container_width = True ):
647640 mapper_run_section (lens_type , cover_type , clustering_type )
641+ colors = _mapper_colors ()
648642 with st .popover ('🎨 Draw' , use_container_width = True ):
649- mapper_draw_section ()
643+ mapper_draw_section (colors )
650644 with st .popover ('ℹ️ More' , use_container_width = True ):
651- tab_0 , tab_1 = st .tabs (['🗒️ Data' , '📊 Mapper' ])
645+ tab_0 , tab_1 , tab_2 = st .tabs (['📈 Features' , '🗒️ Data' , '📊 Mapper' ])
652646 with tab_0 :
653- data_output_section ()
647+ data_summary_section ()
654648 with tab_1 :
649+ data_output_section ()
650+ with tab_2 :
655651 mapper_output_section ()
656652 with col_1 :
657653 mapper_rendering_section ()
0 commit comments