44import pandas as pd
55from nicegui import app , run , ui
66from sklearn .cluster import DBSCAN , AgglomerativeClustering , KMeans
7+ from sklearn .datasets import load_digits , load_iris
78from sklearn .decomposition import PCA
89from sklearn .preprocessing import StandardScaler
910from umap import UMAP
2425LOGO_URL = f"{ GIT_REPO_URL } /raw/main/docs/source/logos/tda-mapper-logo-horizontal.png"
2526
2627
28+ LOAD_EXAMPLE = "Example"
29+ LOAD_EXAMPLE_DIGITS = "Digits"
30+ LOAD_EXAMPLE_IRIS = "Iris"
31+ LOAD_CSV = "CSV"
32+
2733LENS_IDENTITY = "Identity"
2834LENS_PCA = "PCA"
2935LENS_UMAP = "UMAP"
@@ -106,8 +112,8 @@ def _umap(X):
106112 return _umap
107113
108114
109- def run_mapper (df , ** kwargs ):
110- if df is None :
115+ def run_mapper (df , labels , ** kwargs ):
116+ if df is None or df . empty :
111117 logger .error ("No data found. Please upload a file first." )
112118 return
113119 logger .info ("Computing Mapper." )
@@ -133,12 +139,14 @@ def run_mapper(df, **kwargs):
133139 )
134140
135141 lens = lens_pca (n_components = LENS_PCA_N_COMPONENTS )
142+ lens_name = LENS_PCA
136143 if lens_type == LENS_IDENTITY :
137144 lens = lens_identity
138145 elif lens_type == LENS_PCA :
139146 lens = lens_pca (n_components = lens_pca_n_components )
140147 elif lens_type == LENS_UMAP :
141148 lens = lens_umap (n_components = lens_umap_n_components )
149+ lens_name = lens_type
142150
143151 if cover_type == COVER_CUBICAL :
144152 cover = CubicalCover (
@@ -177,16 +185,20 @@ def run_mapper(df, **kwargs):
177185 df_fixed = fix_data (df )
178186 X = df_fixed .to_numpy ()
179187 y = lens (X )
188+ df_y = pd .DataFrame (y , columns = [f"{ lens_type } { i } " for i in range (y .shape [1 ])])
189+ df_labels = pd .DataFrame (labels ) if labels is not None else pd .DataFrame ()
180190 if cover_scale_data :
181191 y = StandardScaler ().fit_transform (y )
182192 if clustering_scale_data :
183193 X = StandardScaler ().fit_transform (X )
194+ df_colors = pd .concat ([df_labels , df_y , df_fixed ], axis = 1 )
184195 mapper_graph = mapper .fit_transform (X , y )
185196 mapper_fig = MapperPlot (
186197 mapper_graph ,
187198 dim = 3 ,
188199 ).plot_plotly (
189- colors = X ,
200+ colors = df_colors .to_numpy (),
201+ title = df_colors .columns .to_list (),
190202 height = 800 ,
191203 node_size = [i * 0.125 for i in range (17 )],
192204 )
@@ -198,41 +210,80 @@ class App:
198210
199211 def __init__ (self , storage ):
200212 self .storage = storage
201- with ui .row ().classes ("w-full h-screen overflow-hidden p-0 m-0" ):
202- with ui .column ().classes ("w-96 h-full p-1 m-0" ):
203- with ui .link (target = GIT_REPO_URL , new_tab = True ).classes (
204- "w-full p-1 m-0"
205- ):
206- ui .image (LOGO_URL )
207- with ui .column ().classes ("w-full h-full overflow-y-auto p-1 m-0" ):
208- self ._init_file_upload ()
209- self ._init_lens ()
210- self ._init_cover ()
211- self ._init_clustering ()
212-
213- ui .button (
214- "Run Mapper" ,
215- on_click = self .async_run_mapper ,
216- color = "primary" ,
217- ).classes ("w-full" )
218- with ui .column ().classes ("flex-1 h-full overflow-hidden p-1 m-0" ):
219- self ._init_plot ()
213+ with ui .left_drawer (elevated = True ).classes (
214+ "w-96 h-full overflow-y-auto gap-12"
215+ ):
216+ with ui .link (target = GIT_REPO_URL , new_tab = True ).classes ("w-full" ):
217+ ui .image (LOGO_URL )
218+
219+ with ui .column ().classes ("w-full gap-2" ):
220+ self ._init_file_upload ()
221+
222+ ui .button (
223+ "Load Data" ,
224+ on_click = self .load_file ,
225+ color = "primary" ,
226+ ).classes ("w-full" )
220227
221- def _init_file_upload (self ):
222- with ui .card ().tight ().classes ("w-full" ):
223- ui .upload (
224- on_upload = self .upload_file ,
225- auto_upload = True ,
226- label = "Upload CSV File" ,
228+ with ui .column ().classes ("w-full gap-2" ):
229+ self ._init_lens ()
230+
231+ with ui .column ().classes ("w-full gap-2" ):
232+ self ._init_cover ()
233+
234+ with ui .column ().classes ("w-full gap-2" ):
235+ self ._init_clustering ()
236+
237+ ui .button (
238+ "Run Mapper" ,
239+ on_click = self .async_run_mapper ,
240+ color = "primary" ,
227241 ).classes ("w-full" )
228- with ui .card_section ().classes ("w-full" ):
229- ui .button ("Load" , on_click = self .load_file ).classes ("w-full" )
242+
243+ ui .label (
244+ text = "If you like this project, please consider giving it a ⭐ on GitHub! Made with ❤️ and ☕️ in Rome."
245+ ).classes ("text-caption text-gray-500" ).classes (
246+ "text-caption text-gray-500"
247+ )
248+
249+ with ui .column ().classes ("w-full h-screen overflow-hidden" ):
250+ self ._init_plot ()
251+
252+ def _init_file_upload (self ):
253+ ui .label ("📊 Data" ).classes ("text-h6" )
254+
255+ self .load_type = ui .select (
256+ options = [LOAD_EXAMPLE , LOAD_CSV ],
257+ label = "Data Source" ,
258+ value = LOAD_EXAMPLE ,
259+ ).classes ("w-full" )
260+
261+ upload = ui .upload (
262+ on_upload = self .upload_file ,
263+ auto_upload = True ,
264+ label = "Upload CSV File" ,
265+ ).classes ("w-full mt-4" )
266+ upload .props ("accept=.csv" )
267+ upload .bind_visibility_from (
268+ target_object = self .load_type ,
269+ target_name = "value" ,
270+ value = LOAD_CSV ,
271+ )
272+
273+ self .load_example = ui .select (
274+ options = [LOAD_EXAMPLE_DIGITS , LOAD_EXAMPLE_IRIS ],
275+ label = "Dataset" ,
276+ value = LOAD_EXAMPLE_DIGITS ,
277+ ).classes ("w-full" )
278+ self .load_example .bind_visibility_from (
279+ target_object = self .load_type ,
280+ target_name = "value" ,
281+ value = LOAD_EXAMPLE ,
282+ )
230283
231284 def _init_lens (self ):
232- with ui .card ().tight ().classes ("w-full" ):
233- with ui .card_section ().classes ("w-full" ):
234- ui .markdown ("##### 🔎 Lens" ).classes ("w-full" )
235- self ._init_lens_settings ()
285+ ui .label ("🔎 Lens" ).classes ("text-h6" )
286+ self ._init_lens_settings ()
236287
237288 def _init_lens_settings (self ):
238289 self .lens_type = ui .select (
@@ -266,16 +317,13 @@ def _init_lens_settings(self):
266317 )
267318
268319 def _init_cover (self ):
269- with ui .card ().tight ().classes ("w-full" ):
270- with ui .card_section ().classes ("w-full" ):
271- with ui .row ().classes ("w-full" ):
272- ui .markdown ("##### 🌐 Cover" ).classes ("flex-1" )
273-
274- self .cover_scale = ui .switch (
275- text = "Scale Data" ,
276- value = COVER_SCALE_DATA ,
277- ).classes ("flex-none" )
278- self ._init_cover_settings ()
320+ with ui .row ().classes ("w-full items-center justify-between" ):
321+ ui .label ("🌐 Cover" ).classes ("text-h6" )
322+ self .cover_scale = ui .switch (
323+ text = "Scaling" ,
324+ value = COVER_SCALE_DATA ,
325+ )
326+ self ._init_cover_settings ()
279327
280328 def _init_cover_settings (self ):
281329 self .cover_type = ui .select (
@@ -329,15 +377,13 @@ def _init_cover_settings(self):
329377 )
330378
331379 def _init_clustering (self ):
332- with ui .card ().tight ().classes ("w-full" ):
333- with ui .card_section ().classes ("w-full" ):
334- with ui .row ().classes ("w-full" ):
335- ui .markdown ("##### 🧮 Clustering" ).classes ("flex-1" )
336- self .clustering_scale = ui .switch (
337- text = "Scale Data" ,
338- value = CLUSTERING_SCALE_DATA ,
339- ).classes ("flex-none" )
340- self ._init_clustering_settings ()
380+ with ui .row ().classes ("w-full items-center justify-between" ):
381+ ui .label ("🧮 Clustering" ).classes ("text-h6" )
382+ self .clustering_scale = ui .switch (
383+ text = "Scaling" ,
384+ value = CLUSTERING_SCALE_DATA ,
385+ )
386+ self ._init_clustering_settings ()
341387
342388 def _init_clustering_settings (self ):
343389 self .clustering_type = ui .select (
@@ -391,7 +437,7 @@ def _init_clustering_settings(self):
391437 )
392438
393439 def _init_plot (self ):
394- self .plot_container = ui .card ( ).classes ("w-full h-full" )
440+ self .plot_container = ui .element ( "div" ).classes ("w-full h-full" )
395441
396442 def get_mapper_config (self ):
397443 return MapperConfig (
@@ -460,12 +506,33 @@ def upload_file(self, file):
460506 if file is not None :
461507 df = pd .read_csv (file .content )
462508 self .storage ["df" ] = df
509+ self .storage ["labels" ] = None
463510 logger .info ("File uploaded successfully." )
464511 ui .notify ("File uploaded successfully." , type = "info" )
465512 else :
466513 logger .info ("No file uploaded." )
467514
468515 def load_file (self ):
516+ if self .load_type .value == LOAD_EXAMPLE :
517+ if self .load_example .value == LOAD_EXAMPLE_DIGITS :
518+ df , labels = load_digits (as_frame = True , return_X_y = True )
519+ elif self .load_example .value == LOAD_EXAMPLE_IRIS :
520+ df , labels = load_iris (as_frame = True , return_X_y = True )
521+ else :
522+ logger .error ("Unknown example dataset selected." )
523+ return
524+ self .storage ["df" ] = df
525+ self .storage ["labels" ] = labels
526+ elif self .load_type .value == LOAD_CSV :
527+ df = self .storage .get ("df" )
528+ if df is None :
529+ logger .warning ("No data found. Please upload a file first." )
530+ ui .notify ("No data found. Please upload a file first." , type = "warning" )
531+ return
532+ else :
533+ logger .error ("Unknown load type selected." )
534+ return
535+
469536 df = self .storage .get ("df" )
470537 if df is not None :
471538 logger .info ("Data loaded successfully." )
@@ -480,11 +547,14 @@ async def async_run_mapper(self):
480547 logger .warning ("No data found. Please upload a file first." )
481548 ui .notify ("No data found. Please upload a file first." , type = "warning" )
482549 return
550+ labels = self .storage .get ("labels" )
483551 notification = ui .notification (timeout = None , type = "ongoing" )
484552 notification .message = "Running Mapper..."
485553 notification .spinner = True
486554 mapper_config = self .get_mapper_config ()
487- mapper_fig = await run .cpu_bound (run_mapper , df , ** asdict (mapper_config ))
555+ mapper_fig = await run .cpu_bound (
556+ run_mapper , df , labels , ** asdict (mapper_config )
557+ )
488558 mapper_fig .layout .width = None
489559 mapper_fig .layout .height = None
490560 mapper_fig .layout .autosize = True
0 commit comments