2222from sklearn .decomposition import PCA
2323from umap import UMAP
2424
25+ from tdamapper ._plot_plotly import _marker_size
2526from tdamapper .cover import BallCover , CubicalCover
2627from tdamapper .learn import MapperAlgorithm
2728from tdamapper .plot import MapperPlot
100101
101102V_CMAP_TWILIGHT = "Twilight (Cyclic)"
102103
104+ V_CMAPS = {
105+ V_CMAP_JET : "Jet" ,
106+ V_CMAP_VIRIDIS : "Viridis" ,
107+ V_CMAP_CIVIDIS : "Cividis" ,
108+ V_CMAP_SPECTRAL : "Spectral" ,
109+ V_CMAP_PORTLAND : "Portland" ,
110+ V_CMAP_HSV : "HSV" ,
111+ V_CMAP_TWILIGHT : "Twilight" ,
112+ }
113+
103114GIT_REPO_URL = "https://github.com/lucasimi/tda-mapper-python"
104115
105116ICON_URL = f"{ GIT_REPO_URL } /raw/main/docs/source/logos/tda-mapper-logo-icon.png"
@@ -574,21 +585,7 @@ def plot_cmap_input_section():
574585 V_CMAP_TWILIGHT ,
575586 ],
576587 )
577- cmap = None
578- if cmap_type == V_CMAP_JET :
579- cmap = "Jet"
580- elif cmap_type == V_CMAP_VIRIDIS :
581- cmap = "Viridis"
582- elif cmap_type == V_CMAP_CIVIDIS :
583- cmap = "Cividis"
584- elif cmap_type == V_CMAP_PORTLAND :
585- cmap = "Portland"
586- elif cmap_type == V_CMAP_SPECTRAL :
587- cmap = "Spectral"
588- elif cmap_type == V_CMAP_HSV :
589- cmap = "HSV"
590- elif cmap_type == V_CMAP_TWILIGHT :
591- cmap = "Twilight"
588+ cmap = V_CMAPS .get (cmap_type , "Jet" )
592589 return cmap
593590
594591
@@ -667,21 +664,20 @@ def compute_mapper_fig(
667664 width = 600 ,
668665 height = 600 ,
669666 )
667+ mapper_fig .update_layout (uirevision = "constant" )
670668 return mapper_fig
671669
672670
673671def mapper_figure_section (df_X , df_y , mapper_plot ):
674672 st .header ("🎨 Plot" )
675673 agg , agg_name = plot_agg_input_section ()
676- cmap = plot_cmap_input_section ()
677674 colors , colors_feat = plot_color_input_section (df_X , df_y )
678- node_size = st .slider ("Node size" , min_value = 0.1 , max_value = 10.0 , value = 1.0 )
679675 mapper_fig = compute_mapper_fig (
680676 mapper_plot ,
681677 colors = colors ,
682- node_size = node_size ,
678+ node_size = 1.0 ,
683679 _agg = agg ,
684- cmap = cmap ,
680+ cmap = "Viridis" ,
685681 agg_name = agg_name ,
686682 colors_feat = colors_feat ,
687683 )
@@ -699,16 +695,117 @@ def mapper_figure_section(df_X, df_y, mapper_plot):
699695 scaleanchor = "x" ,
700696 scaleratio = 1 ,
701697 )
698+
699+ if dim == 2 :
700+ _set_cmap_buttons_2d (mapper_fig )
701+
702+ elif dim == 3 :
703+ _set_cmap_buttons_3d (mapper_fig )
704+
705+ _set_node_size_slider (mapper_plot , mapper_fig )
706+
702707 return mapper_fig
703708
704709
710+ def _set_cmap_buttons_2d (fig ):
711+ fig .update_layout (
712+ updatemenus = [
713+ dict (
714+ buttons = [
715+ dict (
716+ label = cmap_name ,
717+ method = "restyle" ,
718+ args = [
719+ {
720+ "marker.colorscale" : [cmap ],
721+ "marker.line.colorscale" : [cmap ],
722+ },
723+ [0 , 1 ], # Trace indices
724+ ],
725+ )
726+ for cmap_name , cmap in V_CMAPS .items ()
727+ ],
728+ x = 0 ,
729+ xanchor = "left" ,
730+ y = 0.75 ,
731+ yanchor = "bottom" ,
732+ direction = "down" ,
733+ )
734+ ],
735+ uirevision = "constant" ,
736+ )
737+
738+
739+ def _set_cmap_buttons_3d (fig ):
740+ fig .update_layout (
741+ updatemenus = [
742+ dict (
743+ buttons = [
744+ dict (
745+ label = cmap_name ,
746+ method = "restyle" ,
747+ args = [
748+ {
749+ "marker.colorscale" : [cmap , cmap ],
750+ "marker.line.colorscale" : [cmap , cmap ],
751+ "line.colorscale" : [cmap , cmap ],
752+ },
753+ [0 , 1 ], # update both traces
754+ ],
755+ )
756+ for cmap_name , cmap in V_CMAPS .items ()
757+ ],
758+ x = 0 ,
759+ xanchor = "left" ,
760+ y = 0.75 ,
761+ yanchor = "bottom" ,
762+ direction = "down" ,
763+ )
764+ ],
765+ uirevision = "constant" ,
766+ )
767+
768+
769+ def _set_node_size_slider (mapper_plot , fig ):
770+ steps = []
771+ for node_size in [x / 10.0 for x in range (1 , 20 )]: # Sizes from 5 to 30
772+ steps .append (
773+ dict (
774+ method = "restyle" ,
775+ label = f"{ node_size } " ,
776+ args = [
777+ {"marker.size" : [_marker_size (mapper_plot , node_size )]},
778+ [1 ],
779+ ], # Update marker size for trace 0
780+ )
781+ )
782+
783+ fig .update_layout (
784+ sliders = [
785+ dict (
786+ active = len (steps ) // 2 ,
787+ currentvalue = {"prefix" : "Node size: " },
788+ steps = steps ,
789+ x = 0 ,
790+ y = 0.85 ,
791+ xanchor = "left" ,
792+ len = 0.15 ,
793+ yanchor = "bottom" ,
794+ )
795+ ],
796+ uirevision = "constant" ,
797+ )
798+
799+
705800def mapper_rendering_section (mapper_graph , mapper_fig ):
706801 config = {
707802 "scrollZoom" : True ,
708803 "displaylogo" : False ,
709804 "modeBarButtonsToRemove" : ["zoom" , "pan" ],
710805 }
711- st .plotly_chart (mapper_fig , use_container_width = True , config = config )
806+ st .plotly_chart (
807+ mapper_fig , use_container_width = True , config = config , key = "mapper_plot"
808+ )
712809
713810
714811def data_summary_section (df_X , df_y , mapper_graph ):
0 commit comments