11import numpy as np
22import plotly .graph_objs as go
3- from nicegui import ui
4- from nicegui .events import ValueChangeEventArguments
5- from sklearn .cluster import AgglomerativeClustering , KMeans
6- from sklearn .datasets import load_digits , make_circles
3+ from nicegui import run , ui
4+ from sklearn .cluster import DBSCAN , AgglomerativeClustering , KMeans
5+ from sklearn .datasets import load_digits
76from sklearn .decomposition import PCA
87from umap import UMAP
98
@@ -42,229 +41,238 @@ def _func(X):
4241 return _func
4342
4443
44+ LENS_IDENTITY = "Identity"
45+ LENS_PCA = "PCA"
46+ LENS_UMAP = "UMAP"
47+
48+ COVER_TRIVIAL = "Trivial"
49+ COVER_CUBICAL = "Cubical"
50+ COVER_BALL = "Ball"
51+ COVER_KNN = "KNN"
52+
53+ CLUSTERING_TRIVIAL = "Trivial"
54+ CLUSTERING_KMEANS = "KMeans"
55+ CLUSTERING_AGGLOMERATIVE = "Agglomerative"
56+ CLUSTERING_DBSCAN = "DBSCAN"
57+
58+
4559class App :
4660
4761 def build_lens (self ):
48- self .opt_lens_id = "Identity"
49- self .opt_lens_pca = "PCA"
50- self .opt_lens_umap = "UMAP"
51-
5262 self .lens_type = ui .select (
5363 label = "Lens type" ,
5464 options = [
55- self . opt_lens_id ,
56- self . opt_lens_pca ,
57- self . opt_lens_umap ,
65+ LENS_IDENTITY ,
66+ LENS_PCA ,
67+ LENS_UMAP ,
5868 ],
59- value = self . opt_lens_pca ,
60- on_change = self .update ,
69+ value = LENS_PCA ,
70+ on_change = self .update_handler ,
6171 ).classes ("w-full" )
6272 self .pca_n_components = ui .number (
6373 label = "PCA Components" ,
6474 min = 1 ,
6575 max = 10 ,
6676 value = 2 ,
67- on_change = self .update ,
77+ on_change = self .update_handler ,
6878 ).classes ("w-full" )
6979 self .pca_n_components .bind_visibility_from (
7080 target_object = self .lens_type ,
7181 target_name = "value" ,
72- value = self . opt_lens_pca ,
82+ value = LENS_PCA ,
7383 )
7484 self .umap_n_components = ui .number (
7585 label = "UMAP Components" ,
7686 min = 1 ,
7787 max = 10 ,
7888 value = 2 ,
79- on_change = self .update ,
89+ on_change = self .update_handler ,
8090 ).classes ("w-full" )
8191 self .umap_n_components .bind_visibility_from (
8292 target_object = self .lens_type ,
8393 target_name = "value" ,
84- value = self . opt_lens_umap ,
94+ value = LENS_UMAP ,
8595 )
8696
8797 def build_cover (self ):
88- self .opt_cover_trivial = "Trivial"
89- self .opt_cover_cubical = "Cubical"
90- self .opt_cover_ball = "Ball"
91- self .opt_cover_knn = "KNN"
9298
9399 self .cover_type = ui .select (
94100 label = "Cover type" ,
95101 options = [
96- self . opt_cover_trivial ,
97- self . opt_cover_cubical ,
98- self . opt_cover_ball ,
99- self . opt_cover_knn ,
102+ COVER_TRIVIAL ,
103+ COVER_CUBICAL ,
104+ COVER_BALL ,
105+ COVER_KNN ,
100106 ],
101- value = self . opt_cover_cubical ,
102- on_change = self .update ,
107+ value = COVER_CUBICAL ,
108+ on_change = self .update_handler ,
103109 ).classes ("w-full" )
104- self .cover_cubical_n = ui .number (
110+ self .cover_cubical_n_intervals = ui .number (
105111 label = "Intervals" ,
106112 min = 1 ,
107- max = 10 ,
113+ max = 100 ,
108114 value = 2 ,
109- on_change = self .update ,
115+ on_change = self .update_handler ,
110116 ).classes ("w-full" )
111- self .cover_cubical_n .bind_visibility_from (
117+ self .cover_cubical_n_intervals .bind_visibility_from (
112118 target_object = self .cover_type ,
113119 target_name = "value" ,
114- value = self . opt_cover_cubical ,
120+ value = COVER_CUBICAL ,
115121 )
116- self .cover_cubical_overlap = ui .number (
122+ self .cover_cubical_overlap_frac = ui .number (
117123 label = "Overlap" ,
118124 min = 0.0 ,
119125 max = 1.0 ,
120126 value = 0.5 ,
121- on_change = self .update ,
127+ on_change = self .update_handler ,
122128 ).classes ("w-full" )
123- self .cover_cubical_overlap .bind_visibility_from (
129+ self .cover_cubical_overlap_frac .bind_visibility_from (
124130 target_object = self .cover_type ,
125131 target_name = "value" ,
126- value = self . opt_cover_cubical ,
132+ value = COVER_CUBICAL ,
127133 )
128134 self .cover_ball_radius = ui .number (
129135 label = "Radius" ,
130136 min = 0.0 ,
131137 value = 100.0 ,
132- on_change = self .update ,
138+ on_change = self .update_handler ,
133139 ).classes ("w-full" )
134140 self .cover_ball_radius .bind_visibility_from (
135141 target_object = self .cover_type ,
136142 target_name = "value" ,
137- value = self . opt_cover_ball ,
143+ value = COVER_BALL ,
138144 )
139- self .cover_knn_k = ui .number (
145+ self .cover_knn_neighbors = ui .number (
140146 label = "Neighbors" ,
141147 min = 0 ,
142148 value = 10 ,
143- on_change = self .update ,
149+ on_change = self .update_handler ,
144150 ).classes ("w-full" )
145- self .cover_knn_k .bind_visibility_from (
151+ self .cover_knn_neighbors .bind_visibility_from (
146152 target_object = self .cover_type ,
147153 target_name = "value" ,
148- value = self . opt_cover_knn ,
154+ value = COVER_KNN ,
149155 )
150156
151157 def build_clustering (self ):
152- self .opt_clustering_trivial = "Trivial"
153- self .opt_clustering_kmeans = "KMeans"
154- self .opt_clustering_agg = "Agglomerative"
155- self .opt_clustering_dbscan = "DBSCAN"
156-
157158 self .clustering_type = ui .select (
158159 label = "Clustering type" ,
159160 options = [
160- self . opt_clustering_trivial ,
161- self . opt_clustering_kmeans ,
162- self . opt_clustering_agg ,
163- self . opt_clustering_dbscan ,
161+ CLUSTERING_TRIVIAL ,
162+ CLUSTERING_KMEANS ,
163+ CLUSTERING_AGGLOMERATIVE ,
164+ CLUSTERING_DBSCAN ,
164165 ],
165- value = self . opt_clustering_trivial ,
166- on_change = self .update ,
166+ value = CLUSTERING_TRIVIAL ,
167+ on_change = self .update_handler ,
167168 ).classes ("w-full" )
168- self .clustering_kmeans_k = ui .number (
169+ self .clustering_kmeans_n_clusters = ui .number (
169170 label = "Clusters" ,
170171 min = 1 ,
171172 value = 2 ,
172- on_change = self .update ,
173+ on_change = self .update_handler ,
173174 ).classes ("w-full" )
174- self .clustering_kmeans_k .bind_visibility_from (
175+ self .clustering_kmeans_n_clusters .bind_visibility_from (
175176 target_object = self .clustering_type ,
176177 target_name = "value" ,
177- value = self . opt_clustering_kmeans ,
178+ value = CLUSTERING_KMEANS ,
178179 )
179180 self .clustering_dbscan_eps = ui .number (
180181 label = "Eps" ,
181182 min = 0.0 ,
182183 value = 0.5 ,
183- on_change = self .update ,
184+ on_change = self .update_handler ,
184185 ).classes ("w-full" )
185186 self .clustering_dbscan_eps .bind_visibility_from (
186187 target_object = self .clustering_type ,
187188 target_name = "value" ,
188- value = self . opt_clustering_dbscan ,
189+ value = CLUSTERING_DBSCAN ,
189190 )
190191 self .clustering_dbscan_min_samples = ui .number (
191192 label = "Min Samples" ,
192193 min = 1 ,
193194 value = 5 ,
194- on_change = self .update ,
195+ on_change = self .update_handler ,
195196 ).classes ("w-full" )
196- self .clustering_dbscan_eps .bind_visibility_from (
197+ self .clustering_dbscan_min_samples .bind_visibility_from (
197198 target_object = self .clustering_type ,
198199 target_name = "value" ,
199- value = self . opt_clustering_dbscan ,
200+ value = CLUSTERING_DBSCAN ,
200201 )
201- self .clustering_agg_n = ui .number (
202+ self .clustering_agglomerative_n_clusters = ui .number (
202203 label = "Clusters" ,
203204 min = 1 ,
204205 value = 2 ,
205- on_change = self .update ,
206+ on_change = self .update_handler ,
206207 ).classes ("w-full" )
207- self .clustering_agg_n .bind_visibility_from (
208+ self .clustering_agglomerative_n_clusters .bind_visibility_from (
208209 target_object = self .clustering_type ,
209210 target_name = "value" ,
210- value = self . opt_clustering_agg ,
211+ value = CLUSTERING_AGGLOMERATIVE ,
211212 )
212213
213214 def build_plot (self ):
214- self .plot = ui .plotly (go .Figure ())
215+ fig = go .Figure ()
216+ fig .layout .width = None
217+ fig .layout .autosize = True
218+ self .plot_container = ui .element ("div" ).classes ("w-full h-full" )
219+ with self .plot_container :
220+ ui .plotly (go .Figure ())
215221
216222 def render_lens (self ):
217- print (f"Lens type: { self .lens_type .value } " )
218- if self .lens_type .value == self .opt_lens_id :
223+ if self .lens_type .value == LENS_IDENTITY :
219224 return _identity
220- elif self .lens_type .value == self . opt_lens_pca :
225+ elif self .lens_type .value == LENS_PCA :
221226 n = int (self .pca_n_components .value )
222227 return _pca (n )
223- elif self .lens_type .value == self . opt_lens_umap :
228+ elif self .lens_type .value == LENS_UMAP :
224229 n = int (self .umap_n_components .value )
225230 return _umap (n )
226231
227232 def render_cover (self ):
228- if self .cover_type .value == self . opt_cover_trivial :
233+ if self .cover_type .value == COVER_TRIVIAL :
229234 return TrivialCover ()
230- elif self .cover_type .value == self . opt_cover_ball :
231- r = float (self .cover_ball_radius .value )
232- return BallCover (radius = r )
233- elif self .cover_type .value == self . opt_cover_cubical :
234- n = int (self .cover_cubical_n .value )
235- overlap = float (self .cover_cubical_overlap .value )
236- return CubicalCover (n_intervals = n , overlap_frac = overlap )
237- elif self .cover_type .value == self . opt_cover_knn :
238- k = int (self .cover_knn_k .value )
239- return KNNCover (neighbors = k )
235+ elif self .cover_type .value == COVER_BALL :
236+ radius = float (self .cover_ball_radius .value )
237+ return BallCover (radius = radius )
238+ elif self .cover_type .value == COVER_CUBICAL :
239+ n_intervals = int (self .cover_cubical_n_intervals .value )
240+ overlap_frac = float (self .cover_cubical_overlap_frac .value )
241+ return CubicalCover (n_intervals = n_intervals , overlap_frac = overlap_frac )
242+ elif self .cover_type .value == COVER_KNN :
243+ neighbors = int (self .cover_knn_neighbors .value )
244+ return KNNCover (neighbors = neighbors )
240245
241246 def render_clustering (self ):
242- if self .clustering_type .value == self . opt_clustering_trivial :
247+ if self .clustering_type .value == CLUSTERING_TRIVIAL :
243248 return TrivialClustering ()
244- elif self .clustering_type .value == self . opt_clustering_kmeans :
245- k = int (self .clustering_kmeans_k .value )
246- return KMeans (k )
247- elif self .clustering_type .value == self . opt_clustering_dbscan :
249+ elif self .clustering_type .value == CLUSTERING_KMEANS :
250+ n_clusters = int (self .clustering_kmeans_n_clusters .value )
251+ return KMeans (n_clusters )
252+ elif self .clustering_type .value == CLUSTERING_DBSCAN :
248253 eps = float (self .clustering_dbscan_eps .value )
249254 min_samples = int (self .clustering_dbscan_min_samples .value )
250- return DBSCAN (eps = eps )
255+ return DBSCAN (eps = eps , min_samples = min_samples )
256+ elif self .clustering_type == CLUSTERING_AGGLOMERATIVE :
257+ n_clusters = int (self .clustering_agglomerative_n_clusters .value )
258+ return AgglomerativeClustering (n_clusters = n_clusters )
259+
260+ async def update_handler (self , _ = None ):
261+ await run .io_bound (self .update )
251262
252263 def update (self , _ = None ):
253264 X , labels = load_digits (return_X_y = True )
254265 lens = self .render_lens ()
255266 if lens is None :
256- print ("Lens is None" )
257267 return
258268 y = lens (X )
259269
260270 cover = self .render_cover ()
261271 if cover is None :
262- print ("Cover is None" )
263272 return
264273
265274 clustering = self .render_clustering ()
266275 if clustering is None :
267- print ("Clustering is None" )
268276 return
269277
270278 mapper_algo = MapperAlgorithm (
@@ -286,24 +294,31 @@ def update(self, _=None):
286294 height = 800 ,
287295 node_size = 0.5 ,
288296 )
289- if mapper_fig .layout .width is not None :
290- mapper_fig .layout .width = None
291- if not mapper_fig .layout .autosize :
292- mapper_fig .layout .autosize = True
297+ # if mapper_fig.layout.width is not None:
298+ mapper_fig .layout .width = None
299+ # if not mapper_fig.layout.autosize:
293300 mapper_fig .layout .autosize = True
294- self .plot .update_figure (mapper_fig )
295-
296- def build (self ):
297- with ui .left_drawer ().classes ("w-[400px]" ):
298- self .build_lens ()
299- ui .separator ()
300- self .build_cover ()
301- ui .separator ()
302- self .build_clustering ()
303- self .build_plot ()
301+ self .plot_container .clear ()
302+ with self .plot_container :
303+ ui .plotly (mapper_fig )
304+
305+ def __init__ (self ):
306+ with ui .row ().classes ("w-full h-full m-0 p-0 gap-0 overflow-hidden" ):
307+ with ui .column ().classes ("w-64 h-full overflow-y-auto m-0 p-3 gap-2" ):
308+ with ui .card ().classes ("w-full" ):
309+ ui .markdown ("#### 🔎 Lens" )
310+ self .build_lens ()
311+ with ui .card ().classes ("w-full" ):
312+ ui .markdown ("#### 🌐 Cover" )
313+ self .build_cover ()
314+ with ui .card ().classes ("w-full" ):
315+ ui .markdown ("#### 🧮 Clustering" )
316+ self .build_clustering ()
317+
318+ with ui .column ().classes ("flex-1 h-full overflow-hidden m-0 p-0" ):
319+ self .build_plot ()
304320 self .update ()
305321
306322
307323app = App ()
308- app .build ()
309324ui .run ()
0 commit comments