55import numpy as np
66import pandas as pd
77import plotly .graph_objs as go
8- from nicegui import run , ui
8+ from nicegui import app , run , ui
99from sklearn .cluster import DBSCAN , AgglomerativeClustering , KMeans
1010from sklearn .datasets import fetch_openml , load_digits , load_iris
1111from sklearn .decomposition import PCA
1212from umap import UMAP
1313
14- from tdamapper .core import TrivialClustering , TrivialCover
14+ from tdamapper .core import Cover , TrivialClustering , TrivialCover
1515from tdamapper .cover import BallCover , CubicalCover , KNNCover
1616from tdamapper .learn import MapperAlgorithm
1717from tdamapper .plot import MapperPlot
1818
19+ RANDOM_STATE = 42
1920LENS_IDENTITY = "Identity"
2021LENS_PCA = "PCA"
22+ LENS_PCA_N_COMPONENTS = 2
2123LENS_UMAP = "UMAP"
24+ LENS_UMAP_N_COMPONENTS = 2
2225
2326COVER_TRIVIAL = "Trivial"
2427COVER_CUBICAL = "Cubical"
28+ COVER_CUBICAL_N_INTERVALS = 2
29+ COVER_CUBICAL_OVERLAP_FRAC = 0.25
2530COVER_BALL = "Ball"
31+ COVER_BALL_RADIUS = 100.0
2632COVER_KNN = "KNN"
33+ COVER_KNN_NEIGHBORS = 10
2734
2835CLUSTERING_TRIVIAL = "Trivial"
2936CLUSTERING_KMEANS = "KMeans"
37+ CLUSTERING_KMEANS_N_CLUSTERS = 2
3038CLUSTERING_AGGLOMERATIVE = "Agglomerative"
39+ CLUSTERING_AGGLOMERATIVE_N_CLUSTERS = 2
3140CLUSTERING_DBSCAN = "DBSCAN"
41+ CLUSTERING_DBSCAN_EPS = 0.5
42+ CLUSTERING_DBSCAN_MIN_SAMPLES = 5
3243
3344DATA_SOURCE_EXAMPLE = "Example"
3445DATA_SOURCE_CSV = "CSV"
@@ -57,21 +68,21 @@ class State:
5768 source_csv : str = "tmp/data.csv"
5869 source_openml : str = SOURCE_OPENML
5970 lens_type : str = LENS_PCA
60- lens_pca_n_components : int = 2
61- lens_umap_n_components : int = 2
71+ lens_pca_n_components : int = LENS_PCA_N_COMPONENTS
72+ lens_umap_n_components : int = LENS_UMAP_N_COMPONENTS
6273 cover_type : str = COVER_CUBICAL
63- cover_cubical_n_intervals : int = 2
64- cover_cubical_overlap_frac : float = 0.25
65- cover_knn_neighbors : int = 10
66- cover_ball_radius : float = 100.0
74+ cover_cubical_n_intervals : int = COVER_CUBICAL_N_INTERVALS
75+ cover_cubical_overlap_frac : float = COVER_CUBICAL_OVERLAP_FRAC
76+ cover_knn_neighbors : int = COVER_KNN_NEIGHBORS
77+ cover_ball_radius : float = COVER_BALL_RADIUS
6778 clustering_type : str = CLUSTERING_TRIVIAL
68- clustering_kmeans_n_clusters : int = 2
69- clustering_dbscan_eps : float = 0.5
70- clustering_dbscan_min_samples : int = 5
71- clustering_agglomerative_n_clusters : int = 2
72- draw_dim : int = 3
79+ clustering_kmeans_n_clusters : int = CLUSTERING_KMEANS_N_CLUSTERS
80+ clustering_dbscan_eps : float = CLUSTERING_DBSCAN_EPS
81+ clustering_dbscan_min_samples : int = CLUSTERING_DBSCAN_MIN_SAMPLES
82+ clustering_agglomerative_n_clusters : int = CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
83+ draw_dim : str = DRAW_3D
7384 draw_aggregation : str = DRAW_MEAN
74- draw_iterations : int = 50
85+ draw_iterations : int = DRAW_ITERATIONS
7586
7687
7788def _fix_data (data ):
@@ -115,58 +126,69 @@ def get_dataset(state: State):
115126
116127
117128def get_lens (state : State ):
129+ def _pca (n ):
130+ pca = PCA (n_components = n , random_state = RANDOM_STATE )
131+ return lambda X : pca .fit_transform (X )
132+
133+ def _umap (n ):
134+ umap = UMAP (n_components = n , random_state = RANDOM_STATE )
135+ return lambda X : umap .fit_transform (X )
136+
137+ def _identity ():
138+ return lambda X : X
139+
140+ lens = _pca (2 )
118141 lens_type = state .lens_type
119142 if lens_type == LENS_IDENTITY :
120- return lambda X : X
143+ lens = _identity ()
121144 elif lens_type == LENS_PCA :
122- n_components = int (state .lens_pca_n_components )
123- pca = PCA (n_components = n_components , random_state = 42 )
124- return lambda X : pca .fit_transform (X )
145+ lens = _pca (state .lens_pca_n_components )
125146 elif lens_type == LENS_UMAP :
126- n_components = int (state .lens_umap_n_components )
127- umap = UMAP (n_components = n_components , random_state = 42 )
128- return lambda X : umap .fit_transform (X )
147+ lens = _umap (state .lens_umap_n_components )
129148 else :
130- logger .error (f"Unknown lens type: { lens_type } " )
131- return None
149+ logger .error ("Defaulting to PCA lens " )
150+ return lens
132151
133152
134- def get_cover (state : State ):
153+ def get_cover (state : State ) -> Cover :
135154 cover_type = state .cover_type
155+ cover : Cover = CubicalCover (n_intervals = 2 , overlap_frac = 0.25 )
136156 if cover_type == COVER_TRIVIAL :
137- return TrivialCover ()
157+ cover = TrivialCover ()
138158 elif cover_type == COVER_CUBICAL :
139- n_intervals = int (state .cover_cubical_n_intervals )
140- overlap_frac = float (state .cover_cubical_overlap_frac )
141- return CubicalCover (n_intervals = n_intervals , overlap_frac = overlap_frac )
159+ cover = CubicalCover (
160+ n_intervals = state .cover_cubical_n_intervals ,
161+ overlap_frac = state .cover_cubical_overlap_frac ,
162+ )
142163 elif cover_type == COVER_BALL :
143- radius = float (state .cover_ball_radius )
144- return BallCover (radius = radius )
164+ cover = BallCover (radius = state .cover_ball_radius )
145165 elif cover_type == COVER_KNN :
146- neighbors = int (state .cover_knn_neighbors )
147- return KNNCover (neighbors = neighbors )
166+ cover = KNNCover (neighbors = state .cover_knn_neighbors )
148167 else :
149- logger .error (f"Unknown cover type: { cover_type } " )
150- return None
168+ logger .error ("Defaulting to CubicalCover " )
169+ return cover
151170
152171
153172def get_clustering (state : State ):
154173 clustering_type = state .clustering_type
155174 if clustering_type == CLUSTERING_TRIVIAL :
156175 return TrivialClustering ()
157176 elif clustering_type == CLUSTERING_KMEANS :
158- n_clusters = state .clustering_kmeans_n_clusters
159- return KMeans (n_clusters = n_clusters , random_state = 42 )
177+ return KMeans (
178+ n_clusters = state .clustering_kmeans_n_clusters , random_state = RANDOM_STATE
179+ )
160180 elif clustering_type == CLUSTERING_DBSCAN :
161- eps = state .clustering_dbscan_eps
162- min_samples = state .clustering_dbscan_min_samples
163- return DBSCAN (eps = eps , min_samples = min_samples )
181+ return DBSCAN (
182+ eps = state .clustering_dbscan_eps ,
183+ min_samples = state .clustering_dbscan_min_samples ,
184+ )
164185 elif clustering_type == CLUSTERING_AGGLOMERATIVE :
165- n_clusters = state .clustering_agglomerative_n_clusters
166- return AgglomerativeClustering (n_clusters = n_clusters )
186+ return AgglomerativeClustering (
187+ n_clusters = state .clustering_agglomerative_n_clusters
188+ )
167189 else :
168- logger .error (f"Unknown clustering type: { clustering_type } " )
169- return None
190+ logger .error ("Defaulting to TrivialClustering " )
191+ return TrivialClustering ()
170192
171193
172194def compute_mapper (** kwargs ):
@@ -203,9 +225,11 @@ def compute_mapper(**kwargs):
203225 logger .info (f"Mapper configuration: { mapper_algo } " )
204226 mapper_graph = mapper_algo .fit_transform (X , y )
205227
228+ dim = 3 if state .draw_dim == DRAW_3D else 2
229+
206230 mapper_plot = MapperPlot (
207231 mapper_graph ,
208- dim = state . draw_dim ,
232+ dim = dim ,
209233 iterations = state .draw_iterations ,
210234 seed = 42 ,
211235 )
@@ -497,7 +521,7 @@ async def update(self, _=None):
497521 self .state .clustering_agglomerative_n_clusters = int (
498522 self .clustering_agglomerative_n_clusters .value
499523 )
500- self .state .draw_dim = 3 if self .draw_3d .value == DRAW_3D else 2
524+ self .state .draw_dim = str ( self .draw_3d .value )
501525 self .state .draw_iterations = int (self .draw_iterations .value )
502526 self .state .draw_aggregation = self .draw_aggregation .value
503527 await self .render ()
@@ -508,6 +532,9 @@ async def render(self):
508532 ** asdict (self .state ),
509533 )
510534
535+ self .storage ["mapper_graph" ] = mapper_graph
536+ self .storage ["mapper_fig" ] = mapper_fig
537+
511538 self .plot_container .clear ()
512539 with self .plot_container :
513540 ui .plotly (mapper_fig )
@@ -517,7 +544,8 @@ async def update_csv_handler(self, file):
517544 await self .update_dataset_handler ()
518545 await self .update_plot_handler ()
519546
520- def __init__ (self ):
547+ def __init__ (self , storage ):
548+ self .storage = storage
521549 self .state = State ()
522550 with ui .row ().classes ("w-full h-screen m-0 p-0 gap-0 overflow-hidden" ):
523551 with ui .column ().classes ("w-64 h-full m-0 p-0" ):
@@ -544,9 +572,14 @@ def __init__(self):
544572 ui .plotly (mapper_fig )
545573
546574
575+ @ui .page ("/" )
576+ def main_page ():
577+ storage = app .storage .client
578+ App (storage = storage )
579+
580+
547581def main ():
548- App ()
549- ui .run ()
582+ ui .run (storage_secret = "tdamapper_secret" , title = "TDA Mapper App" , port = 8080 )
550583
551584
552585if __name__ in {"__main__" , "__mp_main__" , "tdamapper.app" }:
0 commit comments