Skip to content

Commit a5d2674

Browse files
committed
First implementation using nicegui
1 parent 1805f01 commit a5d2674

1 file changed

Lines changed: 309 additions & 0 deletions

File tree

app/nicegui_app.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import numpy as np
2+
import plotly.graph_objs as go
3+
from nicegui import ui
4+
from nicegui.events import ValueChangeEventArguments
5+
from sklearn.cluster import AgglomerativeClustering, KMeans
6+
from sklearn.datasets import load_digits, make_circles
7+
from sklearn.decomposition import PCA
8+
from umap import UMAP
9+
10+
from tdamapper.core import TrivialClustering, TrivialCover
11+
from tdamapper.cover import BallCover, CubicalCover, KNNCover
12+
from tdamapper.learn import MapperAlgorithm
13+
from tdamapper.plot import MapperPlot
14+
15+
16+
def mode(arr):
17+
values, counts = np.unique(arr, return_counts=True)
18+
max_count = np.max(counts)
19+
mode_values = values[counts == max_count]
20+
return np.nanmean(mode_values)
21+
22+
23+
def _identity(X):
24+
return X
25+
26+
27+
def _pca(n_components):
28+
pca = PCA(n_components=n_components, random_state=42)
29+
30+
def _func(X):
31+
return pca.fit_transform(X)
32+
33+
return _func
34+
35+
36+
def _umap(n_components):
37+
um = UMAP(n_components=n_components, random_state=42)
38+
39+
def _func(X):
40+
return um.fit_transform(X)
41+
42+
return _func
43+
44+
45+
class App:
46+
47+
def build_lens(self):
48+
self.opt_lens_id = "Identity"
49+
self.opt_lens_pca = "PCA"
50+
self.opt_lens_umap = "UMAP"
51+
52+
self.lens_type = ui.select(
53+
label="Lens type",
54+
options=[
55+
self.opt_lens_id,
56+
self.opt_lens_pca,
57+
self.opt_lens_umap,
58+
],
59+
value=self.opt_lens_pca,
60+
on_change=self.update,
61+
).classes("w-full")
62+
self.pca_n_components = ui.number(
63+
label="PCA Components",
64+
min=1,
65+
max=10,
66+
value=2,
67+
on_change=self.update,
68+
).classes("w-full")
69+
self.pca_n_components.bind_visibility_from(
70+
target_object=self.lens_type,
71+
target_name="value",
72+
value=self.opt_lens_pca,
73+
)
74+
self.umap_n_components = ui.number(
75+
label="UMAP Components",
76+
min=1,
77+
max=10,
78+
value=2,
79+
on_change=self.update,
80+
).classes("w-full")
81+
self.umap_n_components.bind_visibility_from(
82+
target_object=self.lens_type,
83+
target_name="value",
84+
value=self.opt_lens_umap,
85+
)
86+
87+
def build_cover(self):
88+
self.opt_cover_trivial = "Trivial"
89+
self.opt_cover_cubical = "Cubical"
90+
self.opt_cover_ball = "Ball"
91+
self.opt_cover_knn = "KNN"
92+
93+
self.cover_type = ui.select(
94+
label="Cover type",
95+
options=[
96+
self.opt_cover_trivial,
97+
self.opt_cover_cubical,
98+
self.opt_cover_ball,
99+
self.opt_cover_knn,
100+
],
101+
value=self.opt_cover_cubical,
102+
on_change=self.update,
103+
).classes("w-full")
104+
self.cover_cubical_n = ui.number(
105+
label="Intervals",
106+
min=1,
107+
max=10,
108+
value=2,
109+
on_change=self.update,
110+
).classes("w-full")
111+
self.cover_cubical_n.bind_visibility_from(
112+
target_object=self.cover_type,
113+
target_name="value",
114+
value=self.opt_cover_cubical,
115+
)
116+
self.cover_cubical_overlap = ui.number(
117+
label="Overlap",
118+
min=0.0,
119+
max=1.0,
120+
value=0.5,
121+
on_change=self.update,
122+
).classes("w-full")
123+
self.cover_cubical_overlap.bind_visibility_from(
124+
target_object=self.cover_type,
125+
target_name="value",
126+
value=self.opt_cover_cubical,
127+
)
128+
self.cover_ball_radius = ui.number(
129+
label="Radius",
130+
min=0.0,
131+
value=100.0,
132+
on_change=self.update,
133+
).classes("w-full")
134+
self.cover_ball_radius.bind_visibility_from(
135+
target_object=self.cover_type,
136+
target_name="value",
137+
value=self.opt_cover_ball,
138+
)
139+
self.cover_knn_k = ui.number(
140+
label="Neighbors",
141+
min=0,
142+
value=10,
143+
on_change=self.update,
144+
).classes("w-full")
145+
self.cover_knn_k.bind_visibility_from(
146+
target_object=self.cover_type,
147+
target_name="value",
148+
value=self.opt_cover_knn,
149+
)
150+
151+
def build_clustering(self):
152+
self.opt_clustering_trivial = "Trivial"
153+
self.opt_clustering_kmeans = "KMeans"
154+
self.opt_clustering_agg = "Agglomerative"
155+
self.opt_clustering_dbscan = "DBSCAN"
156+
157+
self.clustering_type = ui.select(
158+
label="Clustering type",
159+
options=[
160+
self.opt_clustering_trivial,
161+
self.opt_clustering_kmeans,
162+
self.opt_clustering_agg,
163+
self.opt_clustering_dbscan,
164+
],
165+
value=self.opt_clustering_trivial,
166+
on_change=self.update,
167+
).classes("w-full")
168+
self.clustering_kmeans_k = ui.number(
169+
label="Clusters",
170+
min=1,
171+
value=2,
172+
on_change=self.update,
173+
).classes("w-full")
174+
self.clustering_kmeans_k.bind_visibility_from(
175+
target_object=self.clustering_type,
176+
target_name="value",
177+
value=self.opt_clustering_kmeans,
178+
)
179+
self.clustering_dbscan_eps = ui.number(
180+
label="Eps",
181+
min=0.0,
182+
value=0.5,
183+
on_change=self.update,
184+
).classes("w-full")
185+
self.clustering_dbscan_eps.bind_visibility_from(
186+
target_object=self.clustering_type,
187+
target_name="value",
188+
value=self.opt_clustering_dbscan,
189+
)
190+
self.clustering_dbscan_min_samples = ui.number(
191+
label="Min Samples",
192+
min=1,
193+
value=5,
194+
on_change=self.update,
195+
).classes("w-full")
196+
self.clustering_dbscan_eps.bind_visibility_from(
197+
target_object=self.clustering_type,
198+
target_name="value",
199+
value=self.opt_clustering_dbscan,
200+
)
201+
self.clustering_agg_n = ui.number(
202+
label="Clusters",
203+
min=1,
204+
value=2,
205+
on_change=self.update,
206+
).classes("w-full")
207+
self.clustering_agg_n.bind_visibility_from(
208+
target_object=self.clustering_type,
209+
target_name="value",
210+
value=self.opt_clustering_agg,
211+
)
212+
213+
def build_plot(self):
214+
self.plot = ui.plotly(go.Figure())
215+
216+
def render_lens(self):
217+
print(f"Lens type: {self.lens_type.value}")
218+
if self.lens_type.value == self.opt_lens_id:
219+
return _identity
220+
elif self.lens_type.value == self.opt_lens_pca:
221+
n = int(self.pca_n_components.value)
222+
return _pca(n)
223+
elif self.lens_type.value == self.opt_lens_umap:
224+
n = int(self.umap_n_components.value)
225+
return _umap(n)
226+
227+
def render_cover(self):
228+
if self.cover_type.value == self.opt_cover_trivial:
229+
return TrivialCover()
230+
elif self.cover_type.value == self.opt_cover_ball:
231+
r = float(self.cover_ball_radius.value)
232+
return BallCover(radius=r)
233+
elif self.cover_type.value == self.opt_cover_cubical:
234+
n = int(self.cover_cubical_n.value)
235+
overlap = float(self.cover_cubical_overlap.value)
236+
return CubicalCover(n_intervals=n, overlap_frac=overlap)
237+
elif self.cover_type.value == self.opt_cover_knn:
238+
k = int(self.cover_knn_k.value)
239+
return KNNCover(neighbors=k)
240+
241+
def render_clustering(self):
242+
if self.clustering_type.value == self.opt_clustering_trivial:
243+
return TrivialClustering()
244+
elif self.clustering_type.value == self.opt_clustering_kmeans:
245+
k = int(self.clustering_kmeans_k.value)
246+
return KMeans(k)
247+
elif self.clustering_type.value == self.opt_clustering_dbscan:
248+
eps = float(self.clustering_dbscan_eps.value)
249+
min_samples = int(self.clustering_dbscan_min_samples.value)
250+
return DBSCAN(eps=eps)
251+
252+
def update(self, _=None):
253+
X, labels = load_digits(return_X_y=True)
254+
lens = self.render_lens()
255+
if lens is None:
256+
print("Lens is None")
257+
return
258+
y = lens(X)
259+
260+
cover = self.render_cover()
261+
if cover is None:
262+
print("Cover is None")
263+
return
264+
265+
clustering = self.render_clustering()
266+
if clustering is None:
267+
print("Clustering is None")
268+
return
269+
270+
mapper_algo = MapperAlgorithm(
271+
cover=cover,
272+
clustering=clustering,
273+
verbose=False,
274+
)
275+
276+
mapper_graph = mapper_algo.fit_transform(X, y)
277+
278+
mapper_plot = MapperPlot(mapper_graph, dim=3, iterations=400, seed=42)
279+
280+
mapper_fig = mapper_plot.plot_plotly(
281+
colors=labels,
282+
cmap=["jet", "viridis", "cividis"],
283+
agg=mode,
284+
title="mode of digits",
285+
width=800,
286+
height=800,
287+
node_size=0.5,
288+
)
289+
if mapper_fig.layout.width is not None:
290+
mapper_fig.layout.width = None
291+
if not mapper_fig.layout.autosize:
292+
mapper_fig.layout.autosize = True
293+
mapper_fig.layout.autosize = True
294+
self.plot.update_figure(mapper_fig)
295+
296+
def build(self):
297+
with ui.left_drawer().classes("w-[400px]"):
298+
self.build_lens()
299+
ui.separator()
300+
self.build_cover()
301+
ui.separator()
302+
self.build_clustering()
303+
self.build_plot()
304+
self.update()
305+
306+
307+
app = App()
308+
app.build()
309+
ui.run()

0 commit comments

Comments
 (0)