|
| 1 | +import logging |
| 2 | +from dataclasses import asdict, dataclass |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | +from nicegui import app, run, ui |
| 6 | +from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans |
| 7 | +from sklearn.decomposition import PCA |
| 8 | +from umap import UMAP |
| 9 | + |
| 10 | +from tdamapper.clustering import TrivialClustering |
| 11 | +from tdamapper.cover import BallCover, CubicalCover, KNNCover |
| 12 | +from tdamapper.learn import MapperAlgorithm |
| 13 | +from tdamapper.plot import MapperPlot |
| 14 | + |
| 15 | +logging.basicConfig(level=logging.INFO) |
| 16 | +logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +LENS_IDENTITY = "Identity" |
| 20 | +LENS_PCA = "PCA" |
| 21 | +LENS_UMAP = "UMAP" |
| 22 | + |
| 23 | +LENS_PCA_N_COMPONENTS = 2 |
| 24 | +LENS_UMAP_N_COMPONENTS = 2 |
| 25 | + |
| 26 | +COVER_CUBICAL = "Cubical Cover" |
| 27 | +COVER_BALL = "Ball Cover" |
| 28 | +COVER_KNN = "KNN Cover" |
| 29 | + |
| 30 | +CLUSTERING_TRIVIAL = "Skip" |
| 31 | +CLUSTERING_KMEANS = "KMeans" |
| 32 | +CLUSTERING_DBSCAN = "DBSCAN" |
| 33 | +CLUSTERING_AGGLOMERATIVE = "Agglomerative Clustering" |
| 34 | + |
| 35 | +COVER_CUBICAL_N_INTERVALS = 10 |
| 36 | +COVER_CUBICAL_OVERLAP_FRAC = 0.25 |
| 37 | +COVER_KNN_NEIGHBORS = 10 |
| 38 | +COVER_BALL_RADIUS = 100.0 |
| 39 | + |
| 40 | +CLUSTERING_KMEANS_N_CLUSTERS = 2 |
| 41 | +CLUSTERING_DBSCAN_EPS = 0.5 |
| 42 | +CLUSTERING_DBSCAN_MIN_SAMPLES = 5 |
| 43 | +CLUSTERING_AGGLOMERATIVE_N_CLUSTERS = 2 |
| 44 | + |
| 45 | +RANDOM_SEED = 42 |
| 46 | + |
| 47 | + |
| 48 | +@dataclass |
| 49 | +class MapperConfig: |
| 50 | + lens_type: str = LENS_PCA |
| 51 | + cover_type: str = COVER_CUBICAL |
| 52 | + clustering_type: str = CLUSTERING_TRIVIAL |
| 53 | + lens_pca_n_components: int = LENS_PCA_N_COMPONENTS |
| 54 | + lens_umap_n_components: int = LENS_UMAP_N_COMPONENTS |
| 55 | + cover_cubical_n_intervals: int = COVER_CUBICAL_N_INTERVALS |
| 56 | + cover_cubical_overlap_frac: float = COVER_CUBICAL_OVERLAP_FRAC |
| 57 | + cover_knn_neighbors: int = COVER_KNN_NEIGHBORS |
| 58 | + clustering_kmeans_n_clusters: int = CLUSTERING_KMEANS_N_CLUSTERS |
| 59 | + clustering_dbscan_eps: float = CLUSTERING_DBSCAN_EPS |
| 60 | + clustering_dbscan_min_samples: int = CLUSTERING_DBSCAN_MIN_SAMPLES |
| 61 | + clustering_agglomerative_n_clusters: int = CLUSTERING_AGGLOMERATIVE_N_CLUSTERS |
| 62 | + |
| 63 | + |
| 64 | +def identity(X): |
| 65 | + return X |
| 66 | + |
| 67 | + |
| 68 | +def pca(n_components): |
| 69 | + |
| 70 | + def _pca(X): |
| 71 | + pca_model = PCA(n_components=n_components, random_state=RANDOM_SEED) |
| 72 | + return pca_model.fit_transform(X) |
| 73 | + |
| 74 | + return _pca |
| 75 | + |
| 76 | + |
| 77 | +def umap(n_components): |
| 78 | + |
| 79 | + def _umap(X): |
| 80 | + umap_model = umap.UMAP(n_components=n_components, random_state=RANDOM_SEED) |
| 81 | + return umap_model.fit_transform(X) |
| 82 | + |
| 83 | + return _umap |
| 84 | + |
| 85 | + |
| 86 | +def run_mapper(df, **kwargs): |
| 87 | + if df is None: |
| 88 | + logger.error("No data found. Please upload a file first.") |
| 89 | + return |
| 90 | + logger.info("Computing Mapper.") |
| 91 | + |
| 92 | + mapper_config = MapperConfig(**kwargs) |
| 93 | + |
| 94 | + lens_type = mapper_config.lens_type |
| 95 | + cover_type = mapper_config.cover_type |
| 96 | + clustering_type = mapper_config.clustering_type |
| 97 | + lens_pca_n_components = mapper_config.lens_pca_n_components |
| 98 | + lens_umap_n_components = mapper_config.lens_umap_n_components |
| 99 | + cover_cubical_n_intervals = mapper_config.cover_cubical_n_intervals |
| 100 | + cover_cubical_overlap_frac = mapper_config.cover_cubical_overlap_frac |
| 101 | + cover_knn_neighbors = mapper_config.cover_knn_neighbors |
| 102 | + clustering_kmeans_n_clusters = mapper_config.clustering_kmeans_n_clusters |
| 103 | + clustering_dbscan_eps = mapper_config.clustering_dbscan_eps |
| 104 | + clustering_dbscan_min_samples = mapper_config.clustering_dbscan_min_samples |
| 105 | + clustering_agglomerative_n_clusters = ( |
| 106 | + mapper_config.clustering_agglomerative_n_clusters |
| 107 | + ) |
| 108 | + |
| 109 | + if lens_type == LENS_IDENTITY: |
| 110 | + lens = identity |
| 111 | + elif lens_type == LENS_PCA: |
| 112 | + lens = pca(n_components=lens_pca_n_components) |
| 113 | + elif lens_type == LENS_UMAP: |
| 114 | + lens = umap(n_components=lens_umap_n_components) |
| 115 | + |
| 116 | + if cover_type == COVER_CUBICAL: |
| 117 | + cover = CubicalCover(n_intervals=cover_cubical_n_intervals) |
| 118 | + elif cover_type == COVER_BALL: |
| 119 | + cover = BallCover(overlap_fraction=cover_cubical_overlap_frac) |
| 120 | + elif cover_type == COVER_KNN: |
| 121 | + cover = KNNCover(n_neighbors=cover_knn_neighbors) |
| 122 | + else: |
| 123 | + logger.error(f"Unknown cover type: {cover_type}") |
| 124 | + return |
| 125 | + |
| 126 | + if clustering_type == CLUSTERING_TRIVIAL: |
| 127 | + clustering = TrivialClustering() |
| 128 | + elif clustering_type == CLUSTERING_KMEANS: |
| 129 | + clustering = KMeans( |
| 130 | + n_clusters=clustering_kmeans_n_clusters, |
| 131 | + random_state=RANDOM_SEED, |
| 132 | + ) |
| 133 | + elif clustering_type == CLUSTERING_DBSCAN: |
| 134 | + clustering = DBSCAN( |
| 135 | + eps=clustering_dbscan_eps, |
| 136 | + min_samples=clustering_dbscan_min_samples, |
| 137 | + random_state=RANDOM_SEED, |
| 138 | + ) |
| 139 | + elif clustering_type == CLUSTERING_AGGLOMERATIVE: |
| 140 | + clustering = AgglomerativeClustering( |
| 141 | + n_clusters=clustering_agglomerative_n_clusters, |
| 142 | + random_state=RANDOM_SEED, |
| 143 | + ) |
| 144 | + else: |
| 145 | + logger.error(f"Unknown clustering type: {clustering_type}") |
| 146 | + return |
| 147 | + |
| 148 | + mapper = MapperAlgorithm(cover=cover, clustering=clustering) |
| 149 | + X = df.to_numpy() |
| 150 | + y = lens(X) |
| 151 | + mapper_graph = mapper.fit_transform(X, y) |
| 152 | + mapper_fig = MapperPlot( |
| 153 | + mapper_graph, |
| 154 | + dim=3, |
| 155 | + ).plot_plotly( |
| 156 | + colors=X, |
| 157 | + height=800, |
| 158 | + node_size=[0.0, 0.5, 1.0], |
| 159 | + ) |
| 160 | + logger.info("Mapper run completed successfully.") |
| 161 | + return mapper_fig |
| 162 | + |
| 163 | + |
| 164 | +@ui.page("/") |
| 165 | +def index(): |
| 166 | + storage = app.storage.client |
| 167 | + |
| 168 | + def upload_file(file): |
| 169 | + if file is not None: |
| 170 | + df = pd.read_csv(file.content) |
| 171 | + storage["df"] = df |
| 172 | + |
| 173 | + logger.info("File uploaded successfully.") |
| 174 | + logger.info(f"{df.head()}") |
| 175 | + else: |
| 176 | + logger.info("No file uploaded.") |
| 177 | + |
| 178 | + def load_file(): |
| 179 | + df = storage.get("df") |
| 180 | + if df is not None: |
| 181 | + logger.info("Data loaded successfully.") |
| 182 | + else: |
| 183 | + logger.warning("No data found. Please upload a file first.") |
| 184 | + |
| 185 | + def get_mapper_config(): |
| 186 | + return MapperConfig( |
| 187 | + lens_type=str(lens.value) if lens.value else LENS_PCA, |
| 188 | + cover_type=str(cover.value) if cover.value else COVER_CUBICAL, |
| 189 | + clustering_type=( |
| 190 | + str(clustering.value) if clustering.value else CLUSTERING_TRIVIAL |
| 191 | + ), |
| 192 | + lens_pca_n_components=( |
| 193 | + int(pca_n_components.value) |
| 194 | + if pca_n_components.value |
| 195 | + else LENS_PCA_N_COMPONENTS |
| 196 | + ), |
| 197 | + lens_umap_n_components=( |
| 198 | + int(umap_n_components.value) |
| 199 | + if umap_n_components.value |
| 200 | + else LENS_UMAP_N_COMPONENTS |
| 201 | + ), |
| 202 | + cover_cubical_n_intervals=( |
| 203 | + int(n_intervals.value) |
| 204 | + if n_intervals.value |
| 205 | + else COVER_CUBICAL_N_INTERVALS |
| 206 | + ), |
| 207 | + cover_cubical_overlap_frac=( |
| 208 | + float(overlap_frac.value) |
| 209 | + if overlap_frac.value |
| 210 | + else COVER_CUBICAL_OVERLAP_FRAC |
| 211 | + ), |
| 212 | + cover_knn_neighbors=( |
| 213 | + int(neighbors.value) if neighbors.value else COVER_KNN_NEIGHBORS |
| 214 | + ), |
| 215 | + clustering_kmeans_n_clusters=( |
| 216 | + int(kmeans_n_clusters.value) |
| 217 | + if kmeans_n_clusters.value |
| 218 | + else CLUSTERING_KMEANS_N_CLUSTERS |
| 219 | + ), |
| 220 | + clustering_dbscan_eps=( |
| 221 | + float(dbscan_eps.value) if dbscan_eps.value else CLUSTERING_DBSCAN_EPS |
| 222 | + ), |
| 223 | + clustering_dbscan_min_samples=( |
| 224 | + int(dbscan_min_samples.value) |
| 225 | + if dbscan_min_samples.value |
| 226 | + else CLUSTERING_DBSCAN_MIN_SAMPLES |
| 227 | + ), |
| 228 | + clustering_agglomerative_n_clusters=( |
| 229 | + int(agglomerative_n_clusters.value) |
| 230 | + if agglomerative_n_clusters.value |
| 231 | + else CLUSTERING_AGGLOMERATIVE_N_CLUSTERS |
| 232 | + ), |
| 233 | + ) |
| 234 | + |
| 235 | + async def async_run_mapper(): |
| 236 | + df = storage.get("df") |
| 237 | + mapper_config = get_mapper_config() |
| 238 | + mapper_fig = await run.cpu_bound(run_mapper, df, **asdict(mapper_config)) |
| 239 | + mapper_fig.layout.width = None |
| 240 | + mapper_fig.layout.autosize = True |
| 241 | + plot_container.clear() |
| 242 | + with plot_container: |
| 243 | + logger.info("Displaying Mapper plot.") |
| 244 | + ui.plotly(mapper_fig) |
| 245 | + |
| 246 | + with ui.row().classes("w-full h-screen m-0 p-0 gap-0 overflow-hidden"): |
| 247 | + |
| 248 | + with ui.column().classes("w-64 h-full m-0 p-0"): |
| 249 | + with ui.card().tight().classes("w-full"): |
| 250 | + ui.upload( |
| 251 | + on_upload=upload_file, |
| 252 | + auto_upload=True, |
| 253 | + label="Upload CSV File", |
| 254 | + ).classes("w-full") |
| 255 | + with ui.card_section().classes("w-full"): |
| 256 | + ui.button("Load", on_click=load_file).classes("w-full") |
| 257 | + |
| 258 | + with ui.card().classes("w-full"): |
| 259 | + lens = ui.select( |
| 260 | + options=[ |
| 261 | + LENS_IDENTITY, |
| 262 | + LENS_PCA, |
| 263 | + LENS_UMAP, |
| 264 | + ], |
| 265 | + label="Lens", |
| 266 | + value=LENS_PCA, |
| 267 | + ).classes("w-full") |
| 268 | + |
| 269 | + pca_n_components = ui.number( |
| 270 | + label="PCA Components", |
| 271 | + value=LENS_PCA_N_COMPONENTS, |
| 272 | + ).classes("w-full") |
| 273 | + pca_n_components.bind_visibility_from( |
| 274 | + target_object=lens, |
| 275 | + target_name="value", |
| 276 | + value=LENS_PCA, |
| 277 | + ) |
| 278 | + |
| 279 | + umap_n_components = ui.number( |
| 280 | + label="UMAP Components", |
| 281 | + value=LENS_UMAP_N_COMPONENTS, |
| 282 | + ).classes("w-full") |
| 283 | + umap_n_components.bind_visibility_from( |
| 284 | + target_object=lens, |
| 285 | + target_name="value", |
| 286 | + value=LENS_UMAP, |
| 287 | + ) |
| 288 | + |
| 289 | + with ui.card().classes("w-full"): |
| 290 | + cover = ui.select( |
| 291 | + options=[ |
| 292 | + COVER_CUBICAL, |
| 293 | + COVER_BALL, |
| 294 | + COVER_KNN, |
| 295 | + ], |
| 296 | + label="Cover", |
| 297 | + value=COVER_CUBICAL, |
| 298 | + ).classes("w-full") |
| 299 | + |
| 300 | + n_intervals = ui.number( |
| 301 | + label="Number of Intervals", |
| 302 | + value=COVER_CUBICAL_N_INTERVALS, |
| 303 | + ).classes("w-full") |
| 304 | + n_intervals.bind_visibility_from( |
| 305 | + target_object=cover, |
| 306 | + target_name="value", |
| 307 | + value=COVER_CUBICAL, |
| 308 | + ) |
| 309 | + |
| 310 | + overlap_frac = ui.number( |
| 311 | + label="Ball Radius", |
| 312 | + value=COVER_BALL_RADIUS, |
| 313 | + ).classes("w-full") |
| 314 | + overlap_frac.bind_visibility_from( |
| 315 | + target_object=cover, |
| 316 | + target_name="value", |
| 317 | + value=COVER_BALL, |
| 318 | + ) |
| 319 | + |
| 320 | + neighbors = ui.number( |
| 321 | + label="Number of Neighbors", |
| 322 | + value=COVER_KNN_NEIGHBORS, |
| 323 | + ).classes("w-full") |
| 324 | + neighbors.bind_visibility_from( |
| 325 | + target_object=cover, |
| 326 | + target_name="value", |
| 327 | + value=COVER_KNN, |
| 328 | + ) |
| 329 | + |
| 330 | + with ui.card().classes("w-full"): |
| 331 | + clustering = ui.select( |
| 332 | + options=[ |
| 333 | + CLUSTERING_TRIVIAL, |
| 334 | + CLUSTERING_KMEANS, |
| 335 | + CLUSTERING_DBSCAN, |
| 336 | + CLUSTERING_AGGLOMERATIVE, |
| 337 | + ], |
| 338 | + label="Clustering", |
| 339 | + value=CLUSTERING_TRIVIAL, |
| 340 | + ).classes("w-full") |
| 341 | + |
| 342 | + kmeans_n_clusters = ui.number( |
| 343 | + label="Number of Clusters", |
| 344 | + value=CLUSTERING_KMEANS_N_CLUSTERS, |
| 345 | + ).classes("w-full") |
| 346 | + kmeans_n_clusters.bind_visibility_from( |
| 347 | + target_object=clustering, |
| 348 | + target_name="value", |
| 349 | + value=CLUSTERING_KMEANS, |
| 350 | + ) |
| 351 | + |
| 352 | + dbscan_eps = ui.number( |
| 353 | + label="Epsilon", |
| 354 | + value=CLUSTERING_DBSCAN_EPS, |
| 355 | + ).classes("w-full") |
| 356 | + dbscan_eps.bind_visibility_from( |
| 357 | + target_object=clustering, |
| 358 | + target_name="value", |
| 359 | + value=CLUSTERING_DBSCAN, |
| 360 | + ) |
| 361 | + dbscan_min_samples = ui.number( |
| 362 | + label="Min Samples", |
| 363 | + value=CLUSTERING_DBSCAN_MIN_SAMPLES, |
| 364 | + ).classes("w-full") |
| 365 | + dbscan_min_samples.bind_visibility_from( |
| 366 | + target_object=clustering, |
| 367 | + target_name="value", |
| 368 | + value=CLUSTERING_DBSCAN, |
| 369 | + ) |
| 370 | + |
| 371 | + agglomerative_n_clusters = ui.number( |
| 372 | + label="Number of Clusters", |
| 373 | + value=CLUSTERING_AGGLOMERATIVE_N_CLUSTERS, |
| 374 | + ).classes("w-full") |
| 375 | + agglomerative_n_clusters.bind_visibility_from( |
| 376 | + target_object=clustering, |
| 377 | + target_name="value", |
| 378 | + value=CLUSTERING_AGGLOMERATIVE, |
| 379 | + ) |
| 380 | + |
| 381 | + ui.button( |
| 382 | + "Run Mapper", |
| 383 | + on_click=async_run_mapper, |
| 384 | + ).classes("w-full") |
| 385 | + |
| 386 | + with ui.column().classes("flex-1 h-full overflow-hidden m-0 p-0"): |
| 387 | + plot_container = ui.element("div").classes("w-full h-full") |
| 388 | + |
| 389 | + |
| 390 | +ui.run(storage_secret="secret") |
0 commit comments