Skip to content

Commit db38258

Browse files
committed
Improved UI. Added storage
1 parent 67a2dc3 commit db38258

1 file changed

Lines changed: 81 additions & 48 deletions

File tree

src/tdamapper/app.py

Lines changed: 81 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,41 @@
55
import numpy as np
66
import pandas as pd
77
import plotly.graph_objs as go
8-
from nicegui import run, ui
8+
from nicegui import app, run, ui
99
from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans
1010
from sklearn.datasets import fetch_openml, load_digits, load_iris
1111
from sklearn.decomposition import PCA
1212
from umap import UMAP
1313

14-
from tdamapper.core import TrivialClustering, TrivialCover
14+
from tdamapper.core import Cover, TrivialClustering, TrivialCover
1515
from tdamapper.cover import BallCover, CubicalCover, KNNCover
1616
from tdamapper.learn import MapperAlgorithm
1717
from tdamapper.plot import MapperPlot
1818

19+
RANDOM_STATE = 42
1920
LENS_IDENTITY = "Identity"
2021
LENS_PCA = "PCA"
22+
LENS_PCA_N_COMPONENTS = 2
2123
LENS_UMAP = "UMAP"
24+
LENS_UMAP_N_COMPONENTS = 2
2225

2326
COVER_TRIVIAL = "Trivial"
2427
COVER_CUBICAL = "Cubical"
28+
COVER_CUBICAL_N_INTERVALS = 2
29+
COVER_CUBICAL_OVERLAP_FRAC = 0.25
2530
COVER_BALL = "Ball"
31+
COVER_BALL_RADIUS = 100.0
2632
COVER_KNN = "KNN"
33+
COVER_KNN_NEIGHBORS = 10
2734

2835
CLUSTERING_TRIVIAL = "Trivial"
2936
CLUSTERING_KMEANS = "KMeans"
37+
CLUSTERING_KMEANS_N_CLUSTERS = 2
3038
CLUSTERING_AGGLOMERATIVE = "Agglomerative"
39+
CLUSTERING_AGGLOMERATIVE_N_CLUSTERS = 2
3140
CLUSTERING_DBSCAN = "DBSCAN"
41+
CLUSTERING_DBSCAN_EPS = 0.5
42+
CLUSTERING_DBSCAN_MIN_SAMPLES = 5
3243

3344
DATA_SOURCE_EXAMPLE = "Example"
3445
DATA_SOURCE_CSV = "CSV"
@@ -57,21 +68,21 @@ class State:
5768
source_csv: str = "tmp/data.csv"
5869
source_openml: str = SOURCE_OPENML
5970
lens_type: str = LENS_PCA
60-
lens_pca_n_components: int = 2
61-
lens_umap_n_components: int = 2
71+
lens_pca_n_components: int = LENS_PCA_N_COMPONENTS
72+
lens_umap_n_components: int = LENS_UMAP_N_COMPONENTS
6273
cover_type: str = COVER_CUBICAL
63-
cover_cubical_n_intervals: int = 2
64-
cover_cubical_overlap_frac: float = 0.25
65-
cover_knn_neighbors: int = 10
66-
cover_ball_radius: float = 100.0
74+
cover_cubical_n_intervals: int = COVER_CUBICAL_N_INTERVALS
75+
cover_cubical_overlap_frac: float = COVER_CUBICAL_OVERLAP_FRAC
76+
cover_knn_neighbors: int = COVER_KNN_NEIGHBORS
77+
cover_ball_radius: float = COVER_BALL_RADIUS
6778
clustering_type: str = CLUSTERING_TRIVIAL
68-
clustering_kmeans_n_clusters: int = 2
69-
clustering_dbscan_eps: float = 0.5
70-
clustering_dbscan_min_samples: int = 5
71-
clustering_agglomerative_n_clusters: int = 2
72-
draw_dim: int = 3
79+
clustering_kmeans_n_clusters: int = CLUSTERING_KMEANS_N_CLUSTERS
80+
clustering_dbscan_eps: float = CLUSTERING_DBSCAN_EPS
81+
clustering_dbscan_min_samples: int = CLUSTERING_DBSCAN_MIN_SAMPLES
82+
clustering_agglomerative_n_clusters: int = CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
83+
draw_dim: str = DRAW_3D
7384
draw_aggregation: str = DRAW_MEAN
74-
draw_iterations: int = 50
85+
draw_iterations: int = DRAW_ITERATIONS
7586

7687

7788
def _fix_data(data):
@@ -115,58 +126,69 @@ def get_dataset(state: State):
115126

116127

117128
def get_lens(state: State):
129+
def _pca(n):
130+
pca = PCA(n_components=n, random_state=RANDOM_STATE)
131+
return lambda X: pca.fit_transform(X)
132+
133+
def _umap(n):
134+
umap = UMAP(n_components=n, random_state=RANDOM_STATE)
135+
return lambda X: umap.fit_transform(X)
136+
137+
def _identity():
138+
return lambda X: X
139+
140+
lens = _pca(2)
118141
lens_type = state.lens_type
119142
if lens_type == LENS_IDENTITY:
120-
return lambda X: X
143+
lens = _identity()
121144
elif lens_type == LENS_PCA:
122-
n_components = int(state.lens_pca_n_components)
123-
pca = PCA(n_components=n_components, random_state=42)
124-
return lambda X: pca.fit_transform(X)
145+
lens = _pca(state.lens_pca_n_components)
125146
elif lens_type == LENS_UMAP:
126-
n_components = int(state.lens_umap_n_components)
127-
umap = UMAP(n_components=n_components, random_state=42)
128-
return lambda X: umap.fit_transform(X)
147+
lens = _umap(state.lens_umap_n_components)
129148
else:
130-
logger.error(f"Unknown lens type: {lens_type}")
131-
return None
149+
logger.error("Defaulting to PCA lens")
150+
return lens
132151

133152

134-
def get_cover(state: State):
153+
def get_cover(state: State) -> Cover:
135154
cover_type = state.cover_type
155+
cover: Cover = CubicalCover(n_intervals=2, overlap_frac=0.25)
136156
if cover_type == COVER_TRIVIAL:
137-
return TrivialCover()
157+
cover = TrivialCover()
138158
elif cover_type == COVER_CUBICAL:
139-
n_intervals = int(state.cover_cubical_n_intervals)
140-
overlap_frac = float(state.cover_cubical_overlap_frac)
141-
return CubicalCover(n_intervals=n_intervals, overlap_frac=overlap_frac)
159+
cover = CubicalCover(
160+
n_intervals=state.cover_cubical_n_intervals,
161+
overlap_frac=state.cover_cubical_overlap_frac,
162+
)
142163
elif cover_type == COVER_BALL:
143-
radius = float(state.cover_ball_radius)
144-
return BallCover(radius=radius)
164+
cover = BallCover(radius=state.cover_ball_radius)
145165
elif cover_type == COVER_KNN:
146-
neighbors = int(state.cover_knn_neighbors)
147-
return KNNCover(neighbors=neighbors)
166+
cover = KNNCover(neighbors=state.cover_knn_neighbors)
148167
else:
149-
logger.error(f"Unknown cover type: {cover_type}")
150-
return None
168+
logger.error("Defaulting to CubicalCover")
169+
return cover
151170

152171

153172
def get_clustering(state: State):
154173
clustering_type = state.clustering_type
155174
if clustering_type == CLUSTERING_TRIVIAL:
156175
return TrivialClustering()
157176
elif clustering_type == CLUSTERING_KMEANS:
158-
n_clusters = state.clustering_kmeans_n_clusters
159-
return KMeans(n_clusters=n_clusters, random_state=42)
177+
return KMeans(
178+
n_clusters=state.clustering_kmeans_n_clusters, random_state=RANDOM_STATE
179+
)
160180
elif clustering_type == CLUSTERING_DBSCAN:
161-
eps = state.clustering_dbscan_eps
162-
min_samples = state.clustering_dbscan_min_samples
163-
return DBSCAN(eps=eps, min_samples=min_samples)
181+
return DBSCAN(
182+
eps=state.clustering_dbscan_eps,
183+
min_samples=state.clustering_dbscan_min_samples,
184+
)
164185
elif clustering_type == CLUSTERING_AGGLOMERATIVE:
165-
n_clusters = state.clustering_agglomerative_n_clusters
166-
return AgglomerativeClustering(n_clusters=n_clusters)
186+
return AgglomerativeClustering(
187+
n_clusters=state.clustering_agglomerative_n_clusters
188+
)
167189
else:
168-
logger.error(f"Unknown clustering type: {clustering_type}")
169-
return None
190+
logger.error("Defaulting to TrivialClustering")
191+
return TrivialClustering()
170192

171193

172194
def compute_mapper(**kwargs):
@@ -203,9 +225,11 @@ def compute_mapper(**kwargs):
203225
logger.info(f"Mapper configuration: {mapper_algo}")
204226
mapper_graph = mapper_algo.fit_transform(X, y)
205227

228+
dim = 3 if state.draw_dim == DRAW_3D else 2
229+
206230
mapper_plot = MapperPlot(
207231
mapper_graph,
208-
dim=state.draw_dim,
232+
dim=dim,
209233
iterations=state.draw_iterations,
210234
seed=42,
211235
)
@@ -497,7 +521,7 @@ async def update(self, _=None):
497521
self.state.clustering_agglomerative_n_clusters = int(
498522
self.clustering_agglomerative_n_clusters.value
499523
)
500-
self.state.draw_dim = 3 if self.draw_3d.value == DRAW_3D else 2
524+
self.state.draw_dim = str(self.draw_3d.value)
501525
self.state.draw_iterations = int(self.draw_iterations.value)
502526
self.state.draw_aggregation = self.draw_aggregation.value
503527
await self.render()
@@ -508,6 +532,9 @@ async def render(self):
508532
**asdict(self.state),
509533
)
510534

535+
self.storage["mapper_graph"] = mapper_graph
536+
self.storage["mapper_fig"] = mapper_fig
537+
511538
self.plot_container.clear()
512539
with self.plot_container:
513540
ui.plotly(mapper_fig)
@@ -517,7 +544,8 @@ async def update_csv_handler(self, file):
517544
await self.update_dataset_handler()
518545
await self.update_plot_handler()
519546

520-
def __init__(self):
547+
def __init__(self, storage):
548+
self.storage = storage
521549
self.state = State()
522550
with ui.row().classes("w-full h-screen m-0 p-0 gap-0 overflow-hidden"):
523551
with ui.column().classes("w-64 h-full m-0 p-0"):
@@ -544,9 +572,14 @@ def __init__(self):
544572
ui.plotly(mapper_fig)
545573

546574

575+
@ui.page("/")
576+
def main_page():
577+
storage = app.storage.client
578+
App(storage=storage)
579+
580+
547581
def main():
548-
App()
549-
ui.run()
582+
ui.run(storage_secret="tdamapper_secret", title="TDA Mapper App", port=8080)
550583

551584

552585
if __name__ in {"__main__", "__mp_main__", "tdamapper.app"}:

0 commit comments

Comments
 (0)