|
| 1 | +import numpy as np |
| 2 | +import plotly.graph_objs as go |
| 3 | +from nicegui import ui |
| 4 | +from nicegui.events import ValueChangeEventArguments |
| 5 | +from sklearn.cluster import AgglomerativeClustering, KMeans |
| 6 | +from sklearn.datasets import load_digits, make_circles |
| 7 | +from sklearn.decomposition import PCA |
| 8 | +from umap import UMAP |
| 9 | + |
| 10 | +from tdamapper.core import TrivialClustering, TrivialCover |
| 11 | +from tdamapper.cover import BallCover, CubicalCover, KNNCover |
| 12 | +from tdamapper.learn import MapperAlgorithm |
| 13 | +from tdamapper.plot import MapperPlot |
| 14 | + |
| 15 | + |
| 16 | +def mode(arr): |
| 17 | + values, counts = np.unique(arr, return_counts=True) |
| 18 | + max_count = np.max(counts) |
| 19 | + mode_values = values[counts == max_count] |
| 20 | + return np.nanmean(mode_values) |
| 21 | + |
| 22 | + |
| 23 | +def _identity(X): |
| 24 | + return X |
| 25 | + |
| 26 | + |
| 27 | +def _pca(n_components): |
| 28 | + pca = PCA(n_components=n_components, random_state=42) |
| 29 | + |
| 30 | + def _func(X): |
| 31 | + return pca.fit_transform(X) |
| 32 | + |
| 33 | + return _func |
| 34 | + |
| 35 | + |
| 36 | +def _umap(n_components): |
| 37 | + um = UMAP(n_components=n_components, random_state=42) |
| 38 | + |
| 39 | + def _func(X): |
| 40 | + return um.fit_transform(X) |
| 41 | + |
| 42 | + return _func |
| 43 | + |
| 44 | + |
| 45 | +class App: |
| 46 | + |
| 47 | + def build_lens(self): |
| 48 | + self.opt_lens_id = "Identity" |
| 49 | + self.opt_lens_pca = "PCA" |
| 50 | + self.opt_lens_umap = "UMAP" |
| 51 | + |
| 52 | + self.lens_type = ui.select( |
| 53 | + label="Lens type", |
| 54 | + options=[ |
| 55 | + self.opt_lens_id, |
| 56 | + self.opt_lens_pca, |
| 57 | + self.opt_lens_umap, |
| 58 | + ], |
| 59 | + value=self.opt_lens_pca, |
| 60 | + on_change=self.update, |
| 61 | + ).classes("w-full") |
| 62 | + self.pca_n_components = ui.number( |
| 63 | + label="PCA Components", |
| 64 | + min=1, |
| 65 | + max=10, |
| 66 | + value=2, |
| 67 | + on_change=self.update, |
| 68 | + ).classes("w-full") |
| 69 | + self.pca_n_components.bind_visibility_from( |
| 70 | + target_object=self.lens_type, |
| 71 | + target_name="value", |
| 72 | + value=self.opt_lens_pca, |
| 73 | + ) |
| 74 | + self.umap_n_components = ui.number( |
| 75 | + label="UMAP Components", |
| 76 | + min=1, |
| 77 | + max=10, |
| 78 | + value=2, |
| 79 | + on_change=self.update, |
| 80 | + ).classes("w-full") |
| 81 | + self.umap_n_components.bind_visibility_from( |
| 82 | + target_object=self.lens_type, |
| 83 | + target_name="value", |
| 84 | + value=self.opt_lens_umap, |
| 85 | + ) |
| 86 | + |
| 87 | + def build_cover(self): |
| 88 | + self.opt_cover_trivial = "Trivial" |
| 89 | + self.opt_cover_cubical = "Cubical" |
| 90 | + self.opt_cover_ball = "Ball" |
| 91 | + self.opt_cover_knn = "KNN" |
| 92 | + |
| 93 | + self.cover_type = ui.select( |
| 94 | + label="Cover type", |
| 95 | + options=[ |
| 96 | + self.opt_cover_trivial, |
| 97 | + self.opt_cover_cubical, |
| 98 | + self.opt_cover_ball, |
| 99 | + self.opt_cover_knn, |
| 100 | + ], |
| 101 | + value=self.opt_cover_cubical, |
| 102 | + on_change=self.update, |
| 103 | + ).classes("w-full") |
| 104 | + self.cover_cubical_n = ui.number( |
| 105 | + label="Intervals", |
| 106 | + min=1, |
| 107 | + max=10, |
| 108 | + value=2, |
| 109 | + on_change=self.update, |
| 110 | + ).classes("w-full") |
| 111 | + self.cover_cubical_n.bind_visibility_from( |
| 112 | + target_object=self.cover_type, |
| 113 | + target_name="value", |
| 114 | + value=self.opt_cover_cubical, |
| 115 | + ) |
| 116 | + self.cover_cubical_overlap = ui.number( |
| 117 | + label="Overlap", |
| 118 | + min=0.0, |
| 119 | + max=1.0, |
| 120 | + value=0.5, |
| 121 | + on_change=self.update, |
| 122 | + ).classes("w-full") |
| 123 | + self.cover_cubical_overlap.bind_visibility_from( |
| 124 | + target_object=self.cover_type, |
| 125 | + target_name="value", |
| 126 | + value=self.opt_cover_cubical, |
| 127 | + ) |
| 128 | + self.cover_ball_radius = ui.number( |
| 129 | + label="Radius", |
| 130 | + min=0.0, |
| 131 | + value=100.0, |
| 132 | + on_change=self.update, |
| 133 | + ).classes("w-full") |
| 134 | + self.cover_ball_radius.bind_visibility_from( |
| 135 | + target_object=self.cover_type, |
| 136 | + target_name="value", |
| 137 | + value=self.opt_cover_ball, |
| 138 | + ) |
| 139 | + self.cover_knn_k = ui.number( |
| 140 | + label="Neighbors", |
| 141 | + min=0, |
| 142 | + value=10, |
| 143 | + on_change=self.update, |
| 144 | + ).classes("w-full") |
| 145 | + self.cover_knn_k.bind_visibility_from( |
| 146 | + target_object=self.cover_type, |
| 147 | + target_name="value", |
| 148 | + value=self.opt_cover_knn, |
| 149 | + ) |
| 150 | + |
| 151 | + def build_clustering(self): |
| 152 | + self.opt_clustering_trivial = "Trivial" |
| 153 | + self.opt_clustering_kmeans = "KMeans" |
| 154 | + self.opt_clustering_agg = "Agglomerative" |
| 155 | + self.opt_clustering_dbscan = "DBSCAN" |
| 156 | + |
| 157 | + self.clustering_type = ui.select( |
| 158 | + label="Clustering type", |
| 159 | + options=[ |
| 160 | + self.opt_clustering_trivial, |
| 161 | + self.opt_clustering_kmeans, |
| 162 | + self.opt_clustering_agg, |
| 163 | + self.opt_clustering_dbscan, |
| 164 | + ], |
| 165 | + value=self.opt_clustering_trivial, |
| 166 | + on_change=self.update, |
| 167 | + ).classes("w-full") |
| 168 | + self.clustering_kmeans_k = ui.number( |
| 169 | + label="Clusters", |
| 170 | + min=1, |
| 171 | + value=2, |
| 172 | + on_change=self.update, |
| 173 | + ).classes("w-full") |
| 174 | + self.clustering_kmeans_k.bind_visibility_from( |
| 175 | + target_object=self.clustering_type, |
| 176 | + target_name="value", |
| 177 | + value=self.opt_clustering_kmeans, |
| 178 | + ) |
| 179 | + self.clustering_dbscan_eps = ui.number( |
| 180 | + label="Eps", |
| 181 | + min=0.0, |
| 182 | + value=0.5, |
| 183 | + on_change=self.update, |
| 184 | + ).classes("w-full") |
| 185 | + self.clustering_dbscan_eps.bind_visibility_from( |
| 186 | + target_object=self.clustering_type, |
| 187 | + target_name="value", |
| 188 | + value=self.opt_clustering_dbscan, |
| 189 | + ) |
| 190 | + self.clustering_dbscan_min_samples = ui.number( |
| 191 | + label="Min Samples", |
| 192 | + min=1, |
| 193 | + value=5, |
| 194 | + on_change=self.update, |
| 195 | + ).classes("w-full") |
| 196 | + self.clustering_dbscan_eps.bind_visibility_from( |
| 197 | + target_object=self.clustering_type, |
| 198 | + target_name="value", |
| 199 | + value=self.opt_clustering_dbscan, |
| 200 | + ) |
| 201 | + self.clustering_agg_n = ui.number( |
| 202 | + label="Clusters", |
| 203 | + min=1, |
| 204 | + value=2, |
| 205 | + on_change=self.update, |
| 206 | + ).classes("w-full") |
| 207 | + self.clustering_agg_n.bind_visibility_from( |
| 208 | + target_object=self.clustering_type, |
| 209 | + target_name="value", |
| 210 | + value=self.opt_clustering_agg, |
| 211 | + ) |
| 212 | + |
| 213 | + def build_plot(self): |
| 214 | + self.plot = ui.plotly(go.Figure()) |
| 215 | + |
| 216 | + def render_lens(self): |
| 217 | + print(f"Lens type: {self.lens_type.value}") |
| 218 | + if self.lens_type.value == self.opt_lens_id: |
| 219 | + return _identity |
| 220 | + elif self.lens_type.value == self.opt_lens_pca: |
| 221 | + n = int(self.pca_n_components.value) |
| 222 | + return _pca(n) |
| 223 | + elif self.lens_type.value == self.opt_lens_umap: |
| 224 | + n = int(self.umap_n_components.value) |
| 225 | + return _umap(n) |
| 226 | + |
| 227 | + def render_cover(self): |
| 228 | + if self.cover_type.value == self.opt_cover_trivial: |
| 229 | + return TrivialCover() |
| 230 | + elif self.cover_type.value == self.opt_cover_ball: |
| 231 | + r = float(self.cover_ball_radius.value) |
| 232 | + return BallCover(radius=r) |
| 233 | + elif self.cover_type.value == self.opt_cover_cubical: |
| 234 | + n = int(self.cover_cubical_n.value) |
| 235 | + overlap = float(self.cover_cubical_overlap.value) |
| 236 | + return CubicalCover(n_intervals=n, overlap_frac=overlap) |
| 237 | + elif self.cover_type.value == self.opt_cover_knn: |
| 238 | + k = int(self.cover_knn_k.value) |
| 239 | + return KNNCover(neighbors=k) |
| 240 | + |
| 241 | + def render_clustering(self): |
| 242 | + if self.clustering_type.value == self.opt_clustering_trivial: |
| 243 | + return TrivialClustering() |
| 244 | + elif self.clustering_type.value == self.opt_clustering_kmeans: |
| 245 | + k = int(self.clustering_kmeans_k.value) |
| 246 | + return KMeans(k) |
| 247 | + elif self.clustering_type.value == self.opt_clustering_dbscan: |
| 248 | + eps = float(self.clustering_dbscan_eps.value) |
| 249 | + min_samples = int(self.clustering_dbscan_min_samples.value) |
| 250 | + return DBSCAN(eps=eps) |
| 251 | + |
| 252 | + def update(self, _=None): |
| 253 | + X, labels = load_digits(return_X_y=True) |
| 254 | + lens = self.render_lens() |
| 255 | + if lens is None: |
| 256 | + print("Lens is None") |
| 257 | + return |
| 258 | + y = lens(X) |
| 259 | + |
| 260 | + cover = self.render_cover() |
| 261 | + if cover is None: |
| 262 | + print("Cover is None") |
| 263 | + return |
| 264 | + |
| 265 | + clustering = self.render_clustering() |
| 266 | + if clustering is None: |
| 267 | + print("Clustering is None") |
| 268 | + return |
| 269 | + |
| 270 | + mapper_algo = MapperAlgorithm( |
| 271 | + cover=cover, |
| 272 | + clustering=clustering, |
| 273 | + verbose=False, |
| 274 | + ) |
| 275 | + |
| 276 | + mapper_graph = mapper_algo.fit_transform(X, y) |
| 277 | + |
| 278 | + mapper_plot = MapperPlot(mapper_graph, dim=3, iterations=400, seed=42) |
| 279 | + |
| 280 | + mapper_fig = mapper_plot.plot_plotly( |
| 281 | + colors=labels, |
| 282 | + cmap=["jet", "viridis", "cividis"], |
| 283 | + agg=mode, |
| 284 | + title="mode of digits", |
| 285 | + width=800, |
| 286 | + height=800, |
| 287 | + node_size=0.5, |
| 288 | + ) |
| 289 | + if mapper_fig.layout.width is not None: |
| 290 | + mapper_fig.layout.width = None |
| 291 | + if not mapper_fig.layout.autosize: |
| 292 | + mapper_fig.layout.autosize = True |
| 293 | + mapper_fig.layout.autosize = True |
| 294 | + self.plot.update_figure(mapper_fig) |
| 295 | + |
| 296 | + def build(self): |
| 297 | + with ui.left_drawer().classes("w-[400px]"): |
| 298 | + self.build_lens() |
| 299 | + ui.separator() |
| 300 | + self.build_cover() |
| 301 | + ui.separator() |
| 302 | + self.build_clustering() |
| 303 | + self.build_plot() |
| 304 | + self.update() |
| 305 | + |
| 306 | + |
| 307 | +app = App() |
| 308 | +app.build() |
| 309 | +ui.run() |
0 commit comments