2323from umap import UMAP
2424
2525from tdamapper ._plot_plotly import _marker_size
26+ from tdamapper .core import aggregate_graph
2627from tdamapper .cover import BallCover , CubicalCover
2728from tdamapper .learn import MapperAlgorithm
2829from tdamapper .plot import MapperPlot
@@ -677,7 +678,7 @@ def mapper_figure_section(df_X, df_y, mapper_plot):
677678 colors = colors ,
678679 node_size = 1.0 ,
679680 _agg = agg ,
680- cmap = "Viridis " ,
681+ cmap = "Jet " ,
681682 agg_name = agg_name ,
682683 colors_feat = colors_feat ,
683684 )
@@ -696,73 +697,76 @@ def mapper_figure_section(df_X, df_y, mapper_plot):
696697 scaleratio = 1 ,
697698 )
698699
700+ menu_cmap = {}
699701 if dim == 2 :
700- _set_cmap_buttons_2d (mapper_fig )
702+ menu_cmap = _set_cmap_buttons_2d (mapper_fig )
703+ menu_feature = _set_feature_select_button_2d (
704+ mapper_plot , mapper_fig , df_X , df_y , agg
705+ )
701706
702707 elif dim == 3 :
703- _set_cmap_buttons_3d (mapper_fig )
708+ menu_cmap = _set_cmap_buttons_3d (mapper_fig )
709+ menu_feature = _set_feature_select_button_3d (
710+ mapper_plot , mapper_fig , df_X , df_y , agg
711+ )
704712
705- _set_node_size_slider (mapper_plot , mapper_fig )
713+ slider_size = _set_node_size_slider (mapper_plot , mapper_fig )
714+
715+ mapper_fig .update_layout (
716+ updatemenus = [menu_cmap , menu_feature ],
717+ sliders = [slider_size ],
718+ uirevision = "constant" ,
719+ )
706720
707721 return mapper_fig
708722
709723
710724def _set_cmap_buttons_2d (fig ):
711- fig . update_layout (
712- updatemenus = [
725+ return dict (
726+ buttons = [
713727 dict (
714- buttons = [
715- dict (
716- label = cmap_name ,
717- method = "restyle" ,
718- args = [
719- {
720- "marker.colorscale" : [cmap ],
721- "marker.line.colorscale" : [cmap ],
722- },
723- [0 , 1 ], # Trace indices
724- ],
725- )
726- for cmap_name , cmap in V_CMAPS .items ()
728+ label = cmap_name ,
729+ method = "restyle" ,
730+ args = [
731+ {
732+ "marker.colorscale" : [cmap ],
733+ "marker.line.colorscale" : [cmap ],
734+ },
735+ [0 , 1 ], # Trace indices
727736 ],
728- x = 0 ,
729- xanchor = "left" ,
730- y = 0.75 ,
731- yanchor = "bottom" ,
732- direction = "down" ,
733737 )
738+ for cmap_name , cmap in V_CMAPS .items ()
734739 ],
735- uirevision = "constant" ,
740+ x = 0 ,
741+ xanchor = "left" ,
742+ y = 0.75 ,
743+ yanchor = "bottom" ,
744+ direction = "down" ,
736745 )
737746
738747
739748def _set_cmap_buttons_3d (fig ):
740- fig . update_layout (
741- updatemenus = [
749+ return dict (
750+ buttons = [
742751 dict (
743- buttons = [
744- dict (
745- label = cmap_name ,
746- method = "restyle" ,
747- args = [
748- {
749- "marker.colorscale" : [cmap , cmap ],
750- "marker.line.colorscale" : [cmap , cmap ],
751- "line.colorscale" : [cmap , cmap ],
752- },
753- [0 , 1 ], # update both traces
754- ],
755- )
756- for cmap_name , cmap in V_CMAPS .items ()
752+ label = cmap_name ,
753+ method = "restyle" ,
754+ args = [
755+ {
756+ "marker.colorscale" : [cmap , cmap ],
757+ "marker.line.colorscale" : [cmap , cmap ],
758+ "line.colorscale" : [cmap , cmap ],
759+ },
760+ [0 , 1 ], # update both traces
757761 ],
758- x = 0 ,
759- xanchor = "left" ,
760- y = 0.75 ,
761- yanchor = "bottom" ,
762- direction = "down" ,
763762 )
763+ for cmap_name , cmap in V_CMAPS .items ()
764764 ],
765- uirevision = "constant" ,
765+ x = 0 ,
766+ xanchor = "left" ,
767+ y = 0.75 ,
768+ yanchor = "bottom" ,
769+ direction = "down" ,
766770 )
767771
768772
@@ -780,20 +784,132 @@ def _set_node_size_slider(mapper_plot, fig):
780784 )
781785 )
782786
783- fig .update_layout (
784- sliders = [
787+ return dict (
788+ active = len (steps ) // 2 ,
789+ currentvalue = {"prefix" : "Node size: " },
790+ steps = steps ,
791+ x = 0 ,
792+ y = 0.85 ,
793+ xanchor = "left" ,
794+ len = 0.15 ,
795+ yanchor = "bottom" ,
796+ )
797+
798+
799+ def _compute_colors_agg (mapper_plot , df_X , df_y , col_feat , agg ):
800+ X_cols = list (df_X .columns )
801+ y_cols = list (df_y .columns )
802+ if col_feat in X_cols :
803+ colors = df_X [col_feat ].to_numpy ()
804+ elif col_feat in y_cols :
805+ colors = df_y [col_feat ].to_numpy ()
806+ return aggregate_graph (colors , mapper_plot .graph , agg )
807+
808+
809+ def _compute_colors (mapper_plot , df_X , df_y , col_feat , agg ):
810+ nodes_col = _compute_colors_agg (mapper_plot , df_X , df_y , col_feat , agg )
811+ return list (nodes_col .values ())
812+
813+
814+ def _set_feature_select_button_2d (mapper_plot , mapper_fig , df_X , df_y , agg ):
815+ col_feats = list (df_X .columns ) + list (df_y .columns )
816+ return dict (
817+ buttons = [
785818 dict (
786- active = len (steps ) // 2 ,
787- currentvalue = {"prefix" : "Node size: " },
788- steps = steps ,
789- x = 0 ,
790- y = 0.85 ,
791- xanchor = "left" ,
792- len = 0.15 ,
793- yanchor = "bottom" ,
819+ label = f"Feature: { col_feat } " ,
820+ method = "restyle" ,
821+ args = [
822+ {
823+ "marker.color" : [
824+ _compute_colors (mapper_plot , df_X , df_y , col_feat , agg )
825+ ],
826+ "marker.cmax" : [
827+ max (
828+ _compute_colors (mapper_plot , df_X , df_y , col_feat , agg ),
829+ default = None ,
830+ )
831+ ],
832+ "marker.cmin" : [
833+ min (
834+ _compute_colors (mapper_plot , df_X , df_y , col_feat , agg ),
835+ default = None ,
836+ )
837+ ],
838+ },
839+ [1 ], # Trace indices
840+ ],
794841 )
842+ for col_feat in col_feats
795843 ],
796- uirevision = "constant" ,
844+ x = 0 ,
845+ xanchor = "left" ,
846+ y = 0.65 ,
847+ yanchor = "bottom" ,
848+ direction = "down" ,
849+ )
850+
851+
852+ def _edge_colors (mapper_plot , df_X , df_y , col_feat , agg ):
853+ colors_avg = []
854+ colors_agg = _compute_colors_agg (mapper_plot , df_X , df_y , col_feat , agg )
855+ for edge in mapper_plot .graph .edges ():
856+ c0 , c1 = colors_agg [edge [0 ]], colors_agg [edge [1 ]]
857+ colors_avg .append (c0 )
858+ colors_avg .append (c1 )
859+ colors_avg .append (c1 )
860+ return colors_avg
861+
862+
863+ def _set_feature_select_button_3d (mapper_plot , mapper_fig , df_X , df_y , agg ):
864+ col_feats = list (df_X .columns ) + list (df_y .columns )
865+ return dict (
866+ buttons = [
867+ dict (
868+ label = f"Feature: { col_feat } " ,
869+ method = "restyle" ,
870+ args = [
871+ {
872+ "marker.color" : [
873+ _compute_colors (mapper_plot , df_X , df_y , col_feat , agg )
874+ ],
875+ "marker.cmax" : [
876+ max (
877+ _compute_colors (mapper_plot , df_X , df_y , col_feat , agg ),
878+ default = None ,
879+ )
880+ ],
881+ "marker.cmin" : [
882+ min (
883+ _compute_colors (mapper_plot , df_X , df_y , col_feat , agg ),
884+ default = None ,
885+ )
886+ ],
887+ "line.color" : [
888+ _edge_colors (mapper_plot , df_X , df_y , col_feat , agg )
889+ ],
890+ "line.cmax" : [
891+ max (
892+ _compute_colors (mapper_plot , df_X , df_y , col_feat , agg ),
893+ default = None ,
894+ )
895+ ],
896+ "line.cmin" : [
897+ min (
898+ _compute_colors (mapper_plot , df_X , df_y , col_feat , agg ),
899+ default = None ,
900+ )
901+ ],
902+ },
903+ [1 ], # Trace indices
904+ ],
905+ )
906+ for col_feat in col_feats
907+ ],
908+ x = 0 ,
909+ xanchor = "left" ,
910+ y = 0.65 ,
911+ yanchor = "bottom" ,
912+ direction = "down" ,
797913 )
798914
799915
0 commit comments