Skip to content

Commit 8e5fe26

Browse files
committed
Added colormap choice. Minor improvements
1 parent 34b10c4 commit 8e5fe26

1 file changed

Lines changed: 66 additions & 43 deletions

File tree

app/streamlit_app.py

Lines changed: 66 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030

3131
OPENML_URL = 'https://www.openml.org/search?type=data&sort=runs&status=active'
3232

33-
DATA_INFO = 'Non-numeric and NaN features get dropped. NaN rows get replaced by mean'
34-
3533
GIT_REPO_URL = 'https://github.com/lucasimi/tda-mapper-python'
3634

3735
REPORT_BUG = f'{GIT_REPO_URL}/issues'
@@ -52,7 +50,7 @@
5250

5351
APP_TITLE = 'TDA Mapper App'
5452

55-
# V_* are reusable values for widgets
53+
# V_* are reusable constant values
5654

5755
V_LENS_IDENTITY = 'Identity'
5856

@@ -72,17 +70,21 @@
7270

7371
V_DATA_SUMMARY_HIST = 'histogram'
7472

75-
V_DATA_SUMMARY_COLOR = 'color'
76-
7773
V_DATA_SUMMARY_BINS = 15
7874

79-
# VD_* are reusable default values for widgets
75+
V_AGGREGATION_MEAN = 'Mean'
76+
77+
V_AGGREGATION_STD = 'Std'
78+
79+
V_AGGREGATION_QUANTILE = 'Quantile'
80+
81+
# VD_* are reusable default values
8082

8183
VD_SEED = 42
8284

8385
VD_DIM = 3
8486

85-
# S_* are reusable manually managed stored objects
87+
# S_* are reusable session stored objects
8688

8789
S_RESULTS = 'stored_results'
8890

@@ -102,6 +104,7 @@ def __init__(self):
102104
self.mapper_plot = None
103105
self.mapper_fig = self._init_fig()
104106
self.mapper_fig_outdated = True
107+
self.auto_rendering = self._auto_rendering()
105108

106109
def _init_fig(self):
107110
fig = go.Figure(
@@ -129,6 +132,10 @@ def _init_fig(self):
129132
showline=True,
130133
ticks='outside')))
131134
return fig
135+
136+
def _auto_rendering(self):
137+
nodes_num = self.mapper_graph.number_of_nodes()
138+
return nodes_num <= MAX_NODES
132139

133140
def set_df(self, X, y):
134141
self.df_X = fix_data(X)
@@ -143,6 +150,7 @@ def set_df(self, X, y):
143150
self.mapper_plot = None
144151
self.mapper_fig = self._init_fig()
145152
self.mapper_fig_outdated = True
153+
self.auto_rendering = self._auto_rendering()
146154

147155
def set_mapper(self, mapper_graph):
148156
self.mapper_graph = mapper_graph
@@ -155,6 +163,7 @@ def set_mapper(self, mapper_graph):
155163
seed=VD_SEED)
156164
self.mapper_fig = self._init_fig()
157165
self.mapper_fig_outdated = True
166+
self.auto_rendering = self._auto_rendering()
158167

159168
def set_mapper_fig(self, mapper_fig):
160169
self.mapper_fig = mapper_fig
@@ -217,12 +226,6 @@ def _get_data_summary(df_X, df_y):
217226
return df_summary
218227

219228

220-
def auto_rendering():
221-
mapper_graph = st.session_state[S_RESULTS].mapper_graph
222-
nodes_num = mapper_graph.number_of_nodes()
223-
return nodes_num <= MAX_NODES
224-
225-
226229
def _mapper_caption():
227230
mapper_graph = st.session_state[S_RESULTS].mapper_graph
228231
nodes_num = 0
@@ -352,7 +355,7 @@ def _update_data(data_source):
352355
st.toast(f'# {err}', icon='🚨')
353356
df_X, df_y = fix_data(X), fix_data(y)
354357
st.session_state[S_RESULTS].set_df(df_X, df_y)
355-
st.toast('Successfully Loaded Data', icon='')
358+
st.toast('Successfully Loaded Data', icon='📦')
356359

357360

358361
def _data_caption():
@@ -442,8 +445,9 @@ def _update_mapper(X, lens, cover, clustering):
442445
verbose=False))
443446
mapper_graph = mapper_algo.fit_transform(X, lens)
444447
st.session_state[S_RESULTS].set_mapper(mapper_graph)
445-
st.toast('Successfully Computed Mapper', icon='✅')
446-
if not auto_rendering():
448+
st.toast('Successfully Computed Mapper', icon='🚀')
449+
auto_rendering = st.session_state[S_RESULTS].auto_rendering
450+
if not auto_rendering:
447451
st.toast('Automatic Rendering Disabled: Graph Too Large', icon='⚠️')
448452

449453

@@ -541,51 +545,56 @@ def mapper_output_section():
541545
_mapper_download()
542546

543547

544-
def _update_fig(seed, colors, agg):
548+
def _update_fig(seed, colors, agg, cmap, title):
545549
mapper_plot = st.session_state[S_RESULTS].mapper_plot
546550
if mapper_plot is None:
547551
return
548552
mapper_plot.update(
549553
colors=colors,
550554
seed=seed,
551-
agg=agg)
555+
agg=agg,
556+
title=title,
557+
cmap=cmap)
552558
mapper_fig = mapper_plot.plot()
553559
mapper_fig.update_layout(
554560
uirevision='constant',
555561
margin=dict(b=0, l=0, r=0, t=0))
556562
st.session_state[S_RESULTS].set_mapper_fig(mapper_fig)
557-
st.toast('Successfully Rendered Graph', icon='')
563+
st.toast('Successfully Rendered Graph', icon='🖌️')
558564

559565

560566
def _update_mapper_fig_outdated():
561567
st.session_state[S_RESULTS].mapper_fig_outdated = True
562568

563569

564-
def _mapper_colors():
570+
def _mapper_color_feature():
565571
X = st.session_state[S_RESULTS].X
566572
df_all = st.session_state[S_RESULTS].df_all
567-
colors = X
568573
col_feat = st.selectbox(
569-
'Color',
574+
'🎨 Color',
570575
options=list(df_all.columns),
571576
on_change=_update_mapper_fig_outdated)
572-
if col_feat in df_all.columns:
573-
df_col = df_all[col_feat]
574-
colors = df_col.to_numpy()
575-
return colors
577+
return col_feat
576578

577579

578-
def _mapper_aggregation():
579-
agg = None
580+
def _mapper_aggregation_type():
580581
agg_type = st.selectbox(
581-
'Aggregation',
582-
options=['Mean', 'Std', 'Quantile'],
582+
'🍲 Aggregation',
583+
options=[
584+
V_AGGREGATION_MEAN,
585+
V_AGGREGATION_STD,
586+
V_AGGREGATION_QUANTILE],
583587
on_change=_update_mapper_fig_outdated)
584-
if agg_type == 'Mean':
588+
return agg_type
589+
590+
591+
def _mapper_aggregation(agg_type):
592+
agg = None
593+
if agg_type == V_AGGREGATION_MEAN:
585594
agg = np.nanmean
586-
elif agg_type == 'Std':
595+
elif agg_type == V_AGGREGATION_STD:
587596
agg = np.nanstd
588-
elif agg_type == 'Quantile':
597+
elif agg_type == V_AGGREGATION_QUANTILE:
589598
q = st.slider(
590599
'Rank',
591600
value=0.5,
@@ -598,26 +607,40 @@ def _mapper_aggregation():
598607

599608
def _mapper_seed():
600609
seed = st.number_input(
601-
'Seed',
610+
'🎲 Seed',
602611
value=VD_SEED,
603612
help='Changing this value alters the shape',
604613
on_change=_update_mapper_fig_outdated)
605614
return seed
606615

607616

608-
def mapper_draw_section(colors):
617+
def mapper_draw_section(color_feat):
618+
df_all = st.session_state[S_RESULTS].df_all
619+
if color_feat in df_all.columns:
620+
df_col = df_all[color_feat]
621+
colors = df_col.to_numpy()
609622
seed = _mapper_seed()
610-
agg = _mapper_aggregation()
623+
cmap = st.selectbox(
624+
'🌈 Colormap',
625+
options=[
626+
'Jet',
627+
'Viridis',
628+
'HSV'],
629+
on_change=_update_mapper_fig_outdated)
630+
agg_type = _mapper_aggregation_type()
631+
agg = _mapper_aggregation(agg_type)
611632
mapper_plot = st.session_state[S_RESULTS].mapper_plot
612633
update_button = st.button(
613-
'🎨 Draw',
634+
'🖌️ Draw',
614635
use_container_width=True,
615636
disabled=mapper_plot is None)
616637
mapper_fig_outdated = st.session_state[S_RESULTS].mapper_fig_outdated
617-
if auto_rendering() and mapper_fig_outdated:
618-
_update_fig(seed, colors, agg)
638+
auto_rendering = st.session_state[S_RESULTS].auto_rendering
639+
title = f'{agg_type} of {color_feat}'
640+
if auto_rendering and mapper_fig_outdated:
641+
_update_fig(seed, colors, agg, cmap, title)
619642
elif update_button:
620-
_update_fig(seed, colors, agg)
643+
_update_fig(seed, colors, agg, cmap, title)
621644

622645

623646
def mapper_rendering_section():
@@ -638,9 +661,9 @@ def main():
638661
lens_type, cover_type, clustering_type = mapper_settings_section()
639662
with st.popover('🚀 Run', use_container_width=True):
640663
mapper_run_section(lens_type, cover_type, clustering_type)
641-
colors = _mapper_colors()
642-
with st.popover('🎨 Draw', use_container_width=True):
643-
mapper_draw_section(colors)
664+
color_feat = _mapper_color_feature()
665+
with st.popover('🖌️ Draw', use_container_width=True):
666+
mapper_draw_section(color_feat)
644667
with st.popover('ℹ️ More', use_container_width=True):
645668
tab_0, tab_1, tab_2 = st.tabs(['📈 Features', '🗒️ Data', '📊 Mapper'])
646669
with tab_0:

0 commit comments

Comments
 (0)