Skip to content

Commit ee5e8b5

Browse files
committed
Added ui elements for cover and clustering. Added async worker function for plot update
1 parent a5d2674 commit ee5e8b5

1 file changed

Lines changed: 121 additions & 106 deletions

File tree

app/nicegui_app.py

Lines changed: 121 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import numpy as np
22
import plotly.graph_objs as go
3-
from nicegui import ui
4-
from nicegui.events import ValueChangeEventArguments
5-
from sklearn.cluster import AgglomerativeClustering, KMeans
6-
from sklearn.datasets import load_digits, make_circles
3+
from nicegui import run, ui
4+
from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans
5+
from sklearn.datasets import load_digits
76
from sklearn.decomposition import PCA
87
from umap import UMAP
98

@@ -42,229 +41,238 @@ def _func(X):
4241
return _func
4342

4443

44+
LENS_IDENTITY = "Identity"
45+
LENS_PCA = "PCA"
46+
LENS_UMAP = "UMAP"
47+
48+
COVER_TRIVIAL = "Trivial"
49+
COVER_CUBICAL = "Cubical"
50+
COVER_BALL = "Ball"
51+
COVER_KNN = "KNN"
52+
53+
CLUSTERING_TRIVIAL = "Trivial"
54+
CLUSTERING_KMEANS = "KMeans"
55+
CLUSTERING_AGGLOMERATIVE = "Agglomerative"
56+
CLUSTERING_DBSCAN = "DBSCAN"
57+
58+
4559
class App:
4660

4761
def build_lens(self):
48-
self.opt_lens_id = "Identity"
49-
self.opt_lens_pca = "PCA"
50-
self.opt_lens_umap = "UMAP"
51-
5262
self.lens_type = ui.select(
5363
label="Lens type",
5464
options=[
55-
self.opt_lens_id,
56-
self.opt_lens_pca,
57-
self.opt_lens_umap,
65+
LENS_IDENTITY,
66+
LENS_PCA,
67+
LENS_UMAP,
5868
],
59-
value=self.opt_lens_pca,
60-
on_change=self.update,
69+
value=LENS_PCA,
70+
on_change=self.update_handler,
6171
).classes("w-full")
6272
self.pca_n_components = ui.number(
6373
label="PCA Components",
6474
min=1,
6575
max=10,
6676
value=2,
67-
on_change=self.update,
77+
on_change=self.update_handler,
6878
).classes("w-full")
6979
self.pca_n_components.bind_visibility_from(
7080
target_object=self.lens_type,
7181
target_name="value",
72-
value=self.opt_lens_pca,
82+
value=LENS_PCA,
7383
)
7484
self.umap_n_components = ui.number(
7585
label="UMAP Components",
7686
min=1,
7787
max=10,
7888
value=2,
79-
on_change=self.update,
89+
on_change=self.update_handler,
8090
).classes("w-full")
8191
self.umap_n_components.bind_visibility_from(
8292
target_object=self.lens_type,
8393
target_name="value",
84-
value=self.opt_lens_umap,
94+
value=LENS_UMAP,
8595
)
8696

8797
def build_cover(self):
88-
self.opt_cover_trivial = "Trivial"
89-
self.opt_cover_cubical = "Cubical"
90-
self.opt_cover_ball = "Ball"
91-
self.opt_cover_knn = "KNN"
9298

9399
self.cover_type = ui.select(
94100
label="Cover type",
95101
options=[
96-
self.opt_cover_trivial,
97-
self.opt_cover_cubical,
98-
self.opt_cover_ball,
99-
self.opt_cover_knn,
102+
COVER_TRIVIAL,
103+
COVER_CUBICAL,
104+
COVER_BALL,
105+
COVER_KNN,
100106
],
101-
value=self.opt_cover_cubical,
102-
on_change=self.update,
107+
value=COVER_CUBICAL,
108+
on_change=self.update_handler,
103109
).classes("w-full")
104-
self.cover_cubical_n = ui.number(
110+
self.cover_cubical_n_intervals = ui.number(
105111
label="Intervals",
106112
min=1,
107-
max=10,
113+
max=100,
108114
value=2,
109-
on_change=self.update,
115+
on_change=self.update_handler,
110116
).classes("w-full")
111-
self.cover_cubical_n.bind_visibility_from(
117+
self.cover_cubical_n_intervals.bind_visibility_from(
112118
target_object=self.cover_type,
113119
target_name="value",
114-
value=self.opt_cover_cubical,
120+
value=COVER_CUBICAL,
115121
)
116-
self.cover_cubical_overlap = ui.number(
122+
self.cover_cubical_overlap_frac = ui.number(
117123
label="Overlap",
118124
min=0.0,
119125
max=1.0,
120126
value=0.5,
121-
on_change=self.update,
127+
on_change=self.update_handler,
122128
).classes("w-full")
123-
self.cover_cubical_overlap.bind_visibility_from(
129+
self.cover_cubical_overlap_frac.bind_visibility_from(
124130
target_object=self.cover_type,
125131
target_name="value",
126-
value=self.opt_cover_cubical,
132+
value=COVER_CUBICAL,
127133
)
128134
self.cover_ball_radius = ui.number(
129135
label="Radius",
130136
min=0.0,
131137
value=100.0,
132-
on_change=self.update,
138+
on_change=self.update_handler,
133139
).classes("w-full")
134140
self.cover_ball_radius.bind_visibility_from(
135141
target_object=self.cover_type,
136142
target_name="value",
137-
value=self.opt_cover_ball,
143+
value=COVER_BALL,
138144
)
139-
self.cover_knn_k = ui.number(
145+
self.cover_knn_neighbors = ui.number(
140146
label="Neighbors",
141147
min=0,
142148
value=10,
143-
on_change=self.update,
149+
on_change=self.update_handler,
144150
).classes("w-full")
145-
self.cover_knn_k.bind_visibility_from(
151+
self.cover_knn_neighbors.bind_visibility_from(
146152
target_object=self.cover_type,
147153
target_name="value",
148-
value=self.opt_cover_knn,
154+
value=COVER_KNN,
149155
)
150156

151157
def build_clustering(self):
152-
self.opt_clustering_trivial = "Trivial"
153-
self.opt_clustering_kmeans = "KMeans"
154-
self.opt_clustering_agg = "Agglomerative"
155-
self.opt_clustering_dbscan = "DBSCAN"
156-
157158
self.clustering_type = ui.select(
158159
label="Clustering type",
159160
options=[
160-
self.opt_clustering_trivial,
161-
self.opt_clustering_kmeans,
162-
self.opt_clustering_agg,
163-
self.opt_clustering_dbscan,
161+
CLUSTERING_TRIVIAL,
162+
CLUSTERING_KMEANS,
163+
CLUSTERING_AGGLOMERATIVE,
164+
CLUSTERING_DBSCAN,
164165
],
165-
value=self.opt_clustering_trivial,
166-
on_change=self.update,
166+
value=CLUSTERING_TRIVIAL,
167+
on_change=self.update_handler,
167168
).classes("w-full")
168-
self.clustering_kmeans_k = ui.number(
169+
self.clustering_kmeans_n_clusters = ui.number(
169170
label="Clusters",
170171
min=1,
171172
value=2,
172-
on_change=self.update,
173+
on_change=self.update_handler,
173174
).classes("w-full")
174-
self.clustering_kmeans_k.bind_visibility_from(
175+
self.clustering_kmeans_n_clusters.bind_visibility_from(
175176
target_object=self.clustering_type,
176177
target_name="value",
177-
value=self.opt_clustering_kmeans,
178+
value=CLUSTERING_KMEANS,
178179
)
179180
self.clustering_dbscan_eps = ui.number(
180181
label="Eps",
181182
min=0.0,
182183
value=0.5,
183-
on_change=self.update,
184+
on_change=self.update_handler,
184185
).classes("w-full")
185186
self.clustering_dbscan_eps.bind_visibility_from(
186187
target_object=self.clustering_type,
187188
target_name="value",
188-
value=self.opt_clustering_dbscan,
189+
value=CLUSTERING_DBSCAN,
189190
)
190191
self.clustering_dbscan_min_samples = ui.number(
191192
label="Min Samples",
192193
min=1,
193194
value=5,
194-
on_change=self.update,
195+
on_change=self.update_handler,
195196
).classes("w-full")
196-
self.clustering_dbscan_eps.bind_visibility_from(
197+
self.clustering_dbscan_min_samples.bind_visibility_from(
197198
target_object=self.clustering_type,
198199
target_name="value",
199-
value=self.opt_clustering_dbscan,
200+
value=CLUSTERING_DBSCAN,
200201
)
201-
self.clustering_agg_n = ui.number(
202+
self.clustering_agglomerative_n_clusters = ui.number(
202203
label="Clusters",
203204
min=1,
204205
value=2,
205-
on_change=self.update,
206+
on_change=self.update_handler,
206207
).classes("w-full")
207-
self.clustering_agg_n.bind_visibility_from(
208+
self.clustering_agglomerative_n_clusters.bind_visibility_from(
208209
target_object=self.clustering_type,
209210
target_name="value",
210-
value=self.opt_clustering_agg,
211+
value=CLUSTERING_AGGLOMERATIVE,
211212
)
212213

213214
def build_plot(self):
214-
self.plot = ui.plotly(go.Figure())
215+
fig = go.Figure()
216+
fig.layout.width = None
217+
fig.layout.autosize = True
218+
self.plot_container = ui.element("div").classes("w-full h-full")
219+
with self.plot_container:
220+
ui.plotly(go.Figure())
215221

216222
def render_lens(self):
217-
print(f"Lens type: {self.lens_type.value}")
218-
if self.lens_type.value == self.opt_lens_id:
223+
if self.lens_type.value == LENS_IDENTITY:
219224
return _identity
220-
elif self.lens_type.value == self.opt_lens_pca:
225+
elif self.lens_type.value == LENS_PCA:
221226
n = int(self.pca_n_components.value)
222227
return _pca(n)
223-
elif self.lens_type.value == self.opt_lens_umap:
228+
elif self.lens_type.value == LENS_UMAP:
224229
n = int(self.umap_n_components.value)
225230
return _umap(n)
226231

227232
def render_cover(self):
228-
if self.cover_type.value == self.opt_cover_trivial:
233+
if self.cover_type.value == COVER_TRIVIAL:
229234
return TrivialCover()
230-
elif self.cover_type.value == self.opt_cover_ball:
231-
r = float(self.cover_ball_radius.value)
232-
return BallCover(radius=r)
233-
elif self.cover_type.value == self.opt_cover_cubical:
234-
n = int(self.cover_cubical_n.value)
235-
overlap = float(self.cover_cubical_overlap.value)
236-
return CubicalCover(n_intervals=n, overlap_frac=overlap)
237-
elif self.cover_type.value == self.opt_cover_knn:
238-
k = int(self.cover_knn_k.value)
239-
return KNNCover(neighbors=k)
235+
elif self.cover_type.value == COVER_BALL:
236+
radius = float(self.cover_ball_radius.value)
237+
return BallCover(radius=radius)
238+
elif self.cover_type.value == COVER_CUBICAL:
239+
n_intervals = int(self.cover_cubical_n_intervals.value)
240+
overlap_frac = float(self.cover_cubical_overlap_frac.value)
241+
return CubicalCover(n_intervals=n_intervals, overlap_frac=overlap_frac)
242+
elif self.cover_type.value == COVER_KNN:
243+
neighbors = int(self.cover_knn_neighbors.value)
244+
return KNNCover(neighbors=neighbors)
240245

241246
def render_clustering(self):
242-
if self.clustering_type.value == self.opt_clustering_trivial:
247+
if self.clustering_type.value == CLUSTERING_TRIVIAL:
243248
return TrivialClustering()
244-
elif self.clustering_type.value == self.opt_clustering_kmeans:
245-
k = int(self.clustering_kmeans_k.value)
246-
return KMeans(k)
247-
elif self.clustering_type.value == self.opt_clustering_dbscan:
249+
elif self.clustering_type.value == CLUSTERING_KMEANS:
250+
n_clusters = int(self.clustering_kmeans_n_clusters.value)
251+
return KMeans(n_clusters)
252+
elif self.clustering_type.value == CLUSTERING_DBSCAN:
248253
eps = float(self.clustering_dbscan_eps.value)
249254
min_samples = int(self.clustering_dbscan_min_samples.value)
250-
return DBSCAN(eps=eps)
255+
return DBSCAN(eps=eps, min_samples=min_samples)
256+
elif self.clustering_type == CLUSTERING_AGGLOMERATIVE:
257+
n_clusters = int(self.clustering_agglomerative_n_clusters.value)
258+
return AgglomerativeClustering(n_clusters=n_clusters)
259+
260+
async def update_handler(self, _=None):
261+
await run.io_bound(self.update)
251262

252263
def update(self, _=None):
253264
X, labels = load_digits(return_X_y=True)
254265
lens = self.render_lens()
255266
if lens is None:
256-
print("Lens is None")
257267
return
258268
y = lens(X)
259269

260270
cover = self.render_cover()
261271
if cover is None:
262-
print("Cover is None")
263272
return
264273

265274
clustering = self.render_clustering()
266275
if clustering is None:
267-
print("Clustering is None")
268276
return
269277

270278
mapper_algo = MapperAlgorithm(
@@ -286,24 +294,31 @@ def update(self, _=None):
286294
height=800,
287295
node_size=0.5,
288296
)
289-
if mapper_fig.layout.width is not None:
290-
mapper_fig.layout.width = None
291-
if not mapper_fig.layout.autosize:
292-
mapper_fig.layout.autosize = True
297+
# if mapper_fig.layout.width is not None:
298+
mapper_fig.layout.width = None
299+
# if not mapper_fig.layout.autosize:
293300
mapper_fig.layout.autosize = True
294-
self.plot.update_figure(mapper_fig)
295-
296-
def build(self):
297-
with ui.left_drawer().classes("w-[400px]"):
298-
self.build_lens()
299-
ui.separator()
300-
self.build_cover()
301-
ui.separator()
302-
self.build_clustering()
303-
self.build_plot()
301+
self.plot_container.clear()
302+
with self.plot_container:
303+
ui.plotly(mapper_fig)
304+
305+
def __init__(self):
306+
with ui.row().classes("w-full h-full m-0 p-0 gap-0 overflow-hidden"):
307+
with ui.column().classes("w-64 h-full overflow-y-auto m-0 p-3 gap-2"):
308+
with ui.card().classes("w-full"):
309+
ui.markdown("#### 🔎 Lens")
310+
self.build_lens()
311+
with ui.card().classes("w-full"):
312+
ui.markdown("#### 🌐 Cover")
313+
self.build_cover()
314+
with ui.card().classes("w-full"):
315+
ui.markdown("#### 🧮 Clustering")
316+
self.build_clustering()
317+
318+
with ui.column().classes("flex-1 h-full overflow-hidden m-0 p-0"):
319+
self.build_plot()
304320
self.update()
305321

306322

307323
app = App()
308-
app.build()
309324
ui.run()

0 commit comments

Comments
 (0)