Skip to content

Commit 34b10c4

Browse files
committed
Improved UI
1 parent 3a69f1d commit 34b10c4

1 file changed

Lines changed: 85 additions & 89 deletions

File tree

app/streamlit_app.py

Lines changed: 85 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tdamapper.plot import MapperLayoutInteractive
2323

2424

25-
MAX_NODES = 500
25+
MAX_NODES = 1000
2626

2727
MAX_SAMPLES = 1000
2828

@@ -74,7 +74,7 @@
7474

7575
V_DATA_SUMMARY_COLOR = 'color'
7676

77-
V_DATA_SUMMARY_BINS = 5
77+
V_DATA_SUMMARY_BINS = 15
7878

7979
# VD_* are reusable default values for widgets
8080

@@ -95,12 +95,13 @@ def __init__(self):
9595
self.df_X_sample = pd.DataFrame()
9696
self.df_y_sample = pd.DataFrame()
9797
self.df_all = pd.DataFrame()
98+
self.df_all_sample = pd.DataFrame()
9899
self.X = self.df_X.to_numpy()
99100
self.df_summary = pd.DataFrame()
100101
self.mapper_graph = nx.Graph()
101102
self.mapper_plot = None
102103
self.mapper_fig = self._init_fig()
103-
self.auto_rendering = None
104+
self.mapper_fig_outdated = True
104105

105106
def _init_fig(self):
106107
fig = go.Figure(
@@ -135,32 +136,29 @@ def set_df(self, X, y):
135136
self.df_X_sample = get_sample(self.df_X)
136137
self.df_y_sample = get_sample(self.df_y)
137138
self.df_all = pd.concat([self.df_y, self.df_X], axis=1)
139+
self.df_all_sample = pd.concat([self.df_y_sample, self.df_X_sample], axis=1)
138140
self.X = self.df_X.to_numpy()
139141
self.df_summary = _get_data_summary(self.df_X, self.df_y)
140142
self.mapper_graph = nx.Graph()
141143
self.mapper_plot = None
142144
self.mapper_fig = self._init_fig()
143-
self.auto_rendering = None
145+
self.mapper_fig_outdated = True
144146

145147
def set_mapper(self, mapper_graph):
146148
self.mapper_graph = mapper_graph
147149
self.mapper_plot = MapperLayoutInteractive(
148150
self.mapper_graph,
149151
dim=VD_DIM,
150-
height=450,
151-
width=450,
152+
height=500,
153+
width=500,
152154
colors=self.X,
153155
seed=VD_SEED)
154156
self.mapper_fig = self._init_fig()
155-
nodes_num = mapper_graph.number_of_nodes()
156-
if nodes_num <= MAX_NODES:
157-
self.auto_rendering = True
158-
else:
159-
self.auto_rendering = False
157+
self.mapper_fig_outdated = True
160158

161159
def set_mapper_fig(self, mapper_fig):
162160
self.mapper_fig = mapper_fig
163-
self.auto_rendering = None
161+
self.mapper_fig_outdated = False
164162

165163

166164
def lp_metric(p):
@@ -216,10 +214,15 @@ def _get_data_summary(df_X, df_y):
216214
df_summary = pd.DataFrame({
217215
V_DATA_SUMMARY_FEAT: df.columns,
218216
V_DATA_SUMMARY_HIST: df_hist.values.tolist()})
219-
df_summary[V_DATA_SUMMARY_COLOR] = False
220217
return df_summary
221218

222219

220+
def auto_rendering():
221+
mapper_graph = st.session_state[S_RESULTS].mapper_graph
222+
nodes_num = mapper_graph.number_of_nodes()
223+
return nodes_num <= MAX_NODES
224+
225+
223226
def _mapper_caption():
224227
mapper_graph = st.session_state[S_RESULTS].mapper_graph
225228
nodes_num = 0
@@ -364,10 +367,10 @@ def _data_caption():
364367
st.caption(cap)
365368

366369

367-
def _data_summary():
368-
df_all = st.session_state[S_RESULTS].df_all
370+
def _data_preview():
371+
df_all_sample = st.session_state[S_RESULTS].df_all_sample
369372
st.dataframe(
370-
df_all.head(50),
373+
df_all_sample.head(50),
371374
use_container_width=True,
372375
height=200)
373376

@@ -386,10 +389,31 @@ def _data_download():
386389

387390
def data_output_section():
388391
_data_caption()
389-
_data_summary()
392+
_data_preview()
390393
_data_download()
391394

392395

396+
def _data_summary():
397+
df_summary = st.session_state[S_RESULTS].df_summary
398+
st.dataframe(df_summary,
399+
hide_index=True,
400+
height=250,
401+
column_config={
402+
V_DATA_SUMMARY_HIST: st.column_config.AreaChartColumn(
403+
width='large'),
404+
V_DATA_SUMMARY_FEAT: st.column_config.TextColumn(
405+
width='small',
406+
disabled=True)
407+
},
408+
use_container_width=True)
409+
410+
411+
def data_summary_section():
412+
_data_caption()
413+
_data_summary()
414+
#_data_download()
415+
416+
393417
def data_input_section():
394418
data_source = None
395419
data_source_type = st.selectbox(
@@ -419,8 +443,7 @@ def _update_mapper(X, lens, cover, clustering):
419443
mapper_graph = mapper_algo.fit_transform(X, lens)
420444
st.session_state[S_RESULTS].set_mapper(mapper_graph)
421445
st.toast('Successfully Computed Mapper', icon='✅')
422-
auto_rendering = st.session_state[S_RESULTS].auto_rendering
423-
if auto_rendering is False:
446+
if not auto_rendering():
424447
st.toast('Automatic Rendering Disabled: Graph Too Large', icon='⚠️')
425448

426449

@@ -534,97 +557,67 @@ def _update_fig(seed, colors, agg):
534557
st.toast('Successfully Rendered Graph', icon='✅')
535558

536559

537-
def _update_auto_rendering():
538-
mapper_graph = st.session_state[S_RESULTS].mapper_graph
539-
nodes_num = mapper_graph.number_of_nodes() if mapper_graph else 0
540-
if nodes_num <= MAX_NODES:
541-
st.session_state[S_RESULTS].auto_rendering = True
560+
def _update_mapper_fig_outdated():
561+
st.session_state[S_RESULTS].mapper_fig_outdated = True
542562

543563

544564
def _mapper_colors():
545565
X = st.session_state[S_RESULTS].X
546566
df_all = st.session_state[S_RESULTS].df_all
547-
df_summary = st.session_state[S_RESULTS].df_summary
548567
colors = X
549-
data_edit = st.data_editor(
550-
df_summary,
551-
height=250,
552-
hide_index=True,
553-
disabled=(c for c in df_summary.columns if c != V_DATA_SUMMARY_COLOR),
554-
use_container_width=True,
555-
column_config={
556-
V_DATA_SUMMARY_HIST: st.column_config.AreaChartColumn(
557-
width='small'),
558-
V_DATA_SUMMARY_FEAT: st.column_config.TextColumn(
559-
width='small',
560-
disabled=True),
561-
V_DATA_SUMMARY_COLOR: st.column_config.CheckboxColumn(
562-
width='small',
563-
disabled=False)
564-
}, on_change=_update_auto_rendering)
565-
if not data_edit.empty:
566-
color_features = data_edit[data_edit[V_DATA_SUMMARY_COLOR]][V_DATA_SUMMARY_FEAT]
567-
if not color_features.empty:
568-
selected = pd.concat([df_all[c] for c in color_features], axis=1)
569-
if not selected.empty:
570-
colors = selected.to_numpy()
568+
col_feat = st.selectbox(
569+
'Color',
570+
options=list(df_all.columns),
571+
on_change=_update_mapper_fig_outdated)
572+
if col_feat in df_all.columns:
573+
df_col = df_all[col_feat]
574+
colors = df_col.to_numpy()
571575
return colors
572576

573577

574-
def _agg(name, axis):
578+
def _mapper_aggregation():
575579
agg = None
576-
agg_sel = st.selectbox(
577-
name,
580+
agg_type = st.selectbox(
581+
'Aggregation',
578582
options=['Mean', 'Std', 'Quantile'],
579-
on_change=_update_auto_rendering)
580-
if agg_sel == 'Mean':
581-
agg = lambda x: np.nanmean(x, axis=axis)
582-
elif agg_sel == 'Std':
583-
agg = lambda x: np.nanstd(x, axis=axis)
584-
elif agg_sel == 'Quantile':
585-
r = st.slider('Rank', min_value=0.0, value=0.5, max_value=1.0)
586-
agg = lambda x: np.quantile(x, q=r, axis=axis)
583+
on_change=_update_mapper_fig_outdated)
584+
if agg_type == 'Mean':
585+
agg = np.nanmean
586+
elif agg_type == 'Std':
587+
agg = np.nanstd
588+
elif agg_type == 'Quantile':
589+
q = st.slider(
590+
'Rank',
591+
value=0.5,
592+
min_value=0.0,
593+
max_value=1.0,
594+
on_change=_update_mapper_fig_outdated)
595+
agg = lambda x: np.nanquantile(x, q=q)
587596
return agg
588597

589598

590-
def _mapper_feature_agg():
591-
return _agg('Feature Aggregation', axis=1)
592-
593-
594-
def _mapper_node_agg():
595-
return _agg('Node Aggregation', axis=0)
596-
597-
598-
def _mapper_agg():
599-
feat_agg = _mapper_feature_agg()
600-
node_agg = _mapper_node_agg()
601-
return lambda x: node_agg(feat_agg(x))
602-
603-
604599
def _mapper_seed():
605600
seed = st.number_input(
606601
'Seed',
607602
value=VD_SEED,
608603
help='Changing this value alters the shape',
609-
on_change=_update_auto_rendering)
604+
on_change=_update_mapper_fig_outdated)
610605
return seed
611606

612607

613-
def mapper_draw_section():
614-
colors = _mapper_colors()
615-
agg = _mapper_agg()
608+
def mapper_draw_section(colors):
616609
seed = _mapper_seed()
617-
auto_rendering = st.session_state[S_RESULTS].auto_rendering
618-
if auto_rendering:
610+
agg = _mapper_aggregation()
611+
mapper_plot = st.session_state[S_RESULTS].mapper_plot
612+
update_button = st.button(
613+
'🎨 Draw',
614+
use_container_width=True,
615+
disabled=mapper_plot is None)
616+
mapper_fig_outdated = st.session_state[S_RESULTS].mapper_fig_outdated
617+
if auto_rendering() and mapper_fig_outdated:
618+
_update_fig(seed, colors, agg)
619+
elif update_button:
619620
_update_fig(seed, colors, agg)
620-
else:
621-
mapper_plot = st.session_state[S_RESULTS].mapper_plot
622-
update_button = st.button(
623-
'🎨 Draw',
624-
use_container_width=True,
625-
disabled=mapper_plot is None)
626-
if update_button:
627-
_update_fig(seed, colors, agg)
628621

629622

630623
def mapper_rendering_section():
@@ -645,13 +638,16 @@ def main():
645638
lens_type, cover_type, clustering_type = mapper_settings_section()
646639
with st.popover('🚀 Run', use_container_width=True):
647640
mapper_run_section(lens_type, cover_type, clustering_type)
641+
colors = _mapper_colors()
648642
with st.popover('🎨 Draw', use_container_width=True):
649-
mapper_draw_section()
643+
mapper_draw_section(colors)
650644
with st.popover('ℹ️ More', use_container_width=True):
651-
tab_0, tab_1 = st.tabs(['🗒️ Data', '📊 Mapper'])
645+
tab_0, tab_1, tab_2 = st.tabs(['📈 Features', '🗒️ Data', '📊 Mapper'])
652646
with tab_0:
653-
data_output_section()
647+
data_summary_section()
654648
with tab_1:
649+
data_output_section()
650+
with tab_2:
655651
mapper_output_section()
656652
with col_1:
657653
mapper_rendering_section()

0 commit comments

Comments
 (0)