Skip to content

Commit cc33f65

Browse files
committed
Moved ui elements to plotly plot
1 parent 2207be8 commit cc33f65

1 file changed

Lines changed: 117 additions & 20 deletions

File tree

app/streamlit_app.py

Lines changed: 117 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sklearn.decomposition import PCA
2323
from umap import UMAP
2424

25+
from tdamapper._plot_plotly import _marker_size
2526
from tdamapper.cover import BallCover, CubicalCover
2627
from tdamapper.learn import MapperAlgorithm
2728
from tdamapper.plot import MapperPlot
@@ -100,6 +101,16 @@
100101

101102
V_CMAP_TWILIGHT = "Twilight (Cyclic)"
102103

104+
V_CMAPS = {
105+
V_CMAP_JET: "Jet",
106+
V_CMAP_VIRIDIS: "Viridis",
107+
V_CMAP_CIVIDIS: "Cividis",
108+
V_CMAP_SPECTRAL: "Spectral",
109+
V_CMAP_PORTLAND: "Portland",
110+
V_CMAP_HSV: "HSV",
111+
V_CMAP_TWILIGHT: "Twilight",
112+
}
113+
103114
GIT_REPO_URL = "https://github.com/lucasimi/tda-mapper-python"
104115

105116
ICON_URL = f"{GIT_REPO_URL}/raw/main/docs/source/logos/tda-mapper-logo-icon.png"
@@ -574,21 +585,7 @@ def plot_cmap_input_section():
574585
V_CMAP_TWILIGHT,
575586
],
576587
)
577-
cmap = None
578-
if cmap_type == V_CMAP_JET:
579-
cmap = "Jet"
580-
elif cmap_type == V_CMAP_VIRIDIS:
581-
cmap = "Viridis"
582-
elif cmap_type == V_CMAP_CIVIDIS:
583-
cmap = "Cividis"
584-
elif cmap_type == V_CMAP_PORTLAND:
585-
cmap = "Portland"
586-
elif cmap_type == V_CMAP_SPECTRAL:
587-
cmap = "Spectral"
588-
elif cmap_type == V_CMAP_HSV:
589-
cmap = "HSV"
590-
elif cmap_type == V_CMAP_TWILIGHT:
591-
cmap = "Twilight"
588+
cmap = V_CMAPS.get(cmap_type, "Jet")
592589
return cmap
593590

594591

@@ -667,21 +664,20 @@ def compute_mapper_fig(
667664
width=600,
668665
height=600,
669666
)
667+
mapper_fig.update_layout(uirevision="constant")
670668
return mapper_fig
671669

672670

673671
def mapper_figure_section(df_X, df_y, mapper_plot):
674672
st.header("🎨 Plot")
675673
agg, agg_name = plot_agg_input_section()
676-
cmap = plot_cmap_input_section()
677674
colors, colors_feat = plot_color_input_section(df_X, df_y)
678-
node_size = st.slider("Node size", min_value=0.1, max_value=10.0, value=1.0)
679675
mapper_fig = compute_mapper_fig(
680676
mapper_plot,
681677
colors=colors,
682-
node_size=node_size,
678+
node_size=1.0,
683679
_agg=agg,
684-
cmap=cmap,
680+
cmap="Viridis",
685681
agg_name=agg_name,
686682
colors_feat=colors_feat,
687683
)
@@ -699,16 +695,117 @@ def mapper_figure_section(df_X, df_y, mapper_plot):
699695
scaleanchor="x",
700696
scaleratio=1,
701697
)
698+
699+
if dim == 2:
700+
_set_cmap_buttons_2d(mapper_fig)
701+
702+
elif dim == 3:
703+
_set_cmap_buttons_3d(mapper_fig)
704+
705+
_set_node_size_slider(mapper_plot, mapper_fig)
706+
702707
return mapper_fig
703708

704709

710+
def _set_cmap_buttons_2d(fig):
711+
fig.update_layout(
712+
updatemenus=[
713+
dict(
714+
buttons=[
715+
dict(
716+
label=cmap_name,
717+
method="restyle",
718+
args=[
719+
{
720+
"marker.colorscale": [cmap],
721+
"marker.line.colorscale": [cmap],
722+
},
723+
[0, 1], # Trace indices
724+
],
725+
)
726+
for cmap_name, cmap in V_CMAPS.items()
727+
],
728+
x=0,
729+
xanchor="left",
730+
y=0.75,
731+
yanchor="bottom",
732+
direction="down",
733+
)
734+
],
735+
uirevision="constant",
736+
)
737+
738+
739+
def _set_cmap_buttons_3d(fig):
740+
fig.update_layout(
741+
updatemenus=[
742+
dict(
743+
buttons=[
744+
dict(
745+
label=cmap_name,
746+
method="restyle",
747+
args=[
748+
{
749+
"marker.colorscale": [cmap, cmap],
750+
"marker.line.colorscale": [cmap, cmap],
751+
"line.colorscale": [cmap, cmap],
752+
},
753+
[0, 1], # update both traces
754+
],
755+
)
756+
for cmap_name, cmap in V_CMAPS.items()
757+
],
758+
x=0,
759+
xanchor="left",
760+
y=0.75,
761+
yanchor="bottom",
762+
direction="down",
763+
)
764+
],
765+
uirevision="constant",
766+
)
767+
768+
769+
def _set_node_size_slider(mapper_plot, fig):
770+
steps = []
771+
for node_size in [x / 10.0 for x in range(1, 20)]: # Sizes from 5 to 30
772+
steps.append(
773+
dict(
774+
method="restyle",
775+
label=f"{node_size}",
776+
args=[
777+
{"marker.size": [_marker_size(mapper_plot, node_size)]},
778+
[1],
779+
], # Update marker size for trace 0
780+
)
781+
)
782+
783+
fig.update_layout(
784+
sliders=[
785+
dict(
786+
active=len(steps) // 2,
787+
currentvalue={"prefix": "Node size: "},
788+
steps=steps,
789+
x=0,
790+
y=0.85,
791+
xanchor="left",
792+
len=0.15,
793+
yanchor="bottom",
794+
)
795+
],
796+
uirevision="constant",
797+
)
798+
799+
705800
def mapper_rendering_section(mapper_graph, mapper_fig):
706801
config = {
707802
"scrollZoom": True,
708803
"displaylogo": False,
709804
"modeBarButtonsToRemove": ["zoom", "pan"],
710805
}
711-
st.plotly_chart(mapper_fig, use_container_width=True, config=config)
806+
st.plotly_chart(
807+
mapper_fig, use_container_width=True, config=config, key="mapper_plot"
808+
)
712809

713810

714811
def data_summary_section(df_X, df_y, mapper_graph):

0 commit comments

Comments
 (0)