Skip to content

Commit 06dabae

Browse files
committed
Added nicegui app
1 parent e722b5f commit 06dabae

1 file changed

Lines changed: 390 additions & 0 deletions

File tree

src/tdamapper/app.py

Lines changed: 390 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,390 @@
1+
import logging
2+
from dataclasses import asdict, dataclass
3+
4+
import pandas as pd
5+
from nicegui import app, run, ui
6+
from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans
7+
from sklearn.decomposition import PCA
8+
from umap import UMAP
9+
10+
from tdamapper.clustering import TrivialClustering
11+
from tdamapper.cover import BallCover, CubicalCover, KNNCover
12+
from tdamapper.learn import MapperAlgorithm
13+
from tdamapper.plot import MapperPlot
14+
15+
logging.basicConfig(level=logging.INFO)
16+
logger = logging.getLogger(__name__)
17+
18+
19+
LENS_IDENTITY = "Identity"
20+
LENS_PCA = "PCA"
21+
LENS_UMAP = "UMAP"
22+
23+
LENS_PCA_N_COMPONENTS = 2
24+
LENS_UMAP_N_COMPONENTS = 2
25+
26+
COVER_CUBICAL = "Cubical Cover"
27+
COVER_BALL = "Ball Cover"
28+
COVER_KNN = "KNN Cover"
29+
30+
CLUSTERING_TRIVIAL = "Skip"
31+
CLUSTERING_KMEANS = "KMeans"
32+
CLUSTERING_DBSCAN = "DBSCAN"
33+
CLUSTERING_AGGLOMERATIVE = "Agglomerative Clustering"
34+
35+
COVER_CUBICAL_N_INTERVALS = 10
36+
COVER_CUBICAL_OVERLAP_FRAC = 0.25
37+
COVER_KNN_NEIGHBORS = 10
38+
COVER_BALL_RADIUS = 100.0
39+
40+
CLUSTERING_KMEANS_N_CLUSTERS = 2
41+
CLUSTERING_DBSCAN_EPS = 0.5
42+
CLUSTERING_DBSCAN_MIN_SAMPLES = 5
43+
CLUSTERING_AGGLOMERATIVE_N_CLUSTERS = 2
44+
45+
RANDOM_SEED = 42
46+
47+
48+
@dataclass
49+
class MapperConfig:
50+
lens_type: str = LENS_PCA
51+
cover_type: str = COVER_CUBICAL
52+
clustering_type: str = CLUSTERING_TRIVIAL
53+
lens_pca_n_components: int = LENS_PCA_N_COMPONENTS
54+
lens_umap_n_components: int = LENS_UMAP_N_COMPONENTS
55+
cover_cubical_n_intervals: int = COVER_CUBICAL_N_INTERVALS
56+
cover_cubical_overlap_frac: float = COVER_CUBICAL_OVERLAP_FRAC
57+
cover_knn_neighbors: int = COVER_KNN_NEIGHBORS
58+
clustering_kmeans_n_clusters: int = CLUSTERING_KMEANS_N_CLUSTERS
59+
clustering_dbscan_eps: float = CLUSTERING_DBSCAN_EPS
60+
clustering_dbscan_min_samples: int = CLUSTERING_DBSCAN_MIN_SAMPLES
61+
clustering_agglomerative_n_clusters: int = CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
62+
63+
64+
def identity(X):
65+
return X
66+
67+
68+
def pca(n_components):
69+
70+
def _pca(X):
71+
pca_model = PCA(n_components=n_components, random_state=RANDOM_SEED)
72+
return pca_model.fit_transform(X)
73+
74+
return _pca
75+
76+
77+
def umap(n_components):
78+
79+
def _umap(X):
80+
umap_model = umap.UMAP(n_components=n_components, random_state=RANDOM_SEED)
81+
return umap_model.fit_transform(X)
82+
83+
return _umap
84+
85+
86+
def run_mapper(df, **kwargs):
87+
if df is None:
88+
logger.error("No data found. Please upload a file first.")
89+
return
90+
logger.info("Computing Mapper.")
91+
92+
mapper_config = MapperConfig(**kwargs)
93+
94+
lens_type = mapper_config.lens_type
95+
cover_type = mapper_config.cover_type
96+
clustering_type = mapper_config.clustering_type
97+
lens_pca_n_components = mapper_config.lens_pca_n_components
98+
lens_umap_n_components = mapper_config.lens_umap_n_components
99+
cover_cubical_n_intervals = mapper_config.cover_cubical_n_intervals
100+
cover_cubical_overlap_frac = mapper_config.cover_cubical_overlap_frac
101+
cover_knn_neighbors = mapper_config.cover_knn_neighbors
102+
clustering_kmeans_n_clusters = mapper_config.clustering_kmeans_n_clusters
103+
clustering_dbscan_eps = mapper_config.clustering_dbscan_eps
104+
clustering_dbscan_min_samples = mapper_config.clustering_dbscan_min_samples
105+
clustering_agglomerative_n_clusters = (
106+
mapper_config.clustering_agglomerative_n_clusters
107+
)
108+
109+
if lens_type == LENS_IDENTITY:
110+
lens = identity
111+
elif lens_type == LENS_PCA:
112+
lens = pca(n_components=lens_pca_n_components)
113+
elif lens_type == LENS_UMAP:
114+
lens = umap(n_components=lens_umap_n_components)
115+
116+
if cover_type == COVER_CUBICAL:
117+
cover = CubicalCover(n_intervals=cover_cubical_n_intervals)
118+
elif cover_type == COVER_BALL:
119+
cover = BallCover(overlap_fraction=cover_cubical_overlap_frac)
120+
elif cover_type == COVER_KNN:
121+
cover = KNNCover(n_neighbors=cover_knn_neighbors)
122+
else:
123+
logger.error(f"Unknown cover type: {cover_type}")
124+
return
125+
126+
if clustering_type == CLUSTERING_TRIVIAL:
127+
clustering = TrivialClustering()
128+
elif clustering_type == CLUSTERING_KMEANS:
129+
clustering = KMeans(
130+
n_clusters=clustering_kmeans_n_clusters,
131+
random_state=RANDOM_SEED,
132+
)
133+
elif clustering_type == CLUSTERING_DBSCAN:
134+
clustering = DBSCAN(
135+
eps=clustering_dbscan_eps,
136+
min_samples=clustering_dbscan_min_samples,
137+
random_state=RANDOM_SEED,
138+
)
139+
elif clustering_type == CLUSTERING_AGGLOMERATIVE:
140+
clustering = AgglomerativeClustering(
141+
n_clusters=clustering_agglomerative_n_clusters,
142+
random_state=RANDOM_SEED,
143+
)
144+
else:
145+
logger.error(f"Unknown clustering type: {clustering_type}")
146+
return
147+
148+
mapper = MapperAlgorithm(cover=cover, clustering=clustering)
149+
X = df.to_numpy()
150+
y = lens(X)
151+
mapper_graph = mapper.fit_transform(X, y)
152+
mapper_fig = MapperPlot(
153+
mapper_graph,
154+
dim=3,
155+
).plot_plotly(
156+
colors=X,
157+
height=800,
158+
node_size=[0.0, 0.5, 1.0],
159+
)
160+
logger.info("Mapper run completed successfully.")
161+
return mapper_fig
162+
163+
164+
@ui.page("/")
165+
def index():
166+
storage = app.storage.client
167+
168+
def upload_file(file):
169+
if file is not None:
170+
df = pd.read_csv(file.content)
171+
storage["df"] = df
172+
173+
logger.info("File uploaded successfully.")
174+
logger.info(f"{df.head()}")
175+
else:
176+
logger.info("No file uploaded.")
177+
178+
def load_file():
179+
df = storage.get("df")
180+
if df is not None:
181+
logger.info("Data loaded successfully.")
182+
else:
183+
logger.warning("No data found. Please upload a file first.")
184+
185+
def get_mapper_config():
186+
return MapperConfig(
187+
lens_type=str(lens.value) if lens.value else LENS_PCA,
188+
cover_type=str(cover.value) if cover.value else COVER_CUBICAL,
189+
clustering_type=(
190+
str(clustering.value) if clustering.value else CLUSTERING_TRIVIAL
191+
),
192+
lens_pca_n_components=(
193+
int(pca_n_components.value)
194+
if pca_n_components.value
195+
else LENS_PCA_N_COMPONENTS
196+
),
197+
lens_umap_n_components=(
198+
int(umap_n_components.value)
199+
if umap_n_components.value
200+
else LENS_UMAP_N_COMPONENTS
201+
),
202+
cover_cubical_n_intervals=(
203+
int(n_intervals.value)
204+
if n_intervals.value
205+
else COVER_CUBICAL_N_INTERVALS
206+
),
207+
cover_cubical_overlap_frac=(
208+
float(overlap_frac.value)
209+
if overlap_frac.value
210+
else COVER_CUBICAL_OVERLAP_FRAC
211+
),
212+
cover_knn_neighbors=(
213+
int(neighbors.value) if neighbors.value else COVER_KNN_NEIGHBORS
214+
),
215+
clustering_kmeans_n_clusters=(
216+
int(kmeans_n_clusters.value)
217+
if kmeans_n_clusters.value
218+
else CLUSTERING_KMEANS_N_CLUSTERS
219+
),
220+
clustering_dbscan_eps=(
221+
float(dbscan_eps.value) if dbscan_eps.value else CLUSTERING_DBSCAN_EPS
222+
),
223+
clustering_dbscan_min_samples=(
224+
int(dbscan_min_samples.value)
225+
if dbscan_min_samples.value
226+
else CLUSTERING_DBSCAN_MIN_SAMPLES
227+
),
228+
clustering_agglomerative_n_clusters=(
229+
int(agglomerative_n_clusters.value)
230+
if agglomerative_n_clusters.value
231+
else CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
232+
),
233+
)
234+
235+
async def async_run_mapper():
236+
df = storage.get("df")
237+
mapper_config = get_mapper_config()
238+
mapper_fig = await run.cpu_bound(run_mapper, df, **asdict(mapper_config))
239+
mapper_fig.layout.width = None
240+
mapper_fig.layout.autosize = True
241+
plot_container.clear()
242+
with plot_container:
243+
logger.info("Displaying Mapper plot.")
244+
ui.plotly(mapper_fig)
245+
246+
with ui.row().classes("w-full h-screen m-0 p-0 gap-0 overflow-hidden"):
247+
248+
with ui.column().classes("w-64 h-full m-0 p-0"):
249+
with ui.card().tight().classes("w-full"):
250+
ui.upload(
251+
on_upload=upload_file,
252+
auto_upload=True,
253+
label="Upload CSV File",
254+
).classes("w-full")
255+
with ui.card_section().classes("w-full"):
256+
ui.button("Load", on_click=load_file).classes("w-full")
257+
258+
with ui.card().classes("w-full"):
259+
lens = ui.select(
260+
options=[
261+
LENS_IDENTITY,
262+
LENS_PCA,
263+
LENS_UMAP,
264+
],
265+
label="Lens",
266+
value=LENS_PCA,
267+
).classes("w-full")
268+
269+
pca_n_components = ui.number(
270+
label="PCA Components",
271+
value=LENS_PCA_N_COMPONENTS,
272+
).classes("w-full")
273+
pca_n_components.bind_visibility_from(
274+
target_object=lens,
275+
target_name="value",
276+
value=LENS_PCA,
277+
)
278+
279+
umap_n_components = ui.number(
280+
label="UMAP Components",
281+
value=LENS_UMAP_N_COMPONENTS,
282+
).classes("w-full")
283+
umap_n_components.bind_visibility_from(
284+
target_object=lens,
285+
target_name="value",
286+
value=LENS_UMAP,
287+
)
288+
289+
with ui.card().classes("w-full"):
290+
cover = ui.select(
291+
options=[
292+
COVER_CUBICAL,
293+
COVER_BALL,
294+
COVER_KNN,
295+
],
296+
label="Cover",
297+
value=COVER_CUBICAL,
298+
).classes("w-full")
299+
300+
n_intervals = ui.number(
301+
label="Number of Intervals",
302+
value=COVER_CUBICAL_N_INTERVALS,
303+
).classes("w-full")
304+
n_intervals.bind_visibility_from(
305+
target_object=cover,
306+
target_name="value",
307+
value=COVER_CUBICAL,
308+
)
309+
310+
overlap_frac = ui.number(
311+
label="Ball Radius",
312+
value=COVER_BALL_RADIUS,
313+
).classes("w-full")
314+
overlap_frac.bind_visibility_from(
315+
target_object=cover,
316+
target_name="value",
317+
value=COVER_BALL,
318+
)
319+
320+
neighbors = ui.number(
321+
label="Number of Neighbors",
322+
value=COVER_KNN_NEIGHBORS,
323+
).classes("w-full")
324+
neighbors.bind_visibility_from(
325+
target_object=cover,
326+
target_name="value",
327+
value=COVER_KNN,
328+
)
329+
330+
with ui.card().classes("w-full"):
331+
clustering = ui.select(
332+
options=[
333+
CLUSTERING_TRIVIAL,
334+
CLUSTERING_KMEANS,
335+
CLUSTERING_DBSCAN,
336+
CLUSTERING_AGGLOMERATIVE,
337+
],
338+
label="Clustering",
339+
value=CLUSTERING_TRIVIAL,
340+
).classes("w-full")
341+
342+
kmeans_n_clusters = ui.number(
343+
label="Number of Clusters",
344+
value=CLUSTERING_KMEANS_N_CLUSTERS,
345+
).classes("w-full")
346+
kmeans_n_clusters.bind_visibility_from(
347+
target_object=clustering,
348+
target_name="value",
349+
value=CLUSTERING_KMEANS,
350+
)
351+
352+
dbscan_eps = ui.number(
353+
label="Epsilon",
354+
value=CLUSTERING_DBSCAN_EPS,
355+
).classes("w-full")
356+
dbscan_eps.bind_visibility_from(
357+
target_object=clustering,
358+
target_name="value",
359+
value=CLUSTERING_DBSCAN,
360+
)
361+
dbscan_min_samples = ui.number(
362+
label="Min Samples",
363+
value=CLUSTERING_DBSCAN_MIN_SAMPLES,
364+
).classes("w-full")
365+
dbscan_min_samples.bind_visibility_from(
366+
target_object=clustering,
367+
target_name="value",
368+
value=CLUSTERING_DBSCAN,
369+
)
370+
371+
agglomerative_n_clusters = ui.number(
372+
label="Number of Clusters",
373+
value=CLUSTERING_AGGLOMERATIVE_N_CLUSTERS,
374+
).classes("w-full")
375+
agglomerative_n_clusters.bind_visibility_from(
376+
target_object=clustering,
377+
target_name="value",
378+
value=CLUSTERING_AGGLOMERATIVE,
379+
)
380+
381+
ui.button(
382+
"Run Mapper",
383+
on_click=async_run_mapper,
384+
).classes("w-full")
385+
386+
with ui.column().classes("flex-1 h-full overflow-hidden m-0 p-0"):
387+
plot_container = ui.element("div").classes("w-full h-full")
388+
389+
390+
ui.run(storage_secret="secret")

0 commit comments

Comments
 (0)