Skip to content

Commit c7fc1a5

Browse files
committed
Added data examples. Improved UI
1 parent 6a326cd commit c7fc1a5

1 file changed

Lines changed: 125 additions & 55 deletions

File tree

src/tdamapper/app.py

Lines changed: 125 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
from nicegui import app, run, ui
66
from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans
7+
from sklearn.datasets import load_digits, load_iris
78
from sklearn.decomposition import PCA
89
from sklearn.preprocessing import StandardScaler
910
from umap import UMAP
@@ -24,6 +25,11 @@
2425
LOGO_URL = f"{GIT_REPO_URL}/raw/main/docs/source/logos/tda-mapper-logo-horizontal.png"
2526

2627

28+
LOAD_EXAMPLE = "Example"
29+
LOAD_EXAMPLE_DIGITS = "Digits"
30+
LOAD_EXAMPLE_IRIS = "Iris"
31+
LOAD_CSV = "CSV"
32+
2733
LENS_IDENTITY = "Identity"
2834
LENS_PCA = "PCA"
2935
LENS_UMAP = "UMAP"
@@ -106,8 +112,8 @@ def _umap(X):
106112
return _umap
107113

108114

109-
def run_mapper(df, **kwargs):
110-
if df is None:
115+
def run_mapper(df, labels, **kwargs):
116+
if df is None or df.empty:
111117
logger.error("No data found. Please upload a file first.")
112118
return
113119
logger.info("Computing Mapper.")
@@ -133,12 +139,14 @@ def run_mapper(df, **kwargs):
133139
)
134140

135141
lens = lens_pca(n_components=LENS_PCA_N_COMPONENTS)
142+
lens_name = LENS_PCA
136143
if lens_type == LENS_IDENTITY:
137144
lens = lens_identity
138145
elif lens_type == LENS_PCA:
139146
lens = lens_pca(n_components=lens_pca_n_components)
140147
elif lens_type == LENS_UMAP:
141148
lens = lens_umap(n_components=lens_umap_n_components)
149+
lens_name = lens_type
142150

143151
if cover_type == COVER_CUBICAL:
144152
cover = CubicalCover(
@@ -177,16 +185,20 @@ def run_mapper(df, **kwargs):
177185
df_fixed = fix_data(df)
178186
X = df_fixed.to_numpy()
179187
y = lens(X)
188+
df_y = pd.DataFrame(y, columns=[f"{lens_type} {i}" for i in range(y.shape[1])])
189+
df_labels = pd.DataFrame(labels) if labels is not None else pd.DataFrame()
180190
if cover_scale_data:
181191
y = StandardScaler().fit_transform(y)
182192
if clustering_scale_data:
183193
X = StandardScaler().fit_transform(X)
194+
df_colors = pd.concat([df_labels, df_y, df_fixed], axis=1)
184195
mapper_graph = mapper.fit_transform(X, y)
185196
mapper_fig = MapperPlot(
186197
mapper_graph,
187198
dim=3,
188199
).plot_plotly(
189-
colors=X,
200+
colors=df_colors.to_numpy(),
201+
title=df_colors.columns.to_list(),
190202
height=800,
191203
node_size=[i * 0.125 for i in range(17)],
192204
)
@@ -198,41 +210,80 @@ class App:
198210

199211
def __init__(self, storage):
200212
self.storage = storage
201-
with ui.row().classes("w-full h-screen overflow-hidden p-0 m-0"):
202-
with ui.column().classes("w-96 h-full p-1 m-0"):
203-
with ui.link(target=GIT_REPO_URL, new_tab=True).classes(
204-
"w-full p-1 m-0"
205-
):
206-
ui.image(LOGO_URL)
207-
with ui.column().classes("w-full h-full overflow-y-auto p-1 m-0"):
208-
self._init_file_upload()
209-
self._init_lens()
210-
self._init_cover()
211-
self._init_clustering()
212-
213-
ui.button(
214-
"Run Mapper",
215-
on_click=self.async_run_mapper,
216-
color="primary",
217-
).classes("w-full")
218-
with ui.column().classes("flex-1 h-full overflow-hidden p-1 m-0"):
219-
self._init_plot()
213+
with ui.left_drawer(elevated=True).classes(
214+
"w-96 h-full overflow-y-auto gap-12"
215+
):
216+
with ui.link(target=GIT_REPO_URL, new_tab=True).classes("w-full"):
217+
ui.image(LOGO_URL)
218+
219+
with ui.column().classes("w-full gap-2"):
220+
self._init_file_upload()
221+
222+
ui.button(
223+
"Load Data",
224+
on_click=self.load_file,
225+
color="primary",
226+
).classes("w-full")
220227

221-
def _init_file_upload(self):
222-
with ui.card().tight().classes("w-full"):
223-
ui.upload(
224-
on_upload=self.upload_file,
225-
auto_upload=True,
226-
label="Upload CSV File",
228+
with ui.column().classes("w-full gap-2"):
229+
self._init_lens()
230+
231+
with ui.column().classes("w-full gap-2"):
232+
self._init_cover()
233+
234+
with ui.column().classes("w-full gap-2"):
235+
self._init_clustering()
236+
237+
ui.button(
238+
"Run Mapper",
239+
on_click=self.async_run_mapper,
240+
color="primary",
227241
).classes("w-full")
228-
with ui.card_section().classes("w-full"):
229-
ui.button("Load", on_click=self.load_file).classes("w-full")
242+
243+
ui.label(
244+
text="If you like this project, please consider giving it a ⭐ on GitHub! Made with ❤️ and ☕️ in Rome."
245+
).classes("text-caption text-gray-500").classes(
246+
"text-caption text-gray-500"
247+
)
248+
249+
with ui.column().classes("w-full h-screen overflow-hidden"):
250+
self._init_plot()
251+
252+
def _init_file_upload(self):
253+
ui.label("📊 Data").classes("text-h6")
254+
255+
self.load_type = ui.select(
256+
options=[LOAD_EXAMPLE, LOAD_CSV],
257+
label="Data Source",
258+
value=LOAD_EXAMPLE,
259+
).classes("w-full")
260+
261+
upload = ui.upload(
262+
on_upload=self.upload_file,
263+
auto_upload=True,
264+
label="Upload CSV File",
265+
).classes("w-full mt-4")
266+
upload.props("accept=.csv")
267+
upload.bind_visibility_from(
268+
target_object=self.load_type,
269+
target_name="value",
270+
value=LOAD_CSV,
271+
)
272+
273+
self.load_example = ui.select(
274+
options=[LOAD_EXAMPLE_DIGITS, LOAD_EXAMPLE_IRIS],
275+
label="Dataset",
276+
value=LOAD_EXAMPLE_DIGITS,
277+
).classes("w-full")
278+
self.load_example.bind_visibility_from(
279+
target_object=self.load_type,
280+
target_name="value",
281+
value=LOAD_EXAMPLE,
282+
)
230283

231284
def _init_lens(self):
232-
with ui.card().tight().classes("w-full"):
233-
with ui.card_section().classes("w-full"):
234-
ui.markdown("##### 🔎 Lens").classes("w-full")
235-
self._init_lens_settings()
285+
ui.label("🔎 Lens").classes("text-h6")
286+
self._init_lens_settings()
236287

237288
def _init_lens_settings(self):
238289
self.lens_type = ui.select(
@@ -266,16 +317,13 @@ def _init_lens_settings(self):
266317
)
267318

268319
def _init_cover(self):
269-
with ui.card().tight().classes("w-full"):
270-
with ui.card_section().classes("w-full"):
271-
with ui.row().classes("w-full"):
272-
ui.markdown("##### 🌐 Cover").classes("flex-1")
273-
274-
self.cover_scale = ui.switch(
275-
text="Scale Data",
276-
value=COVER_SCALE_DATA,
277-
).classes("flex-none")
278-
self._init_cover_settings()
320+
with ui.row().classes("w-full items-center justify-between"):
321+
ui.label("🌐 Cover").classes("text-h6")
322+
self.cover_scale = ui.switch(
323+
text="Scaling",
324+
value=COVER_SCALE_DATA,
325+
)
326+
self._init_cover_settings()
279327

280328
def _init_cover_settings(self):
281329
self.cover_type = ui.select(
@@ -329,15 +377,13 @@ def _init_cover_settings(self):
329377
)
330378

331379
def _init_clustering(self):
332-
with ui.card().tight().classes("w-full"):
333-
with ui.card_section().classes("w-full"):
334-
with ui.row().classes("w-full"):
335-
ui.markdown("##### 🧮 Clustering").classes("flex-1")
336-
self.clustering_scale = ui.switch(
337-
text="Scale Data",
338-
value=CLUSTERING_SCALE_DATA,
339-
).classes("flex-none")
340-
self._init_clustering_settings()
380+
with ui.row().classes("w-full items-center justify-between"):
381+
ui.label("🧮 Clustering").classes("text-h6")
382+
self.clustering_scale = ui.switch(
383+
text="Scaling",
384+
value=CLUSTERING_SCALE_DATA,
385+
)
386+
self._init_clustering_settings()
341387

342388
def _init_clustering_settings(self):
343389
self.clustering_type = ui.select(
@@ -391,7 +437,7 @@ def _init_clustering_settings(self):
391437
)
392438

393439
def _init_plot(self):
394-
self.plot_container = ui.card().classes("w-full h-full")
440+
self.plot_container = ui.element("div").classes("w-full h-full")
395441

396442
def get_mapper_config(self):
397443
return MapperConfig(
@@ -460,12 +506,33 @@ def upload_file(self, file):
460506
if file is not None:
461507
df = pd.read_csv(file.content)
462508
self.storage["df"] = df
509+
self.storage["labels"] = None
463510
logger.info("File uploaded successfully.")
464511
ui.notify("File uploaded successfully.", type="info")
465512
else:
466513
logger.info("No file uploaded.")
467514

468515
def load_file(self):
516+
if self.load_type.value == LOAD_EXAMPLE:
517+
if self.load_example.value == LOAD_EXAMPLE_DIGITS:
518+
df, labels = load_digits(as_frame=True, return_X_y=True)
519+
elif self.load_example.value == LOAD_EXAMPLE_IRIS:
520+
df, labels = load_iris(as_frame=True, return_X_y=True)
521+
else:
522+
logger.error("Unknown example dataset selected.")
523+
return
524+
self.storage["df"] = df
525+
self.storage["labels"] = labels
526+
elif self.load_type.value == LOAD_CSV:
527+
df = self.storage.get("df")
528+
if df is None:
529+
logger.warning("No data found. Please upload a file first.")
530+
ui.notify("No data found. Please upload a file first.", type="warning")
531+
return
532+
else:
533+
logger.error("Unknown load type selected.")
534+
return
535+
469536
df = self.storage.get("df")
470537
if df is not None:
471538
logger.info("Data loaded successfully.")
@@ -480,11 +547,14 @@ async def async_run_mapper(self):
480547
logger.warning("No data found. Please upload a file first.")
481548
ui.notify("No data found. Please upload a file first.", type="warning")
482549
return
550+
labels = self.storage.get("labels")
483551
notification = ui.notification(timeout=None, type="ongoing")
484552
notification.message = "Running Mapper..."
485553
notification.spinner = True
486554
mapper_config = self.get_mapper_config()
487-
mapper_fig = await run.cpu_bound(run_mapper, df, **asdict(mapper_config))
555+
mapper_fig = await run.cpu_bound(
556+
run_mapper, df, labels, **asdict(mapper_config)
557+
)
488558
mapper_fig.layout.width = None
489559
mapper_fig.layout.height = None
490560
mapper_fig.layout.autosize = True

0 commit comments

Comments
 (0)