Skip to content

Commit 266d67c

Browse files
committed
Added feature button
1 parent 36560bf commit 266d67c

1 file changed

Lines changed: 174 additions & 58 deletions

File tree

app/streamlit_app.py

Lines changed: 174 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from umap import UMAP
2424

2525
from tdamapper._plot_plotly import _marker_size
26+
from tdamapper.core import aggregate_graph
2627
from tdamapper.cover import BallCover, CubicalCover
2728
from tdamapper.learn import MapperAlgorithm
2829
from tdamapper.plot import MapperPlot
@@ -677,7 +678,7 @@ def mapper_figure_section(df_X, df_y, mapper_plot):
677678
colors=colors,
678679
node_size=1.0,
679680
_agg=agg,
680-
cmap="Viridis",
681+
cmap="Jet",
681682
agg_name=agg_name,
682683
colors_feat=colors_feat,
683684
)
@@ -696,73 +697,76 @@ def mapper_figure_section(df_X, df_y, mapper_plot):
696697
scaleratio=1,
697698
)
698699

700+
menu_cmap = {}
699701
if dim == 2:
700-
_set_cmap_buttons_2d(mapper_fig)
702+
menu_cmap = _set_cmap_buttons_2d(mapper_fig)
703+
menu_feature = _set_feature_select_button_2d(
704+
mapper_plot, mapper_fig, df_X, df_y, agg
705+
)
701706

702707
elif dim == 3:
703-
_set_cmap_buttons_3d(mapper_fig)
708+
menu_cmap = _set_cmap_buttons_3d(mapper_fig)
709+
menu_feature = _set_feature_select_button_3d(
710+
mapper_plot, mapper_fig, df_X, df_y, agg
711+
)
704712

705-
_set_node_size_slider(mapper_plot, mapper_fig)
713+
slider_size = _set_node_size_slider(mapper_plot, mapper_fig)
714+
715+
mapper_fig.update_layout(
716+
updatemenus=[menu_cmap, menu_feature],
717+
sliders=[slider_size],
718+
uirevision="constant",
719+
)
706720

707721
return mapper_fig
708722

709723

710724
def _set_cmap_buttons_2d(fig):
711-
fig.update_layout(
712-
updatemenus=[
725+
return dict(
726+
buttons=[
713727
dict(
714-
buttons=[
715-
dict(
716-
label=cmap_name,
717-
method="restyle",
718-
args=[
719-
{
720-
"marker.colorscale": [cmap],
721-
"marker.line.colorscale": [cmap],
722-
},
723-
[0, 1], # Trace indices
724-
],
725-
)
726-
for cmap_name, cmap in V_CMAPS.items()
728+
label=cmap_name,
729+
method="restyle",
730+
args=[
731+
{
732+
"marker.colorscale": [cmap],
733+
"marker.line.colorscale": [cmap],
734+
},
735+
[0, 1], # Trace indices
727736
],
728-
x=0,
729-
xanchor="left",
730-
y=0.75,
731-
yanchor="bottom",
732-
direction="down",
733737
)
738+
for cmap_name, cmap in V_CMAPS.items()
734739
],
735-
uirevision="constant",
740+
x=0,
741+
xanchor="left",
742+
y=0.75,
743+
yanchor="bottom",
744+
direction="down",
736745
)
737746

738747

739748
def _set_cmap_buttons_3d(fig):
740-
fig.update_layout(
741-
updatemenus=[
749+
return dict(
750+
buttons=[
742751
dict(
743-
buttons=[
744-
dict(
745-
label=cmap_name,
746-
method="restyle",
747-
args=[
748-
{
749-
"marker.colorscale": [cmap, cmap],
750-
"marker.line.colorscale": [cmap, cmap],
751-
"line.colorscale": [cmap, cmap],
752-
},
753-
[0, 1], # update both traces
754-
],
755-
)
756-
for cmap_name, cmap in V_CMAPS.items()
752+
label=cmap_name,
753+
method="restyle",
754+
args=[
755+
{
756+
"marker.colorscale": [cmap, cmap],
757+
"marker.line.colorscale": [cmap, cmap],
758+
"line.colorscale": [cmap, cmap],
759+
},
760+
[0, 1], # update both traces
757761
],
758-
x=0,
759-
xanchor="left",
760-
y=0.75,
761-
yanchor="bottom",
762-
direction="down",
763762
)
763+
for cmap_name, cmap in V_CMAPS.items()
764764
],
765-
uirevision="constant",
765+
x=0,
766+
xanchor="left",
767+
y=0.75,
768+
yanchor="bottom",
769+
direction="down",
766770
)
767771

768772

@@ -780,20 +784,132 @@ def _set_node_size_slider(mapper_plot, fig):
780784
)
781785
)
782786

783-
fig.update_layout(
784-
sliders=[
787+
return dict(
788+
active=len(steps) // 2,
789+
currentvalue={"prefix": "Node size: "},
790+
steps=steps,
791+
x=0,
792+
y=0.85,
793+
xanchor="left",
794+
len=0.15,
795+
yanchor="bottom",
796+
)
797+
798+
799+
def _compute_colors_agg(mapper_plot, df_X, df_y, col_feat, agg):
800+
X_cols = list(df_X.columns)
801+
y_cols = list(df_y.columns)
802+
if col_feat in X_cols:
803+
colors = df_X[col_feat].to_numpy()
804+
elif col_feat in y_cols:
805+
colors = df_y[col_feat].to_numpy()
806+
return aggregate_graph(colors, mapper_plot.graph, agg)
807+
808+
809+
def _compute_colors(mapper_plot, df_X, df_y, col_feat, agg):
810+
nodes_col = _compute_colors_agg(mapper_plot, df_X, df_y, col_feat, agg)
811+
return list(nodes_col.values())
812+
813+
814+
def _set_feature_select_button_2d(mapper_plot, mapper_fig, df_X, df_y, agg):
815+
col_feats = list(df_X.columns) + list(df_y.columns)
816+
return dict(
817+
buttons=[
785818
dict(
786-
active=len(steps) // 2,
787-
currentvalue={"prefix": "Node size: "},
788-
steps=steps,
789-
x=0,
790-
y=0.85,
791-
xanchor="left",
792-
len=0.15,
793-
yanchor="bottom",
819+
label=f"Feature: {col_feat}",
820+
method="restyle",
821+
args=[
822+
{
823+
"marker.color": [
824+
_compute_colors(mapper_plot, df_X, df_y, col_feat, agg)
825+
],
826+
"marker.cmax": [
827+
max(
828+
_compute_colors(mapper_plot, df_X, df_y, col_feat, agg),
829+
default=None,
830+
)
831+
],
832+
"marker.cmin": [
833+
min(
834+
_compute_colors(mapper_plot, df_X, df_y, col_feat, agg),
835+
default=None,
836+
)
837+
],
838+
},
839+
[1], # Trace indices
840+
],
794841
)
842+
for col_feat in col_feats
795843
],
796-
uirevision="constant",
844+
x=0,
845+
xanchor="left",
846+
y=0.65,
847+
yanchor="bottom",
848+
direction="down",
849+
)
850+
851+
852+
def _edge_colors(mapper_plot, df_X, df_y, col_feat, agg):
853+
colors_avg = []
854+
colors_agg = _compute_colors_agg(mapper_plot, df_X, df_y, col_feat, agg)
855+
for edge in mapper_plot.graph.edges():
856+
c0, c1 = colors_agg[edge[0]], colors_agg[edge[1]]
857+
colors_avg.append(c0)
858+
colors_avg.append(c1)
859+
colors_avg.append(c1)
860+
return colors_avg
861+
862+
863+
def _set_feature_select_button_3d(mapper_plot, mapper_fig, df_X, df_y, agg):
864+
col_feats = list(df_X.columns) + list(df_y.columns)
865+
return dict(
866+
buttons=[
867+
dict(
868+
label=f"Feature: {col_feat}",
869+
method="restyle",
870+
args=[
871+
{
872+
"marker.color": [
873+
_compute_colors(mapper_plot, df_X, df_y, col_feat, agg)
874+
],
875+
"marker.cmax": [
876+
max(
877+
_compute_colors(mapper_plot, df_X, df_y, col_feat, agg),
878+
default=None,
879+
)
880+
],
881+
"marker.cmin": [
882+
min(
883+
_compute_colors(mapper_plot, df_X, df_y, col_feat, agg),
884+
default=None,
885+
)
886+
],
887+
"line.color": [
888+
_edge_colors(mapper_plot, df_X, df_y, col_feat, agg)
889+
],
890+
"line.cmax": [
891+
max(
892+
_compute_colors(mapper_plot, df_X, df_y, col_feat, agg),
893+
default=None,
894+
)
895+
],
896+
"line.cmin": [
897+
min(
898+
_compute_colors(mapper_plot, df_X, df_y, col_feat, agg),
899+
default=None,
900+
)
901+
],
902+
},
903+
[1], # Trace indices
904+
],
905+
)
906+
for col_feat in col_feats
907+
],
908+
x=0,
909+
xanchor="left",
910+
y=0.65,
911+
yanchor="bottom",
912+
direction="down",
797913
)
798914

799915

0 commit comments

Comments
 (0)