From a5d2674a8d9c1268e90cc2cdbfa8120958bdf6f4 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Thu, 29 May 2025 21:52:19 +0200 Subject: [PATCH 01/15] First implementation using nicegui --- app/nicegui_app.py | 309 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 309 insertions(+) create mode 100644 app/nicegui_app.py diff --git a/app/nicegui_app.py b/app/nicegui_app.py new file mode 100644 index 0000000..c58ca5e --- /dev/null +++ b/app/nicegui_app.py @@ -0,0 +1,309 @@ +import numpy as np +import plotly.graph_objs as go +from nicegui import ui +from nicegui.events import ValueChangeEventArguments +from sklearn.cluster import AgglomerativeClustering, KMeans +from sklearn.datasets import load_digits, make_circles +from sklearn.decomposition import PCA +from umap import UMAP + +from tdamapper.core import TrivialClustering, TrivialCover +from tdamapper.cover import BallCover, CubicalCover, KNNCover +from tdamapper.learn import MapperAlgorithm +from tdamapper.plot import MapperPlot + + +def mode(arr): + values, counts = np.unique(arr, return_counts=True) + max_count = np.max(counts) + mode_values = values[counts == max_count] + return np.nanmean(mode_values) + + +def _identity(X): + return X + + +def _pca(n_components): + pca = PCA(n_components=n_components, random_state=42) + + def _func(X): + return pca.fit_transform(X) + + return _func + + +def _umap(n_components): + um = UMAP(n_components=n_components, random_state=42) + + def _func(X): + return um.fit_transform(X) + + return _func + + +class App: + + def build_lens(self): + self.opt_lens_id = "Identity" + self.opt_lens_pca = "PCA" + self.opt_lens_umap = "UMAP" + + self.lens_type = ui.select( + label="Lens type", + options=[ + self.opt_lens_id, + self.opt_lens_pca, + self.opt_lens_umap, + ], + value=self.opt_lens_pca, + on_change=self.update, + ).classes("w-full") + self.pca_n_components = ui.number( + label="PCA Components", + min=1, + max=10, + value=2, + on_change=self.update, + ).classes("w-full") + self.pca_n_components.bind_visibility_from( + target_object=self.lens_type, + target_name="value", + value=self.opt_lens_pca, + ) + self.umap_n_components = ui.number( + label="UMAP Components", + min=1, + max=10, + value=2, + on_change=self.update, + ).classes("w-full") + self.umap_n_components.bind_visibility_from( + target_object=self.lens_type, + target_name="value", + value=self.opt_lens_umap, + ) + + def build_cover(self): + self.opt_cover_trivial = "Trivial" + self.opt_cover_cubical = "Cubical" + self.opt_cover_ball = "Ball" + self.opt_cover_knn = "KNN" + + self.cover_type = ui.select( + label="Cover type", + options=[ + self.opt_cover_trivial, + self.opt_cover_cubical, + self.opt_cover_ball, + self.opt_cover_knn, + ], + value=self.opt_cover_cubical, + on_change=self.update, + ).classes("w-full") + self.cover_cubical_n = ui.number( + label="Intervals", + min=1, + max=10, + value=2, + on_change=self.update, + ).classes("w-full") + self.cover_cubical_n.bind_visibility_from( + target_object=self.cover_type, + target_name="value", + value=self.opt_cover_cubical, + ) + self.cover_cubical_overlap = ui.number( + label="Overlap", + min=0.0, + max=1.0, + value=0.5, + on_change=self.update, + ).classes("w-full") + self.cover_cubical_overlap.bind_visibility_from( + target_object=self.cover_type, + target_name="value", + value=self.opt_cover_cubical, + ) + self.cover_ball_radius = ui.number( + label="Radius", + min=0.0, + value=100.0, + on_change=self.update, + ).classes("w-full") + self.cover_ball_radius.bind_visibility_from( + target_object=self.cover_type, + target_name="value", + value=self.opt_cover_ball, + ) + self.cover_knn_k = ui.number( + label="Neighbors", + min=0, + value=10, + on_change=self.update, + ).classes("w-full") + self.cover_knn_k.bind_visibility_from( + target_object=self.cover_type, + target_name="value", + value=self.opt_cover_knn, + ) + + def build_clustering(self): + self.opt_clustering_trivial = "Trivial" + self.opt_clustering_kmeans = "KMeans" + self.opt_clustering_agg = "Agglomerative" + self.opt_clustering_dbscan = "DBSCAN" + + self.clustering_type = ui.select( + label="Clustering type", + options=[ + self.opt_clustering_trivial, + self.opt_clustering_kmeans, + self.opt_clustering_agg, + self.opt_clustering_dbscan, + ], + value=self.opt_clustering_trivial, + on_change=self.update, + ).classes("w-full") + self.clustering_kmeans_k = ui.number( + label="Clusters", + min=1, + value=2, + on_change=self.update, + ).classes("w-full") + self.clustering_kmeans_k.bind_visibility_from( + target_object=self.clustering_type, + target_name="value", + value=self.opt_clustering_kmeans, + ) + self.clustering_dbscan_eps = ui.number( + label="Eps", + min=0.0, + value=0.5, + on_change=self.update, + ).classes("w-full") + self.clustering_dbscan_eps.bind_visibility_from( + target_object=self.clustering_type, + target_name="value", + value=self.opt_clustering_dbscan, + ) + self.clustering_dbscan_min_samples = ui.number( + label="Min Samples", + min=1, + value=5, + on_change=self.update, + ).classes("w-full") + self.clustering_dbscan_eps.bind_visibility_from( + target_object=self.clustering_type, + target_name="value", + value=self.opt_clustering_dbscan, + ) + self.clustering_agg_n = ui.number( + label="Clusters", + min=1, + value=2, + on_change=self.update, + ).classes("w-full") + self.clustering_agg_n.bind_visibility_from( + target_object=self.clustering_type, + target_name="value", + value=self.opt_clustering_agg, + ) + + def build_plot(self): + self.plot = ui.plotly(go.Figure()) + + def render_lens(self): + print(f"Lens type: {self.lens_type.value}") + if self.lens_type.value == self.opt_lens_id: + return _identity + elif self.lens_type.value == self.opt_lens_pca: + n = int(self.pca_n_components.value) + return _pca(n) + elif self.lens_type.value == self.opt_lens_umap: + n = int(self.umap_n_components.value) + return _umap(n) + + def render_cover(self): + if self.cover_type.value == self.opt_cover_trivial: + return TrivialCover() + elif self.cover_type.value == self.opt_cover_ball: + r = float(self.cover_ball_radius.value) + return BallCover(radius=r) + elif self.cover_type.value == self.opt_cover_cubical: + n = int(self.cover_cubical_n.value) + overlap = float(self.cover_cubical_overlap.value) + return CubicalCover(n_intervals=n, overlap_frac=overlap) + elif self.cover_type.value == self.opt_cover_knn: + k = int(self.cover_knn_k.value) + return KNNCover(neighbors=k) + + def render_clustering(self): + if self.clustering_type.value == self.opt_clustering_trivial: + return TrivialClustering() + elif self.clustering_type.value == self.opt_clustering_kmeans: + k = int(self.clustering_kmeans_k.value) + return KMeans(k) + elif self.clustering_type.value == self.opt_clustering_dbscan: + eps = float(self.clustering_dbscan_eps.value) + min_samples = int(self.clustering_dbscan_min_samples.value) + return DBSCAN(eps=eps) + + def update(self, _=None): + X, labels = load_digits(return_X_y=True) + lens = self.render_lens() + if lens is None: + print("Lens is None") + return + y = lens(X) + + cover = self.render_cover() + if cover is None: + print("Cover is None") + return + + clustering = self.render_clustering() + if clustering is None: + print("Clustering is None") + return + + mapper_algo = MapperAlgorithm( + cover=cover, + clustering=clustering, + verbose=False, + ) + + mapper_graph = mapper_algo.fit_transform(X, y) + + mapper_plot = MapperPlot(mapper_graph, dim=3, iterations=400, seed=42) + + mapper_fig = mapper_plot.plot_plotly( + colors=labels, + cmap=["jet", "viridis", "cividis"], + agg=mode, + title="mode of digits", + width=800, + height=800, + node_size=0.5, + ) + if mapper_fig.layout.width is not None: + mapper_fig.layout.width = None + if not mapper_fig.layout.autosize: + mapper_fig.layout.autosize = True + mapper_fig.layout.autosize = True + self.plot.update_figure(mapper_fig) + + def build(self): + with ui.left_drawer().classes("w-[400px]"): + self.build_lens() + ui.separator() + self.build_cover() + ui.separator() + self.build_clustering() + self.build_plot() + self.update() + + +app = App() +app.build() +ui.run() From ee5e8b50a36700a0ee25ee265e9a8280f63ea1e6 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Fri, 30 May 2025 08:30:00 +0200 Subject: [PATCH 02/15] Added ui elements for cover and clustering. Added async worker function for plot update --- app/nicegui_app.py | 227 ++++++++++++++++++++++++--------------------- 1 file changed, 121 insertions(+), 106 deletions(-) diff --git a/app/nicegui_app.py b/app/nicegui_app.py index c58ca5e..e7a22d3 100644 --- a/app/nicegui_app.py +++ b/app/nicegui_app.py @@ -1,9 +1,8 @@ import numpy as np import plotly.graph_objs as go -from nicegui import ui -from nicegui.events import ValueChangeEventArguments -from sklearn.cluster import AgglomerativeClustering, KMeans -from sklearn.datasets import load_digits, make_circles +from nicegui import run, ui +from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans +from sklearn.datasets import load_digits from sklearn.decomposition import PCA from umap import UMAP @@ -42,229 +41,238 @@ def _func(X): return _func +LENS_IDENTITY = "Identity" +LENS_PCA = "PCA" +LENS_UMAP = "UMAP" + +COVER_TRIVIAL = "Trivial" +COVER_CUBICAL = "Cubical" +COVER_BALL = "Ball" +COVER_KNN = "KNN" + +CLUSTERING_TRIVIAL = "Trivial" +CLUSTERING_KMEANS = "KMeans" +CLUSTERING_AGGLOMERATIVE = "Agglomerative" +CLUSTERING_DBSCAN = "DBSCAN" + + class App: def build_lens(self): - self.opt_lens_id = "Identity" - self.opt_lens_pca = "PCA" - self.opt_lens_umap = "UMAP" - self.lens_type = ui.select( label="Lens type", options=[ - self.opt_lens_id, - self.opt_lens_pca, - self.opt_lens_umap, + LENS_IDENTITY, + LENS_PCA, + LENS_UMAP, ], - value=self.opt_lens_pca, - on_change=self.update, + value=LENS_PCA, + on_change=self.update_handler, ).classes("w-full") self.pca_n_components = ui.number( label="PCA Components", min=1, max=10, value=2, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") self.pca_n_components.bind_visibility_from( target_object=self.lens_type, target_name="value", - value=self.opt_lens_pca, + value=LENS_PCA, ) self.umap_n_components = ui.number( label="UMAP Components", min=1, max=10, value=2, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") self.umap_n_components.bind_visibility_from( target_object=self.lens_type, target_name="value", - value=self.opt_lens_umap, + value=LENS_UMAP, ) def build_cover(self): - self.opt_cover_trivial = "Trivial" - self.opt_cover_cubical = "Cubical" - self.opt_cover_ball = "Ball" - self.opt_cover_knn = "KNN" self.cover_type = ui.select( label="Cover type", options=[ - self.opt_cover_trivial, - self.opt_cover_cubical, - self.opt_cover_ball, - self.opt_cover_knn, + COVER_TRIVIAL, + COVER_CUBICAL, + COVER_BALL, + COVER_KNN, ], - value=self.opt_cover_cubical, - on_change=self.update, + value=COVER_CUBICAL, + on_change=self.update_handler, ).classes("w-full") - self.cover_cubical_n = ui.number( + self.cover_cubical_n_intervals = ui.number( label="Intervals", min=1, - max=10, + max=100, value=2, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") - self.cover_cubical_n.bind_visibility_from( + self.cover_cubical_n_intervals.bind_visibility_from( target_object=self.cover_type, target_name="value", - value=self.opt_cover_cubical, + value=COVER_CUBICAL, ) - self.cover_cubical_overlap = ui.number( + self.cover_cubical_overlap_frac = ui.number( label="Overlap", min=0.0, max=1.0, value=0.5, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") - self.cover_cubical_overlap.bind_visibility_from( + self.cover_cubical_overlap_frac.bind_visibility_from( target_object=self.cover_type, target_name="value", - value=self.opt_cover_cubical, + value=COVER_CUBICAL, ) self.cover_ball_radius = ui.number( label="Radius", min=0.0, value=100.0, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") self.cover_ball_radius.bind_visibility_from( target_object=self.cover_type, target_name="value", - value=self.opt_cover_ball, + value=COVER_BALL, ) - self.cover_knn_k = ui.number( + self.cover_knn_neighbors = ui.number( label="Neighbors", min=0, value=10, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") - self.cover_knn_k.bind_visibility_from( + self.cover_knn_neighbors.bind_visibility_from( target_object=self.cover_type, target_name="value", - value=self.opt_cover_knn, + value=COVER_KNN, ) def build_clustering(self): - self.opt_clustering_trivial = "Trivial" - self.opt_clustering_kmeans = "KMeans" - self.opt_clustering_agg = "Agglomerative" - self.opt_clustering_dbscan = "DBSCAN" - self.clustering_type = ui.select( label="Clustering type", options=[ - self.opt_clustering_trivial, - self.opt_clustering_kmeans, - self.opt_clustering_agg, - self.opt_clustering_dbscan, + CLUSTERING_TRIVIAL, + CLUSTERING_KMEANS, + CLUSTERING_AGGLOMERATIVE, + CLUSTERING_DBSCAN, ], - value=self.opt_clustering_trivial, - on_change=self.update, + value=CLUSTERING_TRIVIAL, + on_change=self.update_handler, ).classes("w-full") - self.clustering_kmeans_k = ui.number( + self.clustering_kmeans_n_clusters = ui.number( label="Clusters", min=1, value=2, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") - self.clustering_kmeans_k.bind_visibility_from( + self.clustering_kmeans_n_clusters.bind_visibility_from( target_object=self.clustering_type, target_name="value", - value=self.opt_clustering_kmeans, + value=CLUSTERING_KMEANS, ) self.clustering_dbscan_eps = ui.number( label="Eps", min=0.0, value=0.5, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") self.clustering_dbscan_eps.bind_visibility_from( target_object=self.clustering_type, target_name="value", - value=self.opt_clustering_dbscan, + value=CLUSTERING_DBSCAN, ) self.clustering_dbscan_min_samples = ui.number( label="Min Samples", min=1, value=5, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") - self.clustering_dbscan_eps.bind_visibility_from( + self.clustering_dbscan_min_samples.bind_visibility_from( target_object=self.clustering_type, target_name="value", - value=self.opt_clustering_dbscan, + value=CLUSTERING_DBSCAN, ) - self.clustering_agg_n = ui.number( + self.clustering_agglomerative_n_clusters = ui.number( label="Clusters", min=1, value=2, - on_change=self.update, + on_change=self.update_handler, ).classes("w-full") - self.clustering_agg_n.bind_visibility_from( + self.clustering_agglomerative_n_clusters.bind_visibility_from( target_object=self.clustering_type, target_name="value", - value=self.opt_clustering_agg, + value=CLUSTERING_AGGLOMERATIVE, ) def build_plot(self): - self.plot = ui.plotly(go.Figure()) + fig = go.Figure() + fig.layout.width = None + fig.layout.autosize = True + self.plot_container = ui.element("div").classes("w-full h-full") + with self.plot_container: + ui.plotly(go.Figure()) def render_lens(self): - print(f"Lens type: {self.lens_type.value}") - if self.lens_type.value == self.opt_lens_id: + if self.lens_type.value == LENS_IDENTITY: return _identity - elif self.lens_type.value == self.opt_lens_pca: + elif self.lens_type.value == LENS_PCA: n = int(self.pca_n_components.value) return _pca(n) - elif self.lens_type.value == self.opt_lens_umap: + elif self.lens_type.value == LENS_UMAP: n = int(self.umap_n_components.value) return _umap(n) def render_cover(self): - if self.cover_type.value == self.opt_cover_trivial: + if self.cover_type.value == COVER_TRIVIAL: return TrivialCover() - elif self.cover_type.value == self.opt_cover_ball: - r = float(self.cover_ball_radius.value) - return BallCover(radius=r) - elif self.cover_type.value == self.opt_cover_cubical: - n = int(self.cover_cubical_n.value) - overlap = float(self.cover_cubical_overlap.value) - return CubicalCover(n_intervals=n, overlap_frac=overlap) - elif self.cover_type.value == self.opt_cover_knn: - k = int(self.cover_knn_k.value) - return KNNCover(neighbors=k) + elif self.cover_type.value == COVER_BALL: + radius = float(self.cover_ball_radius.value) + return BallCover(radius=radius) + elif self.cover_type.value == COVER_CUBICAL: + n_intervals = int(self.cover_cubical_n_intervals.value) + overlap_frac = float(self.cover_cubical_overlap_frac.value) + return CubicalCover(n_intervals=n_intervals, overlap_frac=overlap_frac) + elif self.cover_type.value == COVER_KNN: + neighbors = int(self.cover_knn_neighbors.value) + return KNNCover(neighbors=neighbors) def render_clustering(self): - if self.clustering_type.value == self.opt_clustering_trivial: + if self.clustering_type.value == CLUSTERING_TRIVIAL: return TrivialClustering() - elif self.clustering_type.value == self.opt_clustering_kmeans: - k = int(self.clustering_kmeans_k.value) - return KMeans(k) - elif self.clustering_type.value == self.opt_clustering_dbscan: + elif self.clustering_type.value == CLUSTERING_KMEANS: + n_clusters = int(self.clustering_kmeans_n_clusters.value) + return KMeans(n_clusters) + elif self.clustering_type.value == CLUSTERING_DBSCAN: eps = float(self.clustering_dbscan_eps.value) min_samples = int(self.clustering_dbscan_min_samples.value) - return DBSCAN(eps=eps) + return DBSCAN(eps=eps, min_samples=min_samples) + elif self.clustering_type == CLUSTERING_AGGLOMERATIVE: + n_clusters = int(self.clustering_agglomerative_n_clusters.value) + return AgglomerativeClustering(n_clusters=n_clusters) + + async def update_handler(self, _=None): + await run.io_bound(self.update) def update(self, _=None): X, labels = load_digits(return_X_y=True) lens = self.render_lens() if lens is None: - print("Lens is None") return y = lens(X) cover = self.render_cover() if cover is None: - print("Cover is None") return clustering = self.render_clustering() if clustering is None: - print("Clustering is None") return mapper_algo = MapperAlgorithm( @@ -286,24 +294,31 @@ def update(self, _=None): height=800, node_size=0.5, ) - if mapper_fig.layout.width is not None: - mapper_fig.layout.width = None - if not mapper_fig.layout.autosize: - mapper_fig.layout.autosize = True + # if mapper_fig.layout.width is not None: + mapper_fig.layout.width = None + # if not mapper_fig.layout.autosize: mapper_fig.layout.autosize = True - self.plot.update_figure(mapper_fig) - - def build(self): - with ui.left_drawer().classes("w-[400px]"): - self.build_lens() - ui.separator() - self.build_cover() - ui.separator() - self.build_clustering() - self.build_plot() + self.plot_container.clear() + with self.plot_container: + ui.plotly(mapper_fig) + + def __init__(self): + with ui.row().classes("w-full h-full m-0 p-0 gap-0 overflow-hidden"): + with ui.column().classes("w-64 h-full overflow-y-auto m-0 p-3 gap-2"): + with ui.card().classes("w-full"): + ui.markdown("#### 🔎 Lens") + self.build_lens() + with ui.card().classes("w-full"): + ui.markdown("#### 🌐 Cover") + self.build_cover() + with ui.card().classes("w-full"): + ui.markdown("#### 🧮 Clustering") + self.build_clustering() + + with ui.column().classes("flex-1 h-full overflow-hidden m-0 p-0"): + self.build_plot() self.update() app = App() -app.build() ui.run() From 322de331e55a392fb20ed5abbd77dc4b93c91847 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Sat, 31 May 2025 08:46:26 +0200 Subject: [PATCH 03/15] Added dataset choice and plot options --- app/nicegui_app.py | 195 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 151 insertions(+), 44 deletions(-) diff --git a/app/nicegui_app.py b/app/nicegui_app.py index e7a22d3..496e215 100644 --- a/app/nicegui_app.py +++ b/app/nicegui_app.py @@ -2,7 +2,7 @@ import plotly.graph_objs as go from nicegui import run, ui from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans -from sklearn.datasets import load_digits +from sklearn.datasets import load_digits, load_iris from sklearn.decomposition import PCA from umap import UMAP @@ -55,9 +55,63 @@ def _func(X): CLUSTERING_AGGLOMERATIVE = "Agglomerative" CLUSTERING_DBSCAN = "DBSCAN" +DATA_SOURCE_EXAMPLE = "Example" +DATA_SOURCE_CSV = "CSV" +DATA_SOURCE_OPENML = "OpenML" + +DATA_SOURCE_EXAMPLE_DIGITS = "Digits" +DATA_SOURCE_EXAMPLE_IRIS = "Iris" + +DRAW_3D = "3D" +DRAW_2D = "2D" +DRAW_ITERATIONS = 50 + class App: + def build_dataset(self): + self.data_source_type = ui.select( + label="Data Source", + options=[ + DATA_SOURCE_EXAMPLE, + DATA_SOURCE_CSV, + DATA_SOURCE_OPENML, + ], + value=DATA_SOURCE_EXAMPLE, + on_change=self.update_dataset_handler, + ).classes("w-full") + self.data_source_example_file = ui.select( + label="File", + options=[ + DATA_SOURCE_EXAMPLE_DIGITS, + DATA_SOURCE_EXAMPLE_IRIS, + ], + value=DATA_SOURCE_EXAMPLE_DIGITS, + on_change=self.update_dataset_handler, + ).classes("w-full") + self.data_source_example_file.bind_visibility_from( + target_object=self.data_source_type, + target_name="value", + value=DATA_SOURCE_EXAMPLE, + ) + self.data_source_csv = ui.upload( + on_upload=self.update_dataset_handler, + ).classes("w-full") + self.data_source_csv.bind_visibility_from( + target_object=self.data_source_type, + target_name="value", + value=DATA_SOURCE_CSV, + ) + self.data_source_openml = ui.input( + label="OpenML Code", + on_change=self.update_dataset_handler, + ).classes("w-full") + self.data_source_openml.bind_visibility_from( + target_object=self.data_source_type, + target_name="value", + value=DATA_SOURCE_OPENML, + ) + def build_lens(self): self.lens_type = ui.select( label="Lens type", @@ -67,14 +121,14 @@ def build_lens(self): LENS_UMAP, ], value=LENS_PCA, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.pca_n_components = ui.number( label="PCA Components", min=1, max=10, value=2, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.pca_n_components.bind_visibility_from( target_object=self.lens_type, @@ -86,7 +140,7 @@ def build_lens(self): min=1, max=10, value=2, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.umap_n_components.bind_visibility_from( target_object=self.lens_type, @@ -95,7 +149,6 @@ def build_lens(self): ) def build_cover(self): - self.cover_type = ui.select( label="Cover type", options=[ @@ -105,14 +158,14 @@ def build_cover(self): COVER_KNN, ], value=COVER_CUBICAL, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.cover_cubical_n_intervals = ui.number( label="Intervals", min=1, max=100, value=2, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.cover_cubical_n_intervals.bind_visibility_from( target_object=self.cover_type, @@ -122,9 +175,9 @@ def build_cover(self): self.cover_cubical_overlap_frac = ui.number( label="Overlap", min=0.0, - max=1.0, - value=0.5, - on_change=self.update_handler, + max=0.5, + value=0.25, + on_change=self.update_graph_handler, ).classes("w-full") self.cover_cubical_overlap_frac.bind_visibility_from( target_object=self.cover_type, @@ -135,7 +188,7 @@ def build_cover(self): label="Radius", min=0.0, value=100.0, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.cover_ball_radius.bind_visibility_from( target_object=self.cover_type, @@ -146,7 +199,7 @@ def build_cover(self): label="Neighbors", min=0, value=10, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.cover_knn_neighbors.bind_visibility_from( target_object=self.cover_type, @@ -164,13 +217,13 @@ def build_clustering(self): CLUSTERING_DBSCAN, ], value=CLUSTERING_TRIVIAL, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.clustering_kmeans_n_clusters = ui.number( label="Clusters", min=1, value=2, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.clustering_kmeans_n_clusters.bind_visibility_from( target_object=self.clustering_type, @@ -181,7 +234,7 @@ def build_clustering(self): label="Eps", min=0.0, value=0.5, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.clustering_dbscan_eps.bind_visibility_from( target_object=self.clustering_type, @@ -192,7 +245,7 @@ def build_clustering(self): label="Min Samples", min=1, value=5, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.clustering_dbscan_min_samples.bind_visibility_from( target_object=self.clustering_type, @@ -203,7 +256,7 @@ def build_clustering(self): label="Clusters", min=1, value=2, - on_change=self.update_handler, + on_change=self.update_graph_handler, ).classes("w-full") self.clustering_agglomerative_n_clusters.bind_visibility_from( target_object=self.clustering_type, @@ -211,6 +264,20 @@ def build_clustering(self): value=CLUSTERING_AGGLOMERATIVE, ) + def build_draw(self): + self.draw_3d = ui.toggle( + options=[DRAW_2D, DRAW_3D], + value=DRAW_3D, + on_change=self.update_plot_handler, + ) + self.draw_iterations = ui.number( + label="Layout Iterations", + min=1, + max=1000, + value=DRAW_ITERATIONS, + on_change=self.update_plot_handler, + ) + def build_plot(self): fig = go.Figure() fig.layout.width = None @@ -219,6 +286,19 @@ def build_plot(self): with self.plot_container: ui.plotly(go.Figure()) + def render_dataset(self): + source_type = self.data_source_type.value + if source_type == DATA_SOURCE_EXAMPLE: + name = self.data_source_example_file.value + if name == DATA_SOURCE_EXAMPLE_DIGITS: + X, y = load_digits(return_X_y=True, as_frame=True) + return X, y + elif name == DATA_SOURCE_EXAMPLE_IRIS: + X, y = load_iris(return_X_y=True, as_frame=True) + return X, y + elif source_type == DATA_SOURCE_CSV: + pass + def render_lens(self): if self.lens_type.value == LENS_IDENTITY: return _identity @@ -257,36 +337,59 @@ def render_clustering(self): n_clusters = int(self.clustering_agglomerative_n_clusters.value) return AgglomerativeClustering(n_clusters=n_clusters) - async def update_handler(self, _=None): - await run.io_bound(self.update) + async def update_graph_handler(self, _=None): + await run.io_bound(self.update_graph) - def update(self, _=None): - X, labels = load_digits(return_X_y=True) - lens = self.render_lens() - if lens is None: - return - y = lens(X) + async def update_dataset_handler(self, _=None): + await run.io_bound(self.update_dataset) + + def update_dataset(self, _=None): + self.X, self.labels = self.render_dataset() + self.update_graph() + def update_graph(self, _=None): + self.lens = self.render_lens() + if self.lens is None: + return + if self.X is None: + return + self.y = self.lens(self.X) cover = self.render_cover() if cover is None: return - clustering = self.render_clustering() if clustering is None: return - mapper_algo = MapperAlgorithm( cover=cover, clustering=clustering, verbose=False, ) + self.mapper_graph = mapper_algo.fit_transform(self.X, self.y) + self.update_plot() - mapper_graph = mapper_algo.fit_transform(X, y) + async def update_plot_handler(self, _=None): + await run.io_bound(self.update_plot) - mapper_plot = MapperPlot(mapper_graph, dim=3, iterations=400, seed=42) + def update_plot(self): + if self.mapper_graph is None: + return + dim = 3 + if self.draw_3d.value == DRAW_3D: + dim = 3 + elif self.draw_3d.value == DRAW_2D: + dim = 2 + + iterations = int(self.draw_iterations.value) + mapper_plot = MapperPlot( + self.mapper_graph, + dim=dim, + iterations=iterations, + seed=42, + ) mapper_fig = mapper_plot.plot_plotly( - colors=labels, + colors=self.labels, cmap=["jet", "viridis", "cividis"], agg=mode, title="mode of digits", @@ -294,30 +397,34 @@ def update(self, _=None): height=800, node_size=0.5, ) - # if mapper_fig.layout.width is not None: mapper_fig.layout.width = None - # if not mapper_fig.layout.autosize: mapper_fig.layout.autosize = True self.plot_container.clear() with self.plot_container: ui.plotly(mapper_fig) def __init__(self): - with ui.row().classes("w-full h-full m-0 p-0 gap-0 overflow-hidden"): - with ui.column().classes("w-64 h-full overflow-y-auto m-0 p-3 gap-2"): - with ui.card().classes("w-full"): - ui.markdown("#### 🔎 Lens") - self.build_lens() - with ui.card().classes("w-full"): - ui.markdown("#### 🌐 Cover") - self.build_cover() - with ui.card().classes("w-full"): - ui.markdown("#### 🧮 Clustering") - self.build_clustering() + with ui.row().classes("w-full h-screen m-0 p-0 gap-0 overflow-hidden"): + with ui.column().classes("w-64 h-full m-0 p-0"): # fixed-width sidebar + with ui.column().classes("w-64 h-full overflow-y-auto p-3 gap-2"): + with ui.card().classes("w-full"): + ui.markdown("#### 📊 Data") + self.build_dataset() + with ui.card().classes("w-full"): + ui.markdown("#### 🔎 Lens") + self.build_lens() + with ui.card().classes("w-full"): + ui.markdown("#### 🌐 Cover") + self.build_cover() + with ui.card().classes("w-full"): + ui.markdown("#### 🧮 Clustering") + self.build_clustering() with ui.column().classes("flex-1 h-full overflow-hidden m-0 p-0"): + with ui.row(align_items="baseline"): + self.build_draw() self.build_plot() - self.update() + self.update_dataset() app = App() From fae1c426036f4fc851010a48df5eedd735543917 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Sun, 1 Jun 2025 12:02:35 +0200 Subject: [PATCH 04/15] Improved attribute names. Improved ui handling --- src/tdamapper/_common.py | 16 +- src/tdamapper/core.py | 42 ++-- src/tdamapper/cover.py | 28 +-- src/tdamapper/plot.py | 94 ++++---- src/tdamapper/plot_backends/__init__.py | 0 .../plot_matplotlib.py} | 0 .../plot_plotly.py} | 210 ++++++++++++++---- .../plot_pyvis.py} | 6 +- src/tdamapper/utils/unionfind.py | 18 +- src/tdamapper/utils/vptree.py | 6 +- tests/ball_tree.py | 12 +- tests/test_unit_params.py | 26 +-- 12 files changed, 289 insertions(+), 169 deletions(-) create mode 100644 src/tdamapper/plot_backends/__init__.py rename src/tdamapper/{_plot_matplotlib.py => plot_backends/plot_matplotlib.py} (100%) rename src/tdamapper/{_plot_plotly.py => plot_backends/plot_plotly.py} (73%) rename src/tdamapper/{_plot_pyvis.py => plot_backends/plot_pyvis.py} (97%) diff --git a/src/tdamapper/_common.py b/src/tdamapper/_common.py index 4a66980..b3a38d2 100644 --- a/src/tdamapper/_common.py +++ b/src/tdamapper/_common.py @@ -29,12 +29,12 @@ def warn_user(msg): class EstimatorMixin: - def __is_sparse(self, X): + def _is_sparse(self, X): # simple alternative use scipy.sparse.issparse return hasattr(X, "toarray") def _validate_X_y(self, X, y): - if self.__is_sparse(X): + if self._is_sparse(X): raise ValueError("Sparse data not supported.") X = np.asarray(X) @@ -80,10 +80,10 @@ class ParamsMixin: scikit-learn `get_params` and `set_params`. """ - def __is_param_public(self, k): + def _is_param_public(self, k): return (not k.startswith("_")) and (not k.endswith("_")) - def __split_param(self, k): + def _split_param(self, k): k_split = k.split("__") outer = k_split[0] inner = "__".join(k_split[1:]) @@ -98,7 +98,7 @@ def get_params(self, deep=True): """ params = {} for k, v in self.__dict__.items(): - if self.__is_param_public(k): + if self._is_param_public(k): params[k] = v if hasattr(v, "get_params") and deep: for _k, _v in v.get_params().items(): @@ -111,8 +111,8 @@ def set_params(self, **params): """ nested_params = [] for k, v in params.items(): - if self.__is_param_public(k): - k_outer, k_inner = self.__split_param(k) + if self._is_param_public(k): + k_outer, k_inner = self._split_param(k) if not k_inner: if hasattr(self, k_outer): setattr(self, k_outer, v) @@ -131,7 +131,7 @@ def __repr__(self): v_default = getattr(obj_noargs, k) v_default_repr = repr(v_default) v_repr = repr(v) - if self.__is_param_public(k) and not v_repr == v_default_repr: + if self._is_param_public(k) and not v_repr == v_default_repr: args_repr.append(f"{k}={v_repr}") return f"{self.__class__.__name__}({', '.join(args_repr)})" diff --git a/src/tdamapper/core.py b/src/tdamapper/core.py index 3707003..d367295 100644 --- a/src/tdamapper/core.py +++ b/src/tdamapper/core.py @@ -297,7 +297,7 @@ def fit(self, X): :return: The object itself. :rtype: self """ - self.__X = X + self._X = X return self def search(self, x): @@ -314,7 +314,7 @@ def search(self, x): dataset. :rtype: list[int] """ - return list(range(0, len(self.__X))) + return list(range(0, len(self._X))) def apply(self, X): """ @@ -385,27 +385,27 @@ def __init__( def fit(self, X, y=None): X, y = self._validate_X_y(X, y) - self.__cover = TrivialCover() if self.cover is None else self.cover - self.__clustering = ( + self._cover = TrivialCover() if self.cover is None else self.cover + self._clustering = ( TrivialClustering() if self.clustering is None else self.clustering ) - self.__verbose = self.verbose - self.__failsafe = self.failsafe - if self.__failsafe: - self.__clustering = FailSafeClustering( - clustering=self.__clustering, - verbose=self.__verbose, + self._verbose = self.verbose + self._failsafe = self.failsafe + if self._failsafe: + self._clustering = FailSafeClustering( + clustering=self._clustering, + verbose=self._verbose, ) - self.__cover = clone(self.__cover) - self.__clustering = clone(self.__clustering) - self.__n_jobs = self.n_jobs + self._cover = clone(self._cover) + self._clustering = clone(self._clustering) + self._n_jobs = self.n_jobs y = X if y is None else y self.graph_ = mapper_graph( X, y, - self.__cover, - self.__clustering, - n_jobs=self.__n_jobs, + self._cover, + self._clustering, + n_jobs=self._n_jobs, ) self._set_n_features_in(X) return self @@ -451,16 +451,16 @@ def __init__(self, clustering=None, verbose=True): self.verbose = verbose def fit(self, X, y=None): - self.__clustering = ( + self._clustering = ( TrivialClustering() if self.clustering is None else self.clustering ) - self.__verbose = self.verbose + self._verbose = self.verbose self.labels_ = None try: - self.__clustering.fit(X, y) - self.labels_ = self.__clustering.labels_ + self._clustering.fit(X, y) + self.labels_ = self._clustering.labels_ except ValueError as err: - if self.__verbose: + if self._verbose: _logger.warning("Unable to perform clustering on local chart: %s", err) self.labels_ = [0 for _ in X] return self diff --git a/src/tdamapper/cover.py b/src/tdamapper/cover.py index acfdc21..90a8925 100644 --- a/src/tdamapper/cover.py +++ b/src/tdamapper/cover.py @@ -97,10 +97,10 @@ def fit(self, X): :rtype: self """ metric = get_metric(self.metric, **(self.metric_params or {})) - self.__radius = self.radius - self.__data = list(enumerate(X)) - self.__vptree = VPTree( - self.__data, + self._radius = self.radius + self._data = list(enumerate(X)) + self._vptree = VPTree( + self._data, metric=_Pullback(_snd, metric), metric_params=None, kind=self.kind, @@ -121,11 +121,11 @@ def search(self, x): :return: The indices of the neighbors contained in the dataset. :rtype: list[int] """ - if self.__vptree is None: + if self._vptree is None: return [] - neighs = self.__vptree.ball_search( + neighs = self._vptree.ball_search( (-1, x), - self.__radius, + self._radius, inclusive=False, ) return [x for (x, _) in neighs] @@ -198,10 +198,10 @@ def fit(self, X): :rtype: self """ metric = get_metric(self.metric, **(self.metric_params or {})) - self.__neighbors = self.neighbors - self.__data = list(enumerate(X)) - self.__vptree = VPTree( - self.__data, + self._neighbors = self.neighbors + self._data = list(enumerate(X)) + self._vptree = VPTree( + self._data, metric=_Pullback(_snd, metric), metric_params=None, kind=self.kind, @@ -223,9 +223,9 @@ def search(self, x): :return: The indices of the neighbors contained in the dataset. :rtype: list[int] """ - if self.__vptree is None: + if self._vptree is None: return [] - neighs = self.__vptree.knn_search((-1, x), self.__neighbors) + neighs = self._vptree.knn_search((-1, x), self._neighbors) return [x for (x, _) in neighs] @@ -447,7 +447,7 @@ def __init__( def _landmarks(self, X): lmrks = {} for x in X: - lmrk, center = self._get_center(x) + lmrk, _ = self._get_center(x) if lmrk not in lmrks: lmrks[lmrk] = x return lmrks diff --git a/src/tdamapper/plot.py b/src/tdamapper/plot.py index 860e9dc..fea6fb9 100644 --- a/src/tdamapper/plot.py +++ b/src/tdamapper/plot.py @@ -7,9 +7,9 @@ import numpy as np from tdamapper._common import deprecated -from tdamapper._plot_matplotlib import plot_matplotlib -from tdamapper._plot_plotly import plot_plotly, plot_plotly_update -from tdamapper._plot_pyvis import plot_pyvis +from tdamapper.plot_backends.plot_matplotlib import plot_matplotlib +from tdamapper.plot_backends.plot_plotly import plot_plotly, plot_plotly_update +from tdamapper.plot_backends.plot_pyvis import plot_pyvis class MapperPlot: @@ -382,29 +382,29 @@ def __init__( height=512, cmap="jet", ): - self.__graph = graph - self.__dim = dim - self.__iterations = iterations - self.__seed = seed - self.__mapper_plot = MapperPlot( - graph=self.__graph, - dim=self.__dim, - iterations=self.__iterations, - seed=self.__seed, + self._graph = graph + self._dim = dim + self._iterations = iterations + self._seed = seed + self._mapper_plot = MapperPlot( + graph=self._graph, + dim=self._dim, + iterations=self._iterations, + seed=self._seed, ) - self.__colors = colors - self.__agg = agg - self.__title = title - self.__width = width - self.__height = height - self.__cmap = cmap - self.__fig = self.__mapper_plot.plot_plotly( - colors=self.__colors, - agg=self.__agg, - title=self.__title, - width=self.__width, - height=self.__height, - cmap=self.__cmap, + self._colors = colors + self._agg = agg + self._title = title + self._width = width + self._height = height + self._cmap = cmap + self._fig = self._mapper_plot.plot_plotly( + colors=self._colors, + agg=self._agg, + title=self._title, + width=self._width, + height=self._height, + cmap=self._cmap, ) def update( @@ -451,20 +451,20 @@ def update( """ _update_pos = False if seed is not None: - self.__seed = seed + self._seed = seed _update_pos = True if iterations is not None: - self.__iterations = iterations + self._iterations = iterations _update_pos = True if _update_pos: - self.__mapper_plot = MapperPlot( - graph=self.__graph, - dim=self.__dim, - iterations=self.__iterations, - seed=self.__seed, + self._mapper_plot = MapperPlot( + graph=self._graph, + dim=self._dim, + iterations=self._iterations, + seed=self._seed, ) - self.__mapper_plot.plot_plotly_update( - self.__fig, + self._mapper_plot.plot_plotly_update( + self._fig, colors=colors, agg=agg, title=title, @@ -482,7 +482,7 @@ def plot(self): context to be shown. :rtype: :class:`plotly.graph_objects.Figure` """ - return self.__fig + return self._fig class MapperLayoutStatic: @@ -543,12 +543,12 @@ def __init__( height=512, cmap="jet", ): - self.__colors = colors - self.__agg = agg - self.__title = title - self.__width = width - self.__height = height - self.__cmap = cmap + self._colors = colors + self._agg = agg + self._title = title + self._width = width + self._height = height + self._cmap = cmap self.mapper_plot = MapperPlot( graph=graph, dim=dim, @@ -566,10 +566,10 @@ def plot(self): :class:`matplotlib.axes.Axes` """ return self.mapper_plot.plot_matplotlib( - colors=self.__colors, - agg=self.__agg, - title=self.__title, - width=self.__width, - height=self.__height, - cmap=self.__cmap, + colors=self._colors, + agg=self._agg, + title=self._title, + width=self._width, + height=self._height, + cmap=self._cmap, ) diff --git a/src/tdamapper/plot_backends/__init__.py b/src/tdamapper/plot_backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tdamapper/_plot_matplotlib.py b/src/tdamapper/plot_backends/plot_matplotlib.py similarity index 100% rename from src/tdamapper/_plot_matplotlib.py rename to src/tdamapper/plot_backends/plot_matplotlib.py diff --git a/src/tdamapper/_plot_plotly.py b/src/tdamapper/plot_backends/plot_plotly.py similarity index 73% rename from src/tdamapper/_plot_plotly.py rename to src/tdamapper/plot_backends/plot_plotly.py index 55ed41a..4ef453c 100644 --- a/src/tdamapper/_plot_plotly.py +++ b/src/tdamapper/plot_backends/plot_plotly.py @@ -53,6 +53,54 @@ def _get_plotly_colorscales(): PLOTLY_CMAPS = _get_plotly_colorscales() +def _to_cmaps(cmap: Union[str, List[str]]) -> List[str]: + """Convert a single cmap or a list of cmaps to a list of cmaps.""" + if isinstance(cmap, str): + return [cmap] + elif isinstance(cmap, list): + return cmap + else: + raise ValueError(f"Invalid cmap type: {type(cmap)}. Expected str or list[str].") + + +def _to_colors(colors: Union[np.ndarray, List[float]]) -> np.ndarray: + """Convert colors to a numpy array.""" + colors_arr = np.array(colors) + if colors_arr.ndim == 1: + return colors_arr.reshape(-1, 1) + elif colors_arr.ndim == 2: + return colors_arr + else: + raise ValueError( + f"Invalid colors shape: {colors_arr.shape}. Expected 1D or 2D array." + ) + + +def _to_titles(title, colors_num): + if title is None: + return [DEFAULT_TITLE for _ in range(colors_num)] + elif isinstance(title, str): + return [title for _ in range(colors_num)] + elif isinstance(title, list) and len(title) == colors_num: + return title + else: + raise ValueError( + f"Invalid title type: {type(title)}. Expected str or list[str]." + ) + + +def _to_node_sizes(node_size, colors_num): + if isinstance(node_size, (int, float)): + return [node_size] * colors_num + elif isinstance(node_size, list) and len(node_size) == colors_num: + return node_size + else: + raise ValueError( + f"Invalid node_size type: {type(node_size)}. " + "Expected int, float or list[int, float]." + ) + + def plot_plotly( mapper_plot, width: int, @@ -63,19 +111,17 @@ def plot_plotly( agg=np.nanmean, cmap: Union[str, List[str]] = DEFAULT_CMAP, ) -> go.Figure: - cmaps = [cmap] if isinstance(cmap, str) else cmap - colors = np.array(colors) - if colors.ndim == 1: - colors = colors.reshape(-1, 1) + cmaps = _to_cmaps(cmap) + colors = _to_colors(colors) colors_num = colors.shape[1] - titles = [f"Color {i}" for i in range(colors_num)] - if isinstance(title, str): - titles = [title for _ in range(colors_num)] - elif isinstance(title, list) and len(title) == colors_num: - titles = title - node_sizes = [node_size] if isinstance(node_size, (int, float)) else node_size + titles = _to_titles(title, colors_num) + node_sizes = _to_node_sizes(node_size, colors_num) fig = _figure(mapper_plot, width, height, node_sizes, colors, titles, agg, cmaps) - _add_ui_to_layout(mapper_plot, fig, colors, titles, node_sizes, agg, cmaps) + ui = PlotlyUI() + ui.set_menu_cmap(mapper_plot, cmaps) + ui.set_menu_color(mapper_plot, colors, titles, agg) + ui.set_slider_size(mapper_plot, node_sizes) + _set_ui(fig, ui) return fig @@ -84,24 +130,42 @@ def plot_plotly_update( fig: go.Figure, width: Optional[int] = None, height: Optional[int] = None, - title: Optional[str] = None, - node_size: Optional[int] = None, + node_size: Optional[Union[int, float, List[Union[int, float]]]] = None, colors=None, + title: Optional[Union[str, List[str]]] = None, agg=None, - cmap: Optional[str] = None, + cmap: Optional[Union[str, List[str]]] = None, ) -> go.Figure: - if (width is not None) and (height is not None): - _update_layout(fig, width, height) + ui = PlotlyUI() + cmaps = None + if cmap is not None: + cmaps = _to_cmaps(cmap) + ui.set_menu_cmap(mapper_plot, cmaps) + colors_num = 0 + if colors is not None: + colors = _to_colors(colors) + colors_num = colors.shape[1] + titles = None if title is not None: - _set_title(mapper_plot, fig, title) + titles = _to_titles(title, colors_num) + if titles is not None and colors is not None and agg is not None: + ui.set_menu_color(mapper_plot, colors, titles, agg) + node_sizes = None if node_size is not None: - _set_node_size(mapper_plot, fig, node_size) - if (colors is not None) and (agg is not None): - _set_colors(mapper_plot, fig, colors, agg) - if cmap is not None: - _set_cmap(mapper_plot, fig, cmap) - # _add_ui_to_layout(mapper_plot, fig, colors, node_size, agg, cmap) - # TODO: understand how to update this + node_sizes = _to_node_sizes(node_size, colors_num) + ui.set_slider_size(mapper_plot, node_sizes) + _update( + mapper_plot, + fig, + width=width, + height=height, + titles=titles, + node_sizes=node_sizes, + colors=colors, + agg=agg, + cmaps=cmaps, + ) + _set_ui(fig, ui) return fig @@ -214,9 +278,14 @@ def _set_node_size(mapper_plot, fig, node_size): ) -def _update_layout(fig, width, height): +def _set_width(fig, width): fig.update_layout( width=width, + ) + + +def _set_height(fig, height): + fig.update_layout( height=height, ) @@ -235,17 +304,50 @@ def _figure(mapper_plot, width, height, node_sizes, colors, titles, agg, cmaps): ) _edges_tr = _edges_trace(mapper_plot, edge_pos_arr) _nodes_tr = _nodes_trace(mapper_plot, node_pos_arr) - _layout_ = _layout(width, height) + _layout_ = _layout() fig = go.Figure(data=[_edges_tr, _nodes_tr], layout=_layout_) - _set_cmap(mapper_plot, fig, cmaps[0]) - _set_colors(mapper_plot, fig, colors[:, 0], agg) - _set_node_size(mapper_plot, fig, node_sizes[len(node_sizes) // 2]) - _set_title(mapper_plot, fig, titles[0]) + _update( + mapper_plot, + fig, + width=width, + height=height, + titles=titles, + node_sizes=node_sizes, + colors=colors, + agg=agg, + cmaps=cmaps, + ) return fig +def _update( + mapper_plot, + fig: go.Figure, + width: Optional[int] = None, + height: Optional[int] = None, + titles: Optional[List[str]] = None, + node_sizes: Optional[List[int]] = None, + colors=None, + agg=None, + cmaps: Optional[List[str]] = None, +) -> go.Figure: + if width is not None: + _set_width(fig, width) + if height is not None: + _set_height(fig, height) + if titles is not None: + _set_title(mapper_plot, fig, titles[0]) + if node_sizes is not None: + _set_node_size(mapper_plot, fig, node_sizes[len(node_sizes) // 2]) + if (colors is not None) and (agg is not None): + _set_colors(mapper_plot, fig, colors[:, 0], agg) + if cmaps is not None: + _set_cmap(mapper_plot, fig, cmaps[0]) + return fig + + def _nodes_trace(mapper_plot, node_pos_arr): scatter = dict( name=_NODES_TRACE, @@ -343,7 +445,7 @@ def _fmt(x, max_len=3): return f"{x:{fmt}}" -def _layout(width, height): +def _layout(): line_col = "rgba(230, 230, 230, 1.0)" axis = dict( showline=False, @@ -378,8 +480,6 @@ def _layout(width, height): margin=dict(b=10, l=10, r=10, t=10), xaxis=axis, yaxis=axis, - width=width, - height=height, scene=dict( xaxis=scene_axis, yaxis=scene_axis, @@ -388,19 +488,39 @@ def _layout(width, height): ) -def _add_ui_to_layout(mapper_plot, mapper_fig, colors, titles, node_sizes, agg, cmaps): - cmaps_plotly = [PLOTLY_CMAPS.get(c.lower()) for c in cmaps] - menu_color = _ui_color(mapper_plot, colors, titles, agg) - if menu_color["buttons"]: - menu_color["x"] = 0.0 - else: - menu_color["x"] = -0.25 - menu_cmap = _ui_cmap(mapper_plot, cmaps_plotly) - menu_cmap["x"] = menu_color["x"] + 0.25 - slider_size = _ui_node_size(mapper_plot, node_sizes) +class PlotlyUI: + + def __init__(self): + self.menu_cmap = None + self.menu_color = None + self.slider_size = None + + def set_menu_cmap(self, mapper_plot, cmaps): + cmaps_plotly = [PLOTLY_CMAPS.get(c.lower()) for c in cmaps] + self.menu_cmap = _ui_cmap(mapper_plot, cmaps_plotly) + + def set_menu_color(self, mapper_plot, colors, titles, agg): + self.menu_color = _ui_color(mapper_plot, colors, titles, agg) + + def set_slider_size(self, mapper_plot, node_sizes): + self.slider_size = _ui_node_size(mapper_plot, node_sizes) + + +def _set_ui(mapper_fig, plotly_ui: PlotlyUI): + menus = [] + sliders = [] + if plotly_ui.menu_cmap: + plotly_ui.menu_cmap["x"] = 0.25 + menus.append(plotly_ui.menu_cmap) + if plotly_ui.menu_color: + plotly_ui.menu_color["x"] = 0.0 + menus.append(plotly_ui.menu_color) + if plotly_ui.slider_size: + plotly_ui.slider_size["x"] = 0.0 + sliders.append(plotly_ui.slider_size) mapper_fig.update_layout( - updatemenus=[menu_cmap, menu_color], - sliders=[slider_size], + updatemenus=menus, + sliders=sliders, ) diff --git a/src/tdamapper/_plot_pyvis.py b/src/tdamapper/plot_backends/plot_pyvis.py similarity index 97% rename from src/tdamapper/_plot_pyvis.py rename to src/tdamapper/plot_backends/plot_pyvis.py index 9d47af6..cd3d05c 100644 --- a/src/tdamapper/_plot_pyvis.py +++ b/src/tdamapper/plot_backends/plot_pyvis.py @@ -19,7 +19,7 @@ _TICKS_NUM = 10 -def __fmt(x, max_len=3): +def _fmt(x, max_len=3): fmt = f".{max_len}g" return f"{x:{fmt}}" @@ -150,7 +150,7 @@ def plot_pyvis( ) colorbar = _colorbar(height=height, cmap=cmap, cmin=cmin, cmax=cmax, title=title) combined_html = _combine(net, colorbar) - with open(output_file, "w") as file: + with open(output_file, "w", encoding="utf-8") as file: file.write(combined_html) @@ -220,7 +220,7 @@ def _color(node): node_id = int(node) n_size = _size(node) node_color = _color(node) - node_stats = __fmt(node_colors[node]) + node_stats = _fmt(node_colors[node]) node_label = f"color: {node_stats}\nnode: {node_id}\nsize: {n_size}" node_pos = mapper_plot.positions[node] net.add_node( diff --git a/src/tdamapper/utils/unionfind.py b/src/tdamapper/utils/unionfind.py index 376d101..6e036f6 100644 --- a/src/tdamapper/utils/unionfind.py +++ b/src/tdamapper/utils/unionfind.py @@ -1,30 +1,30 @@ class UnionFind: def __init__(self, X): - self.__parent = {x: x for x in X} - self.__size = {x: 1 for x in X} + self._parent = {x: x for x in X} + self._size = {x: 1 for x in X} def find(self, x): root = x - while root != self.__parent[root]: - root = self.__parent[root] + while root != self._parent[root]: + root = self._parent[root] tmp = x while tmp != root: - parent = self.__parent[tmp] - self.__parent[tmp] = root + parent = self._parent[tmp] + self._parent[tmp] = root tmp = parent return root def union(self, x, y): x, y = self.find(x), self.find(y) if x != y: - x_size, y_size = self.__size[x], self.__size[y] + x_size, y_size = self._size[x], self._size[y] if x_size < y_size: to_keep, to_move = y, x else: to_keep, to_move = x, y - self.__parent[to_move] = to_keep - self.__size[to_keep] = x_size + y_size + self._parent[to_move] = to_keep + self._size[to_keep] = x_size + y_size return to_keep else: return x diff --git a/src/tdamapper/utils/vptree.py b/src/tdamapper/utils/vptree.py index 3b3a62f..adeba9e 100644 --- a/src/tdamapper/utils/vptree.py +++ b/src/tdamapper/utils/vptree.py @@ -51,7 +51,7 @@ def __init__( builder = HVPT else: raise ValueError(f"Unknown kind of vptree: {kind}") - self.__vpt = builder( + self._vpt = builder( X, metric=metric, metric_params=metric_params, @@ -79,7 +79,7 @@ def ball_search(self, point, eps, inclusive=True): query point. :rtype: list """ - return self.__vpt.ball_search(point, eps, inclusive=inclusive) + return self._vpt.ball_search(point, eps, inclusive=inclusive) def knn_search(self, point, k): """ @@ -96,4 +96,4 @@ def knn_search(self, point, k): :return: A list of the k-nearest neighbors to the given query point. :rtype: list """ - return self.__vpt.knn_search(point, k) + return self._vpt.knn_search(point, k) diff --git a/tests/ball_tree.py b/tests/ball_tree.py index 0205316..c3cfbd4 100644 --- a/tests/ball_tree.py +++ b/tests/ball_tree.py @@ -12,8 +12,8 @@ def __init__( pivoting=None, **kwargs, ): - self.__dataset = X - self.__ball_tree = BallTree( + self._dataset = X + self._ball_tree = BallTree( X, leaf_size=leaf_capacity, metric=metric, @@ -21,21 +21,21 @@ def __init__( ) def ball_search(self, point, eps, inclusive=True): - ids = self.__ball_tree.query_radius( + ids = self._ball_tree.query_radius( [point], eps, return_distance=False, count_only=False, sort_results=False, ) - return [self.__dataset[i] for i in ids[0]] + return [self._dataset[i] for i in ids[0]] def knn_search(self, point, k): - ids = self.__ball_tree.query( + ids = self._ball_tree.query( [point], k=k, return_distance=False, dualtree=False, breadth_first=False, ) - return [self.__dataset[i] for i in ids[0]] + return [self._dataset[i] for i in ids[0]] diff --git a/tests/test_unit_params.py b/tests/test_unit_params.py index fd594e3..5b52f85 100644 --- a/tests/test_unit_params.py +++ b/tests/test_unit_params.py @@ -5,23 +5,23 @@ from tdamapper.learn import MapperAlgorithm, MapperClustering -def __test_clone(obj): +def _test_clone(obj): obj_repr = repr(obj) obj_cln = clone(obj) cln_repr = repr(obj_cln) assert obj_repr == cln_repr -def __test_repr(obj): +def _test_repr(obj): obj_repr = repr(obj) _obj = eval(obj_repr) _obj_repr = repr(_obj) assert obj_repr == _obj_repr -def __test_clone_and_repr(obj): - __test_clone(obj) - __test_repr(obj) +def _test_clone_and_repr(obj): + _test_clone(obj) + _test_repr(obj) def test_params_mapper_algorithm(): @@ -65,8 +65,8 @@ def test_params_mapper_clustering(): def test_clone_and_repr_ball_cover(): - __test_clone_and_repr(BallCover()) - __test_clone_and_repr( + _test_clone_and_repr(BallCover()) + _test_clone_and_repr( BallCover( radius=2.0, metric="test", @@ -80,8 +80,8 @@ def test_clone_and_repr_ball_cover(): def test_clone_and_repr_cubical_cover(): - __test_clone_and_repr(CubicalCover()) - __test_clone_and_repr( + _test_clone_and_repr(CubicalCover()) + _test_clone_and_repr( CubicalCover( n_intervals=4, overlap_frac=5, @@ -95,8 +95,8 @@ def test_clone_and_repr_cubical_cover(): def test_clone_repr_mapper_algorithm(): - __test_clone_and_repr(MapperAlgorithm()) - __test_clone_and_repr( + _test_clone_and_repr(MapperAlgorithm()) + _test_clone_and_repr( MapperAlgorithm( cover=CubicalCover( n_intervals=3, @@ -114,8 +114,8 @@ def test_clone_repr_mapper_algorithm(): def test_clone_repr_mapper_clustering(): - __test_clone_and_repr(MapperClustering()) - __test_clone_and_repr( + _test_clone_and_repr(MapperClustering()) + _test_clone_and_repr( MapperClustering( cover=CubicalCover( n_intervals=3, From 5702277426ee7e812c2beea6a111c264c93ecd6a Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Sun, 1 Jun 2025 12:13:41 +0200 Subject: [PATCH 05/15] Minor improvements --- docs/source/notebooks/circles.py | 2 ++ docs/source/notebooks/digits.py | 4 +-- src/tdamapper/plot_backends/plot_plotly.py | 42 +++++++++++----------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/docs/source/notebooks/circles.py b/docs/source/notebooks/circles.py index 40f7405..0fe54e4 100644 --- a/docs/source/notebooks/circles.py +++ b/docs/source/notebooks/circles.py @@ -96,6 +96,7 @@ fig = plot.plot_plotly( colors=labels, cmap=["jet", "viridis", "cividis"], + node_size=[0.0, 0.5, 1.0, 1.5, 2.0], agg=np.nanmean, width=600, height=600, @@ -118,6 +119,7 @@ fig = plot.plot_plotly( colors=labels, cmap=["jet", "viridis", "cividis"], + node_size=[0.0, 0.5, 1.0, 1.5, 2.0], agg=np.nanstd, width=600, height=600, diff --git a/docs/source/notebooks/digits.py b/docs/source/notebooks/digits.py index daf7412..ee75e16 100644 --- a/docs/source/notebooks/digits.py +++ b/docs/source/notebooks/digits.py @@ -103,10 +103,10 @@ def mode(arr): colors=labels, cmap=["jet", "viridis", "cividis"], agg=mode, + node_size=[0.0, 0.5, 1.0, 1.5, 2.0], title="mode of digits", width=600, height=600, - node_size=0.5, ) fig.show(config={"scrollZoom": True}, renderer="notebook_connected") @@ -134,10 +134,10 @@ def entropy(arr): colors=labels, cmap=["jet", "viridis", "cividis"], agg=entropy, + node_size=[0.0, 0.5, 1.0, 1.5, 2.0], title="entropy of digits", width=600, height=600, - node_size=0.5, ) fig.show(config={"scrollZoom": True}, renderer="notebook_connected") diff --git a/src/tdamapper/plot_backends/plot_plotly.py b/src/tdamapper/plot_backends/plot_plotly.py index 4ef453c..d19cc2f 100644 --- a/src/tdamapper/plot_backends/plot_plotly.py +++ b/src/tdamapper/plot_backends/plot_plotly.py @@ -53,6 +53,24 @@ def _get_plotly_colorscales(): PLOTLY_CMAPS = _get_plotly_colorscales() +class PlotlyUI: + + def __init__(self): + self.menu_cmap = None + self.menu_color = None + self.slider_size = None + + def set_menu_cmap(self, mapper_plot, cmaps): + cmaps_plotly = [PLOTLY_CMAPS.get(c.lower()) for c in cmaps] + self.menu_cmap = _ui_cmap(mapper_plot, cmaps_plotly) + + def set_menu_color(self, mapper_plot, colors, titles, agg): + self.menu_color = _ui_color(mapper_plot, colors, titles, agg) + + def set_slider_size(self, mapper_plot, node_sizes): + self.slider_size = _ui_node_size(mapper_plot, node_sizes) + + def _to_cmaps(cmap: Union[str, List[str]]) -> List[str]: """Convert a single cmap or a list of cmaps to a list of cmaps.""" if isinstance(cmap, str): @@ -488,32 +506,16 @@ def _layout(): ) -class PlotlyUI: - - def __init__(self): - self.menu_cmap = None - self.menu_color = None - self.slider_size = None - - def set_menu_cmap(self, mapper_plot, cmaps): - cmaps_plotly = [PLOTLY_CMAPS.get(c.lower()) for c in cmaps] - self.menu_cmap = _ui_cmap(mapper_plot, cmaps_plotly) - - def set_menu_color(self, mapper_plot, colors, titles, agg): - self.menu_color = _ui_color(mapper_plot, colors, titles, agg) - - def set_slider_size(self, mapper_plot, node_sizes): - self.slider_size = _ui_node_size(mapper_plot, node_sizes) - - def _set_ui(mapper_fig, plotly_ui: PlotlyUI): menus = [] sliders = [] + x = 0.0 if plotly_ui.menu_cmap: - plotly_ui.menu_cmap["x"] = 0.25 + plotly_ui.menu_cmap["x"] = x + x += 0.25 menus.append(plotly_ui.menu_cmap) if plotly_ui.menu_color: - plotly_ui.menu_color["x"] = 0.0 + plotly_ui.menu_color["x"] = x menus.append(plotly_ui.menu_color) if plotly_ui.slider_size: plotly_ui.slider_size["x"] = 0.0 From 81ec45d61557e0501a17c21d76a4a55c031c7b11 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Sun, 1 Jun 2025 12:32:44 +0200 Subject: [PATCH 06/15] Fixed node size ui --- src/tdamapper/plot_backends/plot_plotly.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/tdamapper/plot_backends/plot_plotly.py b/src/tdamapper/plot_backends/plot_plotly.py index d19cc2f..3cc8925 100644 --- a/src/tdamapper/plot_backends/plot_plotly.py +++ b/src/tdamapper/plot_backends/plot_plotly.py @@ -107,10 +107,10 @@ def _to_titles(title, colors_num): ) -def _to_node_sizes(node_size, colors_num): +def _to_node_sizes(node_size): if isinstance(node_size, (int, float)): - return [node_size] * colors_num - elif isinstance(node_size, list) and len(node_size) == colors_num: + return [node_size] + elif isinstance(node_size, list): return node_size else: raise ValueError( @@ -133,7 +133,7 @@ def plot_plotly( colors = _to_colors(colors) colors_num = colors.shape[1] titles = _to_titles(title, colors_num) - node_sizes = _to_node_sizes(node_size, colors_num) + node_sizes = _to_node_sizes(node_size) fig = _figure(mapper_plot, width, height, node_sizes, colors, titles, agg, cmaps) ui = PlotlyUI() ui.set_menu_cmap(mapper_plot, cmaps) @@ -170,7 +170,7 @@ def plot_plotly_update( ui.set_menu_color(mapper_plot, colors, titles, agg) node_sizes = None if node_size is not None: - node_sizes = _to_node_sizes(node_size, colors_num) + node_sizes = _to_node_sizes(node_size) ui.set_slider_size(mapper_plot, node_sizes) _update( mapper_plot, From 6887d24810de8311c7a7edc4cfa02fd5769fb8b5 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Sun, 1 Jun 2025 18:25:39 +0200 Subject: [PATCH 07/15] Added type hints. Improved handling of missing arguments --- src/tdamapper/plot_backends/plot_plotly.py | 80 +++++++++++++--------- 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/src/tdamapper/plot_backends/plot_plotly.py b/src/tdamapper/plot_backends/plot_plotly.py index 3cc8925..ffdea18 100644 --- a/src/tdamapper/plot_backends/plot_plotly.py +++ b/src/tdamapper/plot_backends/plot_plotly.py @@ -60,7 +60,9 @@ def __init__(self): self.menu_color = None self.slider_size = None - def set_menu_cmap(self, mapper_plot, cmaps): + def set_menu_cmap(self, mapper_plot, cmaps: Optional[List[str]]) -> None: + if cmaps is None: + return cmaps_plotly = [PLOTLY_CMAPS.get(c.lower()) for c in cmaps] self.menu_cmap = _ui_cmap(mapper_plot, cmaps_plotly) @@ -71,8 +73,10 @@ def set_slider_size(self, mapper_plot, node_sizes): self.slider_size = _ui_node_size(mapper_plot, node_sizes) -def _to_cmaps(cmap: Union[str, List[str]]) -> List[str]: +def _to_cmaps(cmap: Optional[Union[str, List[str]]]) -> List[str]: """Convert a single cmap or a list of cmaps to a list of cmaps.""" + if cmap is None: + return [DEFAULT_CMAP] if isinstance(cmap, str): return [cmap] elif isinstance(cmap, list): @@ -94,11 +98,11 @@ def _to_colors(colors: Union[np.ndarray, List[float]]) -> np.ndarray: ) -def _to_titles(title, colors_num): +def _to_titles(title: Optional[Union[str, List[str]]], colors_num: int) -> List[str]: if title is None: - return [DEFAULT_TITLE for _ in range(colors_num)] + return [f"{i}" for i in range(colors_num)] elif isinstance(title, str): - return [title for _ in range(colors_num)] + return [f"{title} {i}" for i in range(colors_num)] elif isinstance(title, list) and len(title) == colors_num: return title else: @@ -107,7 +111,9 @@ def _to_titles(title, colors_num): ) -def _to_node_sizes(node_size): +def _to_node_sizes( + node_size: Optional[Union[int, float, List[Union[int, float]]]] +) -> List[float]: if isinstance(node_size, (int, float)): return [node_size] elif isinstance(node_size, list): @@ -123,11 +129,11 @@ def plot_plotly( mapper_plot, width: int, height: int, - node_size: Optional[Union[int, float, List[Union[int, float]]]] = DEFAULT_NODE_SIZE, - colors=None, + colors: Union[np.ndarray, List[float]], + node_size: Optional[Union[int, float, List[Union[int, float]]]] = None, title: Optional[Union[str, List[str]]] = None, agg=np.nanmean, - cmap: Union[str, List[str]] = DEFAULT_CMAP, + cmap: Optional[Union[str, List[str]]] = None, ) -> go.Figure: cmaps = _to_cmaps(cmap) colors = _to_colors(colors) @@ -187,7 +193,7 @@ def plot_plotly_update( return fig -def _node_pos_array(graph, dim, node_pos): +def _node_pos_array(graph: nx.Graph, dim: int, node_pos): return tuple([node_pos[n][i] for n in graph.nodes()] for i in range(dim)) @@ -202,7 +208,7 @@ def _edge_pos_array(graph, dim, node_pos): return edges_arr -def _marker_size(mapper_plot, node_size): +def _marker_size(mapper_plot, node_size: float) -> List[float]: attr_size = nx.get_node_attributes(mapper_plot.graph, ATTR_SIZE) max_size = max(attr_size.values(), default=1.0) scale = node_size * (25.0 if mapper_plot.dim == 2 else 15.0) @@ -212,14 +218,14 @@ def _marker_size(mapper_plot, node_size): return marker_size -def _get_cmap_rgb(cmap): +def _get_cmap_rgb(cmap: str): """Return a colorscale in [[float, 'rgb(r,g,b)']] format.""" base_scale = pc.get_colorscale(cmap) # If it's already in [float, color] format, we're good return [[pos, color] for pos, color in base_scale] -def _set_cmap(mapper_plot, fig, cmap): +def _set_cmap(mapper_plot, fig: go.Figure, cmap: str) -> None: cmap_rgb = _get_cmap_rgb(cmap) fig.update_traces( patch=dict( @@ -244,7 +250,7 @@ def _set_cmap(mapper_plot, fig, cmap): ) -def _set_colors(mapper_plot, fig, colors, agg): +def _set_colors(mapper_plot, fig: go.Figure, colors, agg): node_col = aggregate_graph(colors, mapper_plot.graph, agg) scatter_text = _text(mapper_plot, node_col) colors_arr = list(node_col.values()) @@ -278,7 +284,7 @@ def _set_colors(mapper_plot, fig, colors, agg): ) -def _set_title(mapper_plot, fig, color_name): +def _set_title(mapper_plot, fig: go.Figure, color_name: str): fig.update_traces( patch=dict( marker_colorbar=_colorbar(mapper_plot, color_name), @@ -287,7 +293,7 @@ def _set_title(mapper_plot, fig, color_name): ) -def _set_node_size(mapper_plot, fig, node_size): +def _set_node_size(mapper_plot, fig: go.Figure, node_size: float) -> None: fig.update_traces( patch=dict( marker_size=_marker_size(mapper_plot, node_size), @@ -296,19 +302,28 @@ def _set_node_size(mapper_plot, fig, node_size): ) -def _set_width(fig, width): +def _set_width(fig: go.Figure, width: int) -> None: fig.update_layout( width=width, ) -def _set_height(fig, height): +def _set_height(fig: go.Figure, height: int) -> None: fig.update_layout( height=height, ) -def _figure(mapper_plot, width, height, node_sizes, colors, titles, agg, cmaps): +def _figure( + mapper_plot, + width: int, + height: int, + node_sizes: List[float], + colors: np.ndarray, + titles: List[str], + agg, + cmaps: List[str], +) -> go.Figure: node_pos = mapper_plot.positions node_pos_arr = _node_pos_array( mapper_plot.graph, @@ -346,7 +361,7 @@ def _update( width: Optional[int] = None, height: Optional[int] = None, titles: Optional[List[str]] = None, - node_sizes: Optional[List[int]] = None, + node_sizes: Optional[List[float]] = None, colors=None, agg=None, cmaps: Optional[List[str]] = None, @@ -422,7 +437,9 @@ def _edges_trace(mapper_plot, edge_pos_arr): return go.Scatter(scatter) -def _colorbar(mapper_plot, title): +def _colorbar( + mapper_plot, title: str +) -> Union[go.scatter3d.marker.ColorBar, go.scatter.marker.ColorBar]: cbar = dict( showticklabels=True, outlinewidth=1, @@ -463,7 +480,7 @@ def _fmt(x, max_len=3): return f"{x:{fmt}}" -def _layout(): +def _layout() -> go.Layout: line_col = "rgba(230, 230, 230, 1.0)" axis = dict( showline=False, @@ -506,7 +523,7 @@ def _layout(): ) -def _set_ui(mapper_fig, plotly_ui: PlotlyUI): +def _set_ui(mapper_fig: go.Figure, plotly_ui: PlotlyUI) -> None: menus = [] sliders = [] x = 0.0 @@ -526,10 +543,10 @@ def _set_ui(mapper_fig, plotly_ui: PlotlyUI): ) -def _ui_cmap(mapper_plot, cmaps): +def _ui_cmap(mapper_plot, cmaps: List[str]) -> dict: target_traces = [1] if mapper_plot.dim == 2 else [0, 1] - def _update_cmap(cmap): + def _update_cmap(cmap: str) -> dict: cmap_rgb = _get_cmap_rgb(cmap) if mapper_plot.dim == 2: return { @@ -542,6 +559,7 @@ def _update_cmap(cmap): "marker.line.colorscale": [None, cmap_rgb], "line.colorscale": [cmap_rgb, None], } + return {} buttons = [] if len(cmaps) > 1: @@ -564,7 +582,7 @@ def _update_cmap(cmap): ) -def _ui_node_size(mapper_plot, node_sizes): +def _ui_node_size(mapper_plot, node_sizes: List[float]) -> dict: steps = [ dict( method="restyle", @@ -589,21 +607,21 @@ def _ui_node_size(mapper_plot, node_sizes): ) -def _ui_color(mapper_plot, colors, titles, agg): +def _ui_color(mapper_plot, colors, titles: List[str], agg) -> dict: colors_arr = np.array(colors) colors_num = colors_arr.shape[1] if colors_arr.ndim == 2 else 1 - def _colors_agg(i): + def _colors_agg(i: int) -> dict: if i is None: arr = colors_arr else: arr = colors_arr[:, i] if colors_arr.ndim == 2 else colors_arr return aggregate_graph(arr, mapper_plot.graph, agg) - def _colors(i): + def _colors(i: int) -> List[float]: return list(_colors_agg(i).values()) - def _edge_colors(i): + def _edge_colors(i: int) -> List[float]: colors_avg = [] colors_agg = _colors_agg(i) for edge in mapper_plot.graph.edges(): @@ -613,7 +631,7 @@ def _edge_colors(i): colors_avg.append(c1) return colors_avg - def _update_colors(i): + def _update_colors(i: int) -> dict: arr_agg = _colors_agg(i) arr = list(arr_agg.values()) scatter_text = _text(mapper_plot, arr_agg) From ef0351d6912e94e3efbeef04593a6a59e52e6244 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Sun, 1 Jun 2025 18:29:04 +0200 Subject: [PATCH 08/15] removed deprecation --- src/tdamapper/plot.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/tdamapper/plot.py b/src/tdamapper/plot.py index fea6fb9..4c43388 100644 --- a/src/tdamapper/plot.py +++ b/src/tdamapper/plot.py @@ -206,10 +206,6 @@ def plot_plotly( cmap=cmap, ) - @deprecated( - "This method is deprecated and will be removed in a future release. " - "Use a new instance of tdamapper.plot.MapperPlot." - ) def plot_plotly_update( self, fig, From 28a58bd6882c02b5a5fa7dd1668bb936b3445312 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Sun, 1 Jun 2025 20:10:26 +0200 Subject: [PATCH 09/15] Fixed colorbar --- app/nicegui_app.py | 10 +++++++--- src/tdamapper/plot_backends/plot_plotly.py | 20 ++++++++------------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/app/nicegui_app.py b/app/nicegui_app.py index 496e215..d5d3b6c 100644 --- a/app/nicegui_app.py +++ b/app/nicegui_app.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import plotly.graph_objs as go from nicegui import run, ui from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans @@ -388,14 +389,17 @@ def update_plot(self): iterations=iterations, seed=42, ) + colors = pd.concat([self.labels, self.X], axis=1) + colors_arr = colors.to_numpy() + color_names = colors.columns.tolist() mapper_fig = mapper_plot.plot_plotly( - colors=self.labels, + colors=colors_arr, cmap=["jet", "viridis", "cividis"], agg=mode, - title="mode of digits", + title=color_names, width=800, height=800, - node_size=0.5, + node_size=list(0.125 * x for x in range(17)), ) mapper_fig.layout.width = None mapper_fig.layout.autosize = True diff --git a/src/tdamapper/plot_backends/plot_plotly.py b/src/tdamapper/plot_backends/plot_plotly.py index ffdea18..e9d341b 100644 --- a/src/tdamapper/plot_backends/plot_plotly.py +++ b/src/tdamapper/plot_backends/plot_plotly.py @@ -284,10 +284,10 @@ def _set_colors(mapper_plot, fig: go.Figure, colors, agg): ) -def _set_title(mapper_plot, fig: go.Figure, color_name: str): +def _set_title(fig: go.Figure, color_name: str): fig.update_traces( patch=dict( - marker_colorbar=_colorbar(mapper_plot, color_name), + marker_colorbar=_colorbar(color_name), ), selector=dict(name=_NODES_TRACE), ) @@ -371,7 +371,7 @@ def _update( if height is not None: _set_height(fig, height) if titles is not None: - _set_title(mapper_plot, fig, titles[0]) + _set_title(fig, titles[0]) if node_sizes is not None: _set_node_size(mapper_plot, fig, node_sizes[len(node_sizes) // 2]) if (colors is not None) and (agg is not None): @@ -398,7 +398,7 @@ def _nodes_trace(mapper_plot, node_pos_arr): line_color=_NODE_OUTER_COLOR, line_colorscale=DEFAULT_CMAP, colorscale=DEFAULT_CMAP, - colorbar=_colorbar(mapper_plot, DEFAULT_TITLE), + colorbar=_colorbar(DEFAULT_TITLE), ), ) if mapper_plot.dim == 3: @@ -437,9 +437,7 @@ def _edges_trace(mapper_plot, edge_pos_arr): return go.Scatter(scatter) -def _colorbar( - mapper_plot, title: str -) -> Union[go.scatter3d.marker.ColorBar, go.scatter.marker.ColorBar]: +def _colorbar(title: str) -> dict: cbar = dict( showticklabels=True, outlinewidth=1, @@ -458,10 +456,7 @@ def _colorbar( ) if title is not None: cbar["title"] = title - if mapper_plot.dim == 3: - return go.scatter3d.marker.ColorBar(cbar) - elif mapper_plot.dim == 2: - return go.scatter.marker.ColorBar(cbar) + return cbar def _text(mapper_plot, colors): @@ -635,7 +630,7 @@ def _update_colors(i: int) -> dict: arr_agg = _colors_agg(i) arr = list(arr_agg.values()) scatter_text = _text(mapper_plot, arr_agg) - cbar = _colorbar(mapper_plot, titles[i]) + cbar = _colorbar(titles[i]) if mapper_plot.dim == 2: return { "text": [scatter_text], @@ -656,6 +651,7 @@ def _update_colors(i: int) -> dict: "line.cmax": [max(arr_edge, default=None), None], "line.cmin": [min(arr_edge, default=None), None], } + return {} target_traces = [1] if mapper_plot.dim == 2 else [0, 1] From 665e3c5efea2ed08a441f02c552d27821da172d7 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Mon, 2 Jun 2025 12:40:45 +0200 Subject: [PATCH 10/15] Added app to src directory. Improved options with csv upload --- pyproject.toml | 9 ++ app/nicegui_app.py => src/tdamapper/app.py | 173 ++++++++++++++++----- 2 files changed, 146 insertions(+), 36 deletions(-) rename app/nicegui_app.py => src/tdamapper/app.py (71%) diff --git a/pyproject.toml b/pyproject.toml index c07e8b1..b8f3d20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,15 @@ dev = [ "flake8", "nbformat>=4.2.0", ] +app = [ + "pandas<3.0.0", + "scikit-learn<1.6.0", + "nicegui>=2.18.0,<3.0.0", + "umap-learn<0.6.0", +] + +[project.scripts] +tda-mapper-app = "tdamapper.app:main" [project.urls] Homepage = "https://github.com/lucasimi/tda-mapper-python" diff --git a/app/nicegui_app.py b/src/tdamapper/app.py similarity index 71% rename from app/nicegui_app.py rename to src/tdamapper/app.py index d5d3b6c..cd26d14 100644 --- a/app/nicegui_app.py +++ b/src/tdamapper/app.py @@ -1,9 +1,11 @@ +import logging + import numpy as np import pandas as pd import plotly.graph_objs as go from nicegui import run, ui from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans -from sklearn.datasets import load_digits, load_iris +from sklearn.datasets import fetch_openml, load_digits, load_iris from sklearn.decomposition import PCA from umap import UMAP @@ -12,6 +14,9 @@ from tdamapper.learn import MapperAlgorithm from tdamapper.plot import MapperPlot +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + def mode(arr): values, counts = np.unique(arr, return_counts=True) @@ -20,6 +25,14 @@ def mode(arr): return np.nanmean(mode_values) +def _fix_data(data): + df = pd.DataFrame(data) + df = df.select_dtypes(include="number") + df.dropna(axis=1, how="all", inplace=True) + df.fillna(df.mean(), inplace=True) + return df + + def _identity(X): return X @@ -66,6 +79,9 @@ def _func(X): DRAW_3D = "3D" DRAW_2D = "2D" DRAW_ITERATIONS = 50 +DRAW_MEAN = "Mean" +DRAW_MEDIAN = "Median" +DRAW_MODE = "Mode" class App: @@ -96,7 +112,9 @@ def build_dataset(self): value=DATA_SOURCE_EXAMPLE, ) self.data_source_csv = ui.upload( - on_upload=self.update_dataset_handler, + on_upload=self.update_csv_handler, + auto_upload=True, + label="Upload CSV", ).classes("w-full") self.data_source_csv.bind_visibility_from( target_object=self.data_source_type, @@ -278,102 +296,166 @@ def build_draw(self): value=DRAW_ITERATIONS, on_change=self.update_plot_handler, ) + self.draw_aggregation = ui.select( + label="Aggregation", + options=[ + DRAW_MEAN, + DRAW_MEDIAN, + DRAW_MODE, + ], + value=DRAW_MEAN, + on_change=self.update_plot_handler, + ) def build_plot(self): fig = go.Figure() fig.layout.width = None fig.layout.autosize = True self.plot_container = ui.element("div").classes("w-full h-full") - with self.plot_container: - ui.plotly(go.Figure()) def render_dataset(self): source_type = self.data_source_type.value if source_type == DATA_SOURCE_EXAMPLE: name = self.data_source_example_file.value if name == DATA_SOURCE_EXAMPLE_DIGITS: - X, y = load_digits(return_X_y=True, as_frame=True) - return X, y + df_X, df_y = load_digits(return_X_y=True, as_frame=True) elif name == DATA_SOURCE_EXAMPLE_IRIS: - X, y = load_iris(return_X_y=True, as_frame=True) - return X, y + df_X, df_y = load_iris(return_X_y=True, as_frame=True) elif source_type == DATA_SOURCE_CSV: - pass + csv_file = self.csv_file + if csv_file is None: + logger.warning("No CSV file uploaded") + df_X, df_y = pd.DataFrame(), pd.Series() + else: + df_X = pd.read_csv(csv_file.content) + df_y = pd.Series() + elif source_type == DATA_SOURCE_OPENML: + code = self.data_source_openml.value + if not code: + logger.warning("No OpenML code provided") + df_X, df_y = pd.DataFrame(), pd.Series() + else: + df_X, df_y = fetch_openml(code, return_X_y=True, as_frame=True) + df_X = _fix_data(df_X) + df_y = _fix_data(df_y) + return df_X, df_y def render_lens(self): if self.lens_type.value == LENS_IDENTITY: return _identity elif self.lens_type.value == LENS_PCA: - n = int(self.pca_n_components.value) + n = 2 + if self.pca_n_components.value is not None: + n = int(self.pca_n_components.value) return _pca(n) elif self.lens_type.value == LENS_UMAP: - n = int(self.umap_n_components.value) + n = 2 + if self.umap_n_components.value is not None: + n = int(self.umap_n_components.value) return _umap(n) def render_cover(self): if self.cover_type.value == COVER_TRIVIAL: return TrivialCover() elif self.cover_type.value == COVER_BALL: - radius = float(self.cover_ball_radius.value) + radius = 1.0 + if self.cover_ball_radius.value is not None: + radius = float(self.cover_ball_radius.value) return BallCover(radius=radius) elif self.cover_type.value == COVER_CUBICAL: - n_intervals = int(self.cover_cubical_n_intervals.value) - overlap_frac = float(self.cover_cubical_overlap_frac.value) + n_intervals = 1 + if self.cover_cubical_n_intervals.value is not None: + n_intervals = int(self.cover_cubical_n_intervals.value) + overlap_frac = None + if self.cover_cubical_overlap_frac.value is not None: + overlap_frac = float(self.cover_cubical_overlap_frac.value) return CubicalCover(n_intervals=n_intervals, overlap_frac=overlap_frac) elif self.cover_type.value == COVER_KNN: - neighbors = int(self.cover_knn_neighbors.value) + neighbors = 1 + if self.cover_knn_neighbors.value is not None: + neighbors = int(self.cover_knn_neighbors.value) return KNNCover(neighbors=neighbors) def render_clustering(self): + clustering_type = self.clustering_type.value if self.clustering_type.value == CLUSTERING_TRIVIAL: return TrivialClustering() - elif self.clustering_type.value == CLUSTERING_KMEANS: - n_clusters = int(self.clustering_kmeans_n_clusters.value) + elif clustering_type == CLUSTERING_KMEANS: + n_clusters = 1 + if self.clustering_kmeans_n_clusters.value is not None: + n_clusters = int(self.clustering_kmeans_n_clusters.value) return KMeans(n_clusters) - elif self.clustering_type.value == CLUSTERING_DBSCAN: - eps = float(self.clustering_dbscan_eps.value) - min_samples = int(self.clustering_dbscan_min_samples.value) + elif clustering_type == CLUSTERING_DBSCAN: + eps = 0.5 + if self.clustering_dbscan_eps.value is not None: + eps = float(self.clustering_dbscan_eps.value) + min_samples = 5 + if self.clustering_dbscan_min_samples.value is not None: + min_samples = int(self.clustering_dbscan_min_samples.value) return DBSCAN(eps=eps, min_samples=min_samples) - elif self.clustering_type == CLUSTERING_AGGLOMERATIVE: - n_clusters = int(self.clustering_agglomerative_n_clusters.value) + elif clustering_type == CLUSTERING_AGGLOMERATIVE: + n_clusters = 2 + if self.clustering_agglomerative_n_clusters.value is not None: + n_clusters = int(self.clustering_agglomerative_n_clusters.value) return AgglomerativeClustering(n_clusters=n_clusters) - async def update_graph_handler(self, _=None): - await run.io_bound(self.update_graph) + async def update_csv_handler(self, file): + await run.io_bound(self.update_csv, file) + await self.update_dataset_handler() async def update_dataset_handler(self, _=None): await run.io_bound(self.update_dataset) + await self.update_graph_handler() + + async def update_graph_handler(self, _=None): + await run.io_bound(self.update_graph) + await self.update_plot_handler() + + async def update_plot_handler(self, _=None): + await run.io_bound(self.update_plot) + + def update_csv(self, file): + if file is None: + logger.warning("No file uploaded") + return + self.csv_file = file def update_dataset(self, _=None): - self.X, self.labels = self.render_dataset() - self.update_graph() + self.df_X, self.labels = self.render_dataset() def update_graph(self, _=None): self.lens = self.render_lens() if self.lens is None: + logger.warning("No lens selected") return - if self.X is None: + if self.df_X is None or self.df_X.empty: + logger.warning("No dataset loaded for computation") return + logger.info(f"Uploaded dataset with shape {self.df_X.shape}") + self.X = self.df_X.to_numpy() self.y = self.lens(self.X) cover = self.render_cover() if cover is None: + logger.warning("No cover selected") return clustering = self.render_clustering() if clustering is None: + logger.warning("No clustering selected") return mapper_algo = MapperAlgorithm( cover=cover, clustering=clustering, verbose=False, ) + logger.info(f"Configuration: {mapper_algo}") self.mapper_graph = mapper_algo.fit_transform(self.X, self.y) - self.update_plot() - - async def update_plot_handler(self, _=None): - await run.io_bound(self.update_plot) def update_plot(self): + if self.df_X is None or self.df_X.empty: + logger.warning("No dataset loaded for plotting") + return if self.mapper_graph is None: + logger.warning("No graph computed") return dim = 3 @@ -389,13 +471,23 @@ def update_plot(self): iterations=iterations, seed=42, ) - colors = pd.concat([self.labels, self.X], axis=1) + + colors = pd.concat([self.labels, self.df_X], axis=1) colors_arr = colors.to_numpy() color_names = colors.columns.tolist() + + agg = np.nanmean + if self.draw_aggregation.value == DRAW_MEAN: + agg = np.nanmean + elif self.draw_aggregation.value == DRAW_MEDIAN: + agg = np.nanmedian + elif self.draw_aggregation.value == DRAW_MODE: + agg = mode + mapper_fig = mapper_plot.plot_plotly( colors=colors_arr, cmap=["jet", "viridis", "cividis"], - agg=mode, + agg=agg, title=color_names, width=800, height=800, @@ -408,8 +500,10 @@ def update_plot(self): ui.plotly(mapper_fig) def __init__(self): + self.csv_file = None + self.df_X = None with ui.row().classes("w-full h-screen m-0 p-0 gap-0 overflow-hidden"): - with ui.column().classes("w-64 h-full m-0 p-0"): # fixed-width sidebar + with ui.column().classes("w-64 h-full m-0 p-0"): with ui.column().classes("w-64 h-full overflow-y-auto p-3 gap-2"): with ui.card().classes("w-full"): ui.markdown("#### 📊 Data") @@ -429,7 +523,14 @@ def __init__(self): self.build_draw() self.build_plot() self.update_dataset() + self.update_graph() + self.update_plot() + + +def main(): + App() + ui.run() -app = App() -ui.run() +if __name__ in {"__main__", "__mp_main__", "tdamapper.app"}: + main() From 67a2dc3e8c14ba3fa78b1be21c79556f5fc4b32c Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Mon, 2 Jun 2025 23:05:25 +0200 Subject: [PATCH 11/15] Refactored App with state and async callbacks --- src/tdamapper/app.py | 507 ++++++++++++++++++++++--------------------- 1 file changed, 262 insertions(+), 245 deletions(-) diff --git a/src/tdamapper/app.py b/src/tdamapper/app.py index cd26d14..9d001b0 100644 --- a/src/tdamapper/app.py +++ b/src/tdamapper/app.py @@ -1,5 +1,7 @@ import logging +from dataclasses import asdict, dataclass +import networkx as nx import numpy as np import pandas as pd import plotly.graph_objs as go @@ -14,47 +16,6 @@ from tdamapper.learn import MapperAlgorithm from tdamapper.plot import MapperPlot -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - - -def mode(arr): - values, counts = np.unique(arr, return_counts=True) - max_count = np.max(counts) - mode_values = values[counts == max_count] - return np.nanmean(mode_values) - - -def _fix_data(data): - df = pd.DataFrame(data) - df = df.select_dtypes(include="number") - df.dropna(axis=1, how="all", inplace=True) - df.fillna(df.mean(), inplace=True) - return df - - -def _identity(X): - return X - - -def _pca(n_components): - pca = PCA(n_components=n_components, random_state=42) - - def _func(X): - return pca.fit_transform(X) - - return _func - - -def _umap(n_components): - um = UMAP(n_components=n_components, random_state=42) - - def _func(X): - return um.fit_transform(X) - - return _func - - LENS_IDENTITY = "Identity" LENS_PCA = "PCA" LENS_UMAP = "UMAP" @@ -75,6 +36,7 @@ def _func(X): DATA_SOURCE_EXAMPLE_DIGITS = "Digits" DATA_SOURCE_EXAMPLE_IRIS = "Iris" +SOURCE_OPENML = "554" DRAW_3D = "3D" DRAW_2D = "2D" @@ -84,6 +46,203 @@ def _func(X): DRAW_MODE = "Mode" +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +@dataclass +class State: + source_type: str = DATA_SOURCE_EXAMPLE + source_name: str = DATA_SOURCE_EXAMPLE_DIGITS + source_csv: str = "tmp/data.csv" + source_openml: str = SOURCE_OPENML + lens_type: str = LENS_PCA + lens_pca_n_components: int = 2 + lens_umap_n_components: int = 2 + cover_type: str = COVER_CUBICAL + cover_cubical_n_intervals: int = 2 + cover_cubical_overlap_frac: float = 0.25 + cover_knn_neighbors: int = 10 + cover_ball_radius: float = 100.0 + clustering_type: str = CLUSTERING_TRIVIAL + clustering_kmeans_n_clusters: int = 2 + clustering_dbscan_eps: float = 0.5 + clustering_dbscan_min_samples: int = 5 + clustering_agglomerative_n_clusters: int = 2 + draw_dim: int = 3 + draw_aggregation: str = DRAW_MEAN + draw_iterations: int = 50 + + +def _fix_data(data): + df = pd.DataFrame(data) + df = df.select_dtypes(include="number") + df.dropna(axis=1, how="all", inplace=True) + df.fillna(df.mean(), inplace=True) + return df + + +def get_dataset(state: State): + source_type = state.source_type + source_name = state.source_name + csv_file = state.source_csv + openml_code = state.source_openml + df_X, df_y = pd.DataFrame(), pd.Series() + if source_type == DATA_SOURCE_EXAMPLE: + if source_name == DATA_SOURCE_EXAMPLE_DIGITS: + df_X, df_y = load_digits(return_X_y=True, as_frame=True) + elif source_name == DATA_SOURCE_EXAMPLE_IRIS: + df_X, df_y = load_iris(return_X_y=True, as_frame=True) + elif source_type == DATA_SOURCE_CSV: + if csv_file is None: + logger.warning("No CSV file uploaded") + df_X, df_y = pd.DataFrame(), pd.Series() + else: + df_X = pd.read_csv() + df_y = pd.Series() + elif source_type == DATA_SOURCE_OPENML: + if not openml_code: + logger.warning("No OpenML code provided") + df_X, df_y = pd.DataFrame(), pd.Series() + else: + df_X, df_y = fetch_openml(openml_code, return_X_y=True, as_frame=True) + else: + logger.error(f"Unknown data source type: {source_type}") + return pd.DataFrame(), pd.Series() + df_X = _fix_data(df_X) + df_y = _fix_data(df_y) + return df_X, df_y + + +def get_lens(state: State): + lens_type = state.lens_type + if lens_type == LENS_IDENTITY: + return lambda X: X + elif lens_type == LENS_PCA: + n_components = int(state.lens_pca_n_components) + pca = PCA(n_components=n_components, random_state=42) + return lambda X: pca.fit_transform(X) + elif lens_type == LENS_UMAP: + n_components = int(state.lens_umap_n_components) + umap = UMAP(n_components=n_components, random_state=42) + return lambda X: umap.fit_transform(X) + else: + logger.error(f"Unknown lens type: {lens_type}") + return None + + +def get_cover(state: State): + cover_type = state.cover_type + if cover_type == COVER_TRIVIAL: + return TrivialCover() + elif cover_type == COVER_CUBICAL: + n_intervals = int(state.cover_cubical_n_intervals) + overlap_frac = float(state.cover_cubical_overlap_frac) + return CubicalCover(n_intervals=n_intervals, overlap_frac=overlap_frac) + elif cover_type == COVER_BALL: + radius = float(state.cover_ball_radius) + return BallCover(radius=radius) + elif cover_type == COVER_KNN: + neighbors = int(state.cover_knn_neighbors) + return KNNCover(neighbors=neighbors) + else: + logger.error(f"Unknown cover type: {cover_type}") + return None + + +def get_clustering(state: State): + clustering_type = state.clustering_type + if clustering_type == CLUSTERING_TRIVIAL: + return TrivialClustering() + elif clustering_type == CLUSTERING_KMEANS: + n_clusters = state.clustering_kmeans_n_clusters + return KMeans(n_clusters=n_clusters, random_state=42) + elif clustering_type == CLUSTERING_DBSCAN: + eps = state.clustering_dbscan_eps + min_samples = state.clustering_dbscan_min_samples + return DBSCAN(eps=eps, min_samples=min_samples) + elif clustering_type == CLUSTERING_AGGLOMERATIVE: + n_clusters = state.clustering_agglomerative_n_clusters + return AgglomerativeClustering(n_clusters=n_clusters) + else: + logger.error(f"Unknown clustering type: {clustering_type}") + return None + + +def compute_mapper(**kwargs): + state = State(**kwargs) + + df_X, labels = get_dataset(state) + if df_X.empty: + logger.warning("No dataset loaded") + return None + + lens = get_lens(state) + if lens is None: + logger.warning("No lens selected") + return None + + X = df_X.to_numpy() + y = lens(X) + + cover = get_cover(state) + if cover is None: + logger.warning("No cover selected") + return None + + clustering = get_clustering(state) + if clustering is None: + logger.warning("No clustering selected") + return None + + mapper_algo = MapperAlgorithm( + cover=cover, + clustering=clustering, + verbose=False, + ) + logger.info(f"Mapper configuration: {mapper_algo}") + mapper_graph = mapper_algo.fit_transform(X, y) + + mapper_plot = MapperPlot( + mapper_graph, + dim=state.draw_dim, + iterations=state.draw_iterations, + seed=42, + ) + + colors = pd.concat([labels, df_X], axis=1) + colors_arr = colors.to_numpy() + color_names = colors.columns.tolist() + + mapper_fig = mapper_plot.plot_plotly( + colors=colors_arr, + cmap=["jet", "viridis", "cividis"], + agg=np.nanmean, + title=color_names, + width=800, + height=800, + node_size=list(0.125 * x for x in range(17)), + ) + mapper_fig.layout.width = None + mapper_fig.layout.autosize = True + + return mapper_graph, mapper_fig + + +def mode(arr): + values, counts = np.unique(arr, return_counts=True) + max_count = np.max(counts) + mode_values = values[counts == max_count] + return np.nanmean(mode_values) + + +def update_csv(file): + if file is None: + logger.warning("No file uploaded") + return None + return file + + class App: def build_dataset(self): @@ -95,7 +254,7 @@ def build_dataset(self): DATA_SOURCE_OPENML, ], value=DATA_SOURCE_EXAMPLE, - on_change=self.update_dataset_handler, + on_change=self.update, ).classes("w-full") self.data_source_example_file = ui.select( label="File", @@ -104,7 +263,7 @@ def build_dataset(self): DATA_SOURCE_EXAMPLE_IRIS, ], value=DATA_SOURCE_EXAMPLE_DIGITS, - on_change=self.update_dataset_handler, + on_change=self.update, ).classes("w-full") self.data_source_example_file.bind_visibility_from( target_object=self.data_source_type, @@ -112,7 +271,7 @@ def build_dataset(self): value=DATA_SOURCE_EXAMPLE, ) self.data_source_csv = ui.upload( - on_upload=self.update_csv_handler, + on_upload=self.update, auto_upload=True, label="Upload CSV", ).classes("w-full") @@ -123,7 +282,7 @@ def build_dataset(self): ) self.data_source_openml = ui.input( label="OpenML Code", - on_change=self.update_dataset_handler, + on_change=self.update, ).classes("w-full") self.data_source_openml.bind_visibility_from( target_object=self.data_source_type, @@ -140,14 +299,14 @@ def build_lens(self): LENS_UMAP, ], value=LENS_PCA, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.pca_n_components = ui.number( label="PCA Components", min=1, max=10, value=2, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.pca_n_components.bind_visibility_from( target_object=self.lens_type, @@ -159,7 +318,7 @@ def build_lens(self): min=1, max=10, value=2, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.umap_n_components.bind_visibility_from( target_object=self.lens_type, @@ -177,14 +336,14 @@ def build_cover(self): COVER_KNN, ], value=COVER_CUBICAL, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.cover_cubical_n_intervals = ui.number( label="Intervals", min=1, max=100, value=2, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.cover_cubical_n_intervals.bind_visibility_from( target_object=self.cover_type, @@ -196,7 +355,7 @@ def build_cover(self): min=0.0, max=0.5, value=0.25, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.cover_cubical_overlap_frac.bind_visibility_from( target_object=self.cover_type, @@ -207,7 +366,7 @@ def build_cover(self): label="Radius", min=0.0, value=100.0, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.cover_ball_radius.bind_visibility_from( target_object=self.cover_type, @@ -218,7 +377,7 @@ def build_cover(self): label="Neighbors", min=0, value=10, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.cover_knn_neighbors.bind_visibility_from( target_object=self.cover_type, @@ -236,13 +395,13 @@ def build_clustering(self): CLUSTERING_DBSCAN, ], value=CLUSTERING_TRIVIAL, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.clustering_kmeans_n_clusters = ui.number( label="Clusters", min=1, value=2, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.clustering_kmeans_n_clusters.bind_visibility_from( target_object=self.clustering_type, @@ -253,7 +412,7 @@ def build_clustering(self): label="Eps", min=0.0, value=0.5, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.clustering_dbscan_eps.bind_visibility_from( target_object=self.clustering_type, @@ -264,7 +423,7 @@ def build_clustering(self): label="Min Samples", min=1, value=5, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.clustering_dbscan_min_samples.bind_visibility_from( target_object=self.clustering_type, @@ -275,7 +434,7 @@ def build_clustering(self): label="Clusters", min=1, value=2, - on_change=self.update_graph_handler, + on_change=self.update, ).classes("w-full") self.clustering_agglomerative_n_clusters.bind_visibility_from( target_object=self.clustering_type, @@ -287,14 +446,14 @@ def build_draw(self): self.draw_3d = ui.toggle( options=[DRAW_2D, DRAW_3D], value=DRAW_3D, - on_change=self.update_plot_handler, + on_change=self.update, ) self.draw_iterations = ui.number( label="Layout Iterations", min=1, max=1000, value=DRAW_ITERATIONS, - on_change=self.update_plot_handler, + on_change=self.update, ) self.draw_aggregation = ui.select( label="Aggregation", @@ -304,7 +463,7 @@ def build_draw(self): DRAW_MODE, ], value=DRAW_MEAN, - on_change=self.update_plot_handler, + on_change=self.update, ) def build_plot(self): @@ -313,195 +472,53 @@ def build_plot(self): fig.layout.autosize = True self.plot_container = ui.element("div").classes("w-full h-full") - def render_dataset(self): - source_type = self.data_source_type.value - if source_type == DATA_SOURCE_EXAMPLE: - name = self.data_source_example_file.value - if name == DATA_SOURCE_EXAMPLE_DIGITS: - df_X, df_y = load_digits(return_X_y=True, as_frame=True) - elif name == DATA_SOURCE_EXAMPLE_IRIS: - df_X, df_y = load_iris(return_X_y=True, as_frame=True) - elif source_type == DATA_SOURCE_CSV: - csv_file = self.csv_file - if csv_file is None: - logger.warning("No CSV file uploaded") - df_X, df_y = pd.DataFrame(), pd.Series() - else: - df_X = pd.read_csv(csv_file.content) - df_y = pd.Series() - elif source_type == DATA_SOURCE_OPENML: - code = self.data_source_openml.value - if not code: - logger.warning("No OpenML code provided") - df_X, df_y = pd.DataFrame(), pd.Series() - else: - df_X, df_y = fetch_openml(code, return_X_y=True, as_frame=True) - df_X = _fix_data(df_X) - df_y = _fix_data(df_y) - return df_X, df_y - - def render_lens(self): - if self.lens_type.value == LENS_IDENTITY: - return _identity - elif self.lens_type.value == LENS_PCA: - n = 2 - if self.pca_n_components.value is not None: - n = int(self.pca_n_components.value) - return _pca(n) - elif self.lens_type.value == LENS_UMAP: - n = 2 - if self.umap_n_components.value is not None: - n = int(self.umap_n_components.value) - return _umap(n) - - def render_cover(self): - if self.cover_type.value == COVER_TRIVIAL: - return TrivialCover() - elif self.cover_type.value == COVER_BALL: - radius = 1.0 - if self.cover_ball_radius.value is not None: - radius = float(self.cover_ball_radius.value) - return BallCover(radius=radius) - elif self.cover_type.value == COVER_CUBICAL: - n_intervals = 1 - if self.cover_cubical_n_intervals.value is not None: - n_intervals = int(self.cover_cubical_n_intervals.value) - overlap_frac = None - if self.cover_cubical_overlap_frac.value is not None: - overlap_frac = float(self.cover_cubical_overlap_frac.value) - return CubicalCover(n_intervals=n_intervals, overlap_frac=overlap_frac) - elif self.cover_type.value == COVER_KNN: - neighbors = 1 - if self.cover_knn_neighbors.value is not None: - neighbors = int(self.cover_knn_neighbors.value) - return KNNCover(neighbors=neighbors) - - def render_clustering(self): - clustering_type = self.clustering_type.value - if self.clustering_type.value == CLUSTERING_TRIVIAL: - return TrivialClustering() - elif clustering_type == CLUSTERING_KMEANS: - n_clusters = 1 - if self.clustering_kmeans_n_clusters.value is not None: - n_clusters = int(self.clustering_kmeans_n_clusters.value) - return KMeans(n_clusters) - elif clustering_type == CLUSTERING_DBSCAN: - eps = 0.5 - if self.clustering_dbscan_eps.value is not None: - eps = float(self.clustering_dbscan_eps.value) - min_samples = 5 - if self.clustering_dbscan_min_samples.value is not None: - min_samples = int(self.clustering_dbscan_min_samples.value) - return DBSCAN(eps=eps, min_samples=min_samples) - elif clustering_type == CLUSTERING_AGGLOMERATIVE: - n_clusters = 2 - if self.clustering_agglomerative_n_clusters.value is not None: - n_clusters = int(self.clustering_agglomerative_n_clusters.value) - return AgglomerativeClustering(n_clusters=n_clusters) - - async def update_csv_handler(self, file): - await run.io_bound(self.update_csv, file) - await self.update_dataset_handler() - - async def update_dataset_handler(self, _=None): - await run.io_bound(self.update_dataset) - await self.update_graph_handler() - - async def update_graph_handler(self, _=None): - await run.io_bound(self.update_graph) - await self.update_plot_handler() - - async def update_plot_handler(self, _=None): - await run.io_bound(self.update_plot) - - def update_csv(self, file): - if file is None: - logger.warning("No file uploaded") - return - self.csv_file = file - - def update_dataset(self, _=None): - self.df_X, self.labels = self.render_dataset() - - def update_graph(self, _=None): - self.lens = self.render_lens() - if self.lens is None: - logger.warning("No lens selected") - return - if self.df_X is None or self.df_X.empty: - logger.warning("No dataset loaded for computation") - return - logger.info(f"Uploaded dataset with shape {self.df_X.shape}") - self.X = self.df_X.to_numpy() - self.y = self.lens(self.X) - cover = self.render_cover() - if cover is None: - logger.warning("No cover selected") - return - clustering = self.render_clustering() - if clustering is None: - logger.warning("No clustering selected") - return - mapper_algo = MapperAlgorithm( - cover=cover, - clustering=clustering, - verbose=False, + async def update(self, _=None): + self.state.source_type = str(self.data_source_type.value) + self.state.source_name = str(self.data_source_example_file.value) + self.state.source_openml = str(self.data_source_openml.value) + self.state.lens_type = str(self.lens_type.value) + self.state.lens_pca_n_components = int(self.pca_n_components.value) + self.state.lens_umap_n_components = int(self.umap_n_components.value) + self.state.cover_type = str(self.cover_type.value) + self.state.cover_cubical_n_intervals = int(self.cover_cubical_n_intervals.value) + self.state.cover_cubical_overlap_frac = float( + self.cover_cubical_overlap_frac.value ) - logger.info(f"Configuration: {mapper_algo}") - self.mapper_graph = mapper_algo.fit_transform(self.X, self.y) - - def update_plot(self): - if self.df_X is None or self.df_X.empty: - logger.warning("No dataset loaded for plotting") - return - if self.mapper_graph is None: - logger.warning("No graph computed") - return - - dim = 3 - if self.draw_3d.value == DRAW_3D: - dim = 3 - elif self.draw_3d.value == DRAW_2D: - dim = 2 - - iterations = int(self.draw_iterations.value) - mapper_plot = MapperPlot( - self.mapper_graph, - dim=dim, - iterations=iterations, - seed=42, + self.state.cover_knn_neighbors = int(self.cover_knn_neighbors.value) + self.state.cover_ball_radius = float(self.cover_ball_radius.value) + self.state.clustering_type = str(self.clustering_type.value) + self.state.clustering_kmeans_n_clusters = int( + self.clustering_kmeans_n_clusters.value ) - - colors = pd.concat([self.labels, self.df_X], axis=1) - colors_arr = colors.to_numpy() - color_names = colors.columns.tolist() - - agg = np.nanmean - if self.draw_aggregation.value == DRAW_MEAN: - agg = np.nanmean - elif self.draw_aggregation.value == DRAW_MEDIAN: - agg = np.nanmedian - elif self.draw_aggregation.value == DRAW_MODE: - agg = mode - - mapper_fig = mapper_plot.plot_plotly( - colors=colors_arr, - cmap=["jet", "viridis", "cividis"], - agg=agg, - title=color_names, - width=800, - height=800, - node_size=list(0.125 * x for x in range(17)), + self.state.clustering_dbscan_eps = float(self.clustering_dbscan_eps.value) + self.state.clustering_dbscan_min_samples = int( + self.clustering_dbscan_min_samples.value + ) + self.state.clustering_agglomerative_n_clusters = int( + self.clustering_agglomerative_n_clusters.value ) - mapper_fig.layout.width = None - mapper_fig.layout.autosize = True + self.state.draw_dim = 3 if self.draw_3d.value == DRAW_3D else 2 + self.state.draw_iterations = int(self.draw_iterations.value) + self.state.draw_aggregation = self.draw_aggregation.value + await self.render() + + async def render(self): + mapper_graph, mapper_fig = await run.cpu_bound( + compute_mapper, + **asdict(self.state), + ) + self.plot_container.clear() with self.plot_container: ui.plotly(mapper_fig) + async def update_csv_handler(self, file): + self.csv_file = run.cpu_bound(update_csv, file) + await self.update_dataset_handler() + await self.update_plot_handler() + def __init__(self): - self.csv_file = None - self.df_X = None + self.state = State() with ui.row().classes("w-full h-screen m-0 p-0 gap-0 overflow-hidden"): with ui.column().classes("w-64 h-full m-0 p-0"): with ui.column().classes("w-64 h-full overflow-y-auto p-3 gap-2"): @@ -522,9 +539,9 @@ def __init__(self): with ui.row(align_items="baseline"): self.build_draw() self.build_plot() - self.update_dataset() - self.update_graph() - self.update_plot() + mapper_graph, mapper_fig = compute_mapper(**asdict(self.state)) + with self.plot_container: + ui.plotly(mapper_fig) def main(): From db38258a869121e104a20a907b67432300acb887 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Tue, 3 Jun 2025 07:53:04 +0200 Subject: [PATCH 12/15] Improved UI. Added storage --- src/tdamapper/app.py | 129 +++++++++++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 48 deletions(-) diff --git a/src/tdamapper/app.py b/src/tdamapper/app.py index 9d001b0..26ae7f1 100644 --- a/src/tdamapper/app.py +++ b/src/tdamapper/app.py @@ -5,30 +5,41 @@ import numpy as np import pandas as pd import plotly.graph_objs as go -from nicegui import run, ui +from nicegui import app, run, ui from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans from sklearn.datasets import fetch_openml, load_digits, load_iris from sklearn.decomposition import PCA from umap import UMAP -from tdamapper.core import TrivialClustering, TrivialCover +from tdamapper.core import Cover, TrivialClustering, TrivialCover from tdamapper.cover import BallCover, CubicalCover, KNNCover from tdamapper.learn import MapperAlgorithm from tdamapper.plot import MapperPlot +RANDOM_STATE = 42 LENS_IDENTITY = "Identity" LENS_PCA = "PCA" +LENS_PCA_N_COMPONENTS = 2 LENS_UMAP = "UMAP" +LENS_UMAP_N_COMPONENTS = 2 COVER_TRIVIAL = "Trivial" COVER_CUBICAL = "Cubical" +COVER_CUBICAL_N_INTERVALS = 2 +COVER_CUBICAL_OVERLAP_FRAC = 0.25 COVER_BALL = "Ball" +COVER_BALL_RADIUS = 100.0 COVER_KNN = "KNN" +COVER_KNN_NEIGHBORS = 10 CLUSTERING_TRIVIAL = "Trivial" CLUSTERING_KMEANS = "KMeans" +CLUSTERING_KMEANS_N_CLUSTERS = 2 CLUSTERING_AGGLOMERATIVE = "Agglomerative" +CLUSTERING_AGGLOMERATIVE_N_CLUSTERS = 2 CLUSTERING_DBSCAN = "DBSCAN" +CLUSTERING_DBSCAN_EPS = 0.5 +CLUSTERING_DBSCAN_MIN_SAMPLES = 5 DATA_SOURCE_EXAMPLE = "Example" DATA_SOURCE_CSV = "CSV" @@ -57,21 +68,21 @@ class State: source_csv: str = "tmp/data.csv" source_openml: str = SOURCE_OPENML lens_type: str = LENS_PCA - lens_pca_n_components: int = 2 - lens_umap_n_components: int = 2 + lens_pca_n_components: int = LENS_PCA_N_COMPONENTS + lens_umap_n_components: int = LENS_UMAP_N_COMPONENTS cover_type: str = COVER_CUBICAL - cover_cubical_n_intervals: int = 2 - cover_cubical_overlap_frac: float = 0.25 - cover_knn_neighbors: int = 10 - cover_ball_radius: float = 100.0 + cover_cubical_n_intervals: int = COVER_CUBICAL_N_INTERVALS + cover_cubical_overlap_frac: float = COVER_CUBICAL_OVERLAP_FRAC + cover_knn_neighbors: int = COVER_KNN_NEIGHBORS + cover_ball_radius: float = COVER_BALL_RADIUS clustering_type: str = CLUSTERING_TRIVIAL - clustering_kmeans_n_clusters: int = 2 - clustering_dbscan_eps: float = 0.5 - clustering_dbscan_min_samples: int = 5 - clustering_agglomerative_n_clusters: int = 2 - draw_dim: int = 3 + clustering_kmeans_n_clusters: int = CLUSTERING_KMEANS_N_CLUSTERS + clustering_dbscan_eps: float = CLUSTERING_DBSCAN_EPS + clustering_dbscan_min_samples: int = CLUSTERING_DBSCAN_MIN_SAMPLES + clustering_agglomerative_n_clusters: int = CLUSTERING_AGGLOMERATIVE_N_CLUSTERS + draw_dim: str = DRAW_3D draw_aggregation: str = DRAW_MEAN - draw_iterations: int = 50 + draw_iterations: int = DRAW_ITERATIONS def _fix_data(data): @@ -115,39 +126,47 @@ def get_dataset(state: State): def get_lens(state: State): + def _pca(n): + pca = PCA(n_components=n, random_state=RANDOM_STATE) + return lambda X: pca.fit_transform(X) + + def _umap(n): + umap = UMAP(n_components=n, random_state=RANDOM_STATE) + return lambda X: umap.fit_transform(X) + + def _identity(): + return lambda X: X + + lens = _pca(2) lens_type = state.lens_type if lens_type == LENS_IDENTITY: - return lambda X: X + lens = _identity() elif lens_type == LENS_PCA: - n_components = int(state.lens_pca_n_components) - pca = PCA(n_components=n_components, random_state=42) - return lambda X: pca.fit_transform(X) + lens = _pca(state.lens_pca_n_components) elif lens_type == LENS_UMAP: - n_components = int(state.lens_umap_n_components) - umap = UMAP(n_components=n_components, random_state=42) - return lambda X: umap.fit_transform(X) + lens = _umap(state.lens_umap_n_components) else: - logger.error(f"Unknown lens type: {lens_type}") - return None + logger.error("Defaulting to PCA lens") + return lens -def get_cover(state: State): +def get_cover(state: State) -> Cover: cover_type = state.cover_type + cover: Cover = CubicalCover(n_intervals=2, overlap_frac=0.25) if cover_type == COVER_TRIVIAL: - return TrivialCover() + cover = TrivialCover() elif cover_type == COVER_CUBICAL: - n_intervals = int(state.cover_cubical_n_intervals) - overlap_frac = float(state.cover_cubical_overlap_frac) - return CubicalCover(n_intervals=n_intervals, overlap_frac=overlap_frac) + cover = CubicalCover( + n_intervals=state.cover_cubical_n_intervals, + overlap_frac=state.cover_cubical_overlap_frac, + ) elif cover_type == COVER_BALL: - radius = float(state.cover_ball_radius) - return BallCover(radius=radius) + cover = BallCover(radius=state.cover_ball_radius) elif cover_type == COVER_KNN: - neighbors = int(state.cover_knn_neighbors) - return KNNCover(neighbors=neighbors) + cover = KNNCover(neighbors=state.cover_knn_neighbors) else: - logger.error(f"Unknown cover type: {cover_type}") - return None + logger.error("Defaulting to CubicalCover") + return cover def get_clustering(state: State): @@ -155,18 +174,21 @@ def get_clustering(state: State): if clustering_type == CLUSTERING_TRIVIAL: return TrivialClustering() elif clustering_type == CLUSTERING_KMEANS: - n_clusters = state.clustering_kmeans_n_clusters - return KMeans(n_clusters=n_clusters, random_state=42) + return KMeans( + n_clusters=state.clustering_kmeans_n_clusters, random_state=RANDOM_STATE + ) elif clustering_type == CLUSTERING_DBSCAN: - eps = state.clustering_dbscan_eps - min_samples = state.clustering_dbscan_min_samples - return DBSCAN(eps=eps, min_samples=min_samples) + return DBSCAN( + eps=state.clustering_dbscan_eps, + min_samples=state.clustering_dbscan_min_samples, + ) elif clustering_type == CLUSTERING_AGGLOMERATIVE: - n_clusters = state.clustering_agglomerative_n_clusters - return AgglomerativeClustering(n_clusters=n_clusters) + return AgglomerativeClustering( + n_clusters=state.clustering_agglomerative_n_clusters + ) else: - logger.error(f"Unknown clustering type: {clustering_type}") - return None + logger.error("Defaulting to TrivialClustering") + return TrivialClustering() def compute_mapper(**kwargs): @@ -203,9 +225,11 @@ def compute_mapper(**kwargs): logger.info(f"Mapper configuration: {mapper_algo}") mapper_graph = mapper_algo.fit_transform(X, y) + dim = 3 if state.draw_dim == DRAW_3D else 2 + mapper_plot = MapperPlot( mapper_graph, - dim=state.draw_dim, + dim=dim, iterations=state.draw_iterations, seed=42, ) @@ -497,7 +521,7 @@ async def update(self, _=None): self.state.clustering_agglomerative_n_clusters = int( self.clustering_agglomerative_n_clusters.value ) - self.state.draw_dim = 3 if self.draw_3d.value == DRAW_3D else 2 + self.state.draw_dim = str(self.draw_3d.value) self.state.draw_iterations = int(self.draw_iterations.value) self.state.draw_aggregation = self.draw_aggregation.value await self.render() @@ -508,6 +532,9 @@ async def render(self): **asdict(self.state), ) + self.storage["mapper_graph"] = mapper_graph + self.storage["mapper_fig"] = mapper_fig + self.plot_container.clear() with self.plot_container: ui.plotly(mapper_fig) @@ -517,7 +544,8 @@ async def update_csv_handler(self, file): await self.update_dataset_handler() await self.update_plot_handler() - def __init__(self): + def __init__(self, storage): + self.storage = storage self.state = State() with ui.row().classes("w-full h-screen m-0 p-0 gap-0 overflow-hidden"): with ui.column().classes("w-64 h-full m-0 p-0"): @@ -544,9 +572,14 @@ def __init__(self): ui.plotly(mapper_fig) +@ui.page("/") +def main_page(): + storage = app.storage.client + App(storage=storage) + + def main(): - App() - ui.run() + ui.run(storage_secret="tdamapper_secret", title="TDA Mapper App", port=8080) if __name__ in {"__main__", "__mp_main__", "tdamapper.app"}: From aa6879d32689d3a3fc95968d1c375105bf6b923b Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Tue, 3 Jun 2025 08:29:39 +0200 Subject: [PATCH 13/15] Handling of csv uploads --- src/tdamapper/app.py | 223 +++++++++++++++++++++++++++---------------- 1 file changed, 143 insertions(+), 80 deletions(-) diff --git a/src/tdamapper/app.py b/src/tdamapper/app.py index 26ae7f1..33a529b 100644 --- a/src/tdamapper/app.py +++ b/src/tdamapper/app.py @@ -65,7 +65,6 @@ class State: source_type: str = DATA_SOURCE_EXAMPLE source_name: str = DATA_SOURCE_EXAMPLE_DIGITS - source_csv: str = "tmp/data.csv" source_openml: str = SOURCE_OPENML lens_type: str = LENS_PCA lens_pca_n_components: int = LENS_PCA_N_COMPONENTS @@ -93,38 +92,6 @@ def _fix_data(data): return df -def get_dataset(state: State): - source_type = state.source_type - source_name = state.source_name - csv_file = state.source_csv - openml_code = state.source_openml - df_X, df_y = pd.DataFrame(), pd.Series() - if source_type == DATA_SOURCE_EXAMPLE: - if source_name == DATA_SOURCE_EXAMPLE_DIGITS: - df_X, df_y = load_digits(return_X_y=True, as_frame=True) - elif source_name == DATA_SOURCE_EXAMPLE_IRIS: - df_X, df_y = load_iris(return_X_y=True, as_frame=True) - elif source_type == DATA_SOURCE_CSV: - if csv_file is None: - logger.warning("No CSV file uploaded") - df_X, df_y = pd.DataFrame(), pd.Series() - else: - df_X = pd.read_csv() - df_y = pd.Series() - elif source_type == DATA_SOURCE_OPENML: - if not openml_code: - logger.warning("No OpenML code provided") - df_X, df_y = pd.DataFrame(), pd.Series() - else: - df_X, df_y = fetch_openml(openml_code, return_X_y=True, as_frame=True) - else: - logger.error(f"Unknown data source type: {source_type}") - return pd.DataFrame(), pd.Series() - df_X = _fix_data(df_X) - df_y = _fix_data(df_y) - return df_X, df_y - - def get_lens(state: State): def _pca(n): pca = PCA(n_components=n, random_state=RANDOM_STATE) @@ -191,18 +158,18 @@ def get_clustering(state: State): return TrivialClustering() -def compute_mapper(**kwargs): +def compute_mapper(df_X, labels, **kwargs): state = State(**kwargs) - df_X, labels = get_dataset(state) + # df_X, labels = get_dataset(state) if df_X.empty: logger.warning("No dataset loaded") - return None + return None, None lens = get_lens(state) if lens is None: logger.warning("No lens selected") - return None + return None, None X = df_X.to_numpy() y = lens(X) @@ -210,12 +177,12 @@ def compute_mapper(**kwargs): cover = get_cover(state) if cover is None: logger.warning("No cover selected") - return None + return None, None clustering = get_clustering(state) if clustering is None: logger.warning("No clustering selected") - return None + return None, None mapper_algo = MapperAlgorithm( cover=cover, @@ -260,13 +227,6 @@ def mode(arr): return np.nanmean(mode_values) -def update_csv(file): - if file is None: - logger.warning("No file uploaded") - return None - return file - - class App: def build_dataset(self): @@ -278,7 +238,7 @@ def build_dataset(self): DATA_SOURCE_OPENML, ], value=DATA_SOURCE_EXAMPLE, - on_change=self.update, + on_change=self.load_dataset, ).classes("w-full") self.data_source_example_file = ui.select( label="File", @@ -287,7 +247,7 @@ def build_dataset(self): DATA_SOURCE_EXAMPLE_IRIS, ], value=DATA_SOURCE_EXAMPLE_DIGITS, - on_change=self.update, + on_change=self.load_dataset, ).classes("w-full") self.data_source_example_file.bind_visibility_from( target_object=self.data_source_type, @@ -295,7 +255,7 @@ def build_dataset(self): value=DATA_SOURCE_EXAMPLE, ) self.data_source_csv = ui.upload( - on_upload=self.update, + on_upload=self.upload_csv, auto_upload=True, label="Upload CSV", ).classes("w-full") @@ -306,7 +266,7 @@ def build_dataset(self): ) self.data_source_openml = ui.input( label="OpenML Code", - on_change=self.update, + on_change=self.load_dataset, ).classes("w-full") self.data_source_openml.bind_visibility_from( target_object=self.data_source_type, @@ -496,39 +456,144 @@ def build_plot(self): fig.layout.autosize = True self.plot_container = ui.element("div").classes("w-full h-full") - async def update(self, _=None): + def get_dataset(self): + state = self.state + source_type = state.source_type + source_name = state.source_name + csv_file = self.storage.get("csv_file", None) + openml_code = state.source_openml + df_X, df_y = pd.DataFrame(), pd.Series() + if source_type == DATA_SOURCE_EXAMPLE: + if source_name == DATA_SOURCE_EXAMPLE_DIGITS: + df_X, df_y = load_digits(return_X_y=True, as_frame=True) + elif source_name == DATA_SOURCE_EXAMPLE_IRIS: + df_X, df_y = load_iris(return_X_y=True, as_frame=True) + elif source_type == DATA_SOURCE_CSV: + if csv_file is None: + logger.warning("No CSV file uploaded") + df_X, df_y = pd.DataFrame(), pd.Series() + else: + df_X = pd.read_csv(csv_file) + df_y = pd.Series() + elif source_type == DATA_SOURCE_OPENML: + if not openml_code: + logger.warning("No OpenML code provided") + df_X, df_y = pd.DataFrame(), pd.Series() + else: + df_X, df_y = fetch_openml(openml_code, return_X_y=True, as_frame=True) + else: + logger.error(f"Unknown data source type: {source_type}") + return pd.DataFrame(), pd.Series() + df_X = _fix_data(df_X) + df_y = _fix_data(df_y) + return df_X, df_y + + async def upload_csv(self, file): + if file is None: + logger.warning("No file uploaded") + else: + self.storage["csv_file"] = file.content + await self.load_dataset() + + async def load_dataset(self, _=None): self.state.source_type = str(self.data_source_type.value) self.state.source_name = str(self.data_source_example_file.value) self.state.source_openml = str(self.data_source_openml.value) - self.state.lens_type = str(self.lens_type.value) - self.state.lens_pca_n_components = int(self.pca_n_components.value) - self.state.lens_umap_n_components = int(self.umap_n_components.value) - self.state.cover_type = str(self.cover_type.value) - self.state.cover_cubical_n_intervals = int(self.cover_cubical_n_intervals.value) - self.state.cover_cubical_overlap_frac = float( - self.cover_cubical_overlap_frac.value - ) - self.state.cover_knn_neighbors = int(self.cover_knn_neighbors.value) - self.state.cover_ball_radius = float(self.cover_ball_radius.value) - self.state.clustering_type = str(self.clustering_type.value) - self.state.clustering_kmeans_n_clusters = int( - self.clustering_kmeans_n_clusters.value - ) - self.state.clustering_dbscan_eps = float(self.clustering_dbscan_eps.value) - self.state.clustering_dbscan_min_samples = int( - self.clustering_dbscan_min_samples.value - ) - self.state.clustering_agglomerative_n_clusters = int( - self.clustering_agglomerative_n_clusters.value + df_X, labels = self.get_dataset() + if df_X.empty: + logger.warning("No dataset loaded") + return None + self.storage["df_X"] = df_X + self.storage["labels"] = labels + await self.update() + + async def update(self, _=None): + self.state.lens_type = LENS_PCA + if self.lens_type.value is not None: + self.state.lens_type = str(self.lens_type.value) + + self.state.lens_pca_n_components = LENS_PCA_N_COMPONENTS + if self.pca_n_components.value is not None: + self.state.lens_pca_n_components = int(self.pca_n_components.value) + + self.state.lens_umap_n_components = LENS_UMAP_N_COMPONENTS + if self.umap_n_components.value is not None: + self.state.lens_umap_n_components = int(self.umap_n_components.value) + + self.state.cover_type = COVER_CUBICAL + if self.cover_type.value is not None: + self.state.cover_type = str(self.cover_type.value) + + self.state.cover_cubical_n_intervals = COVER_CUBICAL_N_INTERVALS + if self.cover_cubical_n_intervals.value is not None: + self.state.cover_cubical_n_intervals = int( + self.cover_cubical_n_intervals.value + ) + + self.state.cover_cubical_overlap_frac = COVER_CUBICAL_OVERLAP_FRAC + if self.cover_cubical_overlap_frac.value is not None: + self.state.cover_cubical_overlap_frac = float( + self.cover_cubical_overlap_frac.value + ) + + self.state.cover_ball_radius = COVER_BALL_RADIUS + if self.cover_ball_radius.value is not None: + self.state.cover_ball_radius = float(self.cover_ball_radius.value) + + self.state.cover_knn_neighbors = COVER_KNN_NEIGHBORS + if self.cover_knn_neighbors.value is not None: + self.state.cover_knn_neighbors = int(self.cover_knn_neighbors.value) + + self.state.clustering_type = CLUSTERING_TRIVIAL + if self.clustering_type.value is not None: + self.state.clustering_type = str(self.clustering_type.value) + + self.state.clustering_kmeans_n_clusters = CLUSTERING_KMEANS_N_CLUSTERS + if self.clustering_kmeans_n_clusters.value is not None: + self.state.clustering_kmeans_n_clusters = int( + self.clustering_kmeans_n_clusters.value + ) + + self.state.clustering_dbscan_eps = CLUSTERING_DBSCAN_EPS + if self.clustering_dbscan_eps.value is not None: + self.state.clustering_dbscan_eps = float(self.clustering_dbscan_eps.value) + + self.state.clustering_dbscan_min_samples = CLUSTERING_DBSCAN_MIN_SAMPLES + if self.clustering_dbscan_min_samples.value is not None: + self.state.clustering_dbscan_min_samples = int( + self.clustering_dbscan_min_samples.value + ) + + self.state.clustering_agglomerative_n_clusters = ( + CLUSTERING_AGGLOMERATIVE_N_CLUSTERS ) - self.state.draw_dim = str(self.draw_3d.value) - self.state.draw_iterations = int(self.draw_iterations.value) - self.state.draw_aggregation = self.draw_aggregation.value + if self.clustering_agglomerative_n_clusters.value is not None: + self.state.clustering_agglomerative_n_clusters = int( + self.clustering_agglomerative_n_clusters.value + ) + + self.state.draw_dim = DRAW_3D + if self.draw_3d.value is not None: + self.state.draw_dim = str(self.draw_3d.value) + + self.state.draw_iterations = DRAW_ITERATIONS + if self.draw_iterations.value is not None: + self.state.draw_iterations = int(self.draw_iterations.value) + + self.state.draw_aggregation = DRAW_MEAN + if self.draw_aggregation.value is not None: + self.state.draw_aggregation = str(self.draw_aggregation.value) + await self.render() async def render(self): + df_X = self.storage.get("df_X", pd.DataFrame()) + labels = self.storage.get("labels", pd.Series()) + mapper_graph, mapper_fig = await run.cpu_bound( compute_mapper, + df_X, + labels, **asdict(self.state), ) @@ -539,11 +604,6 @@ async def render(self): with self.plot_container: ui.plotly(mapper_fig) - async def update_csv_handler(self, file): - self.csv_file = run.cpu_bound(update_csv, file) - await self.update_dataset_handler() - await self.update_plot_handler() - def __init__(self, storage): self.storage = storage self.state = State() @@ -567,7 +627,10 @@ def __init__(self, storage): with ui.row(align_items="baseline"): self.build_draw() self.build_plot() - mapper_graph, mapper_fig = compute_mapper(**asdict(self.state)) + df_X, labels = self.get_dataset() + self.storage["df_X"] = df_X + self.storage["labels"] = labels + mapper_graph, mapper_fig = compute_mapper(df_X, labels, **asdict(self.state)) with self.plot_container: ui.plotly(mapper_fig) From f2886383356e17addf70694ea0b8640e103a42d7 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Wed, 4 Jun 2025 08:16:30 +0200 Subject: [PATCH 14/15] Improved state handling --- src/tdamapper/app.py | 184 ++++++++++++++++++++----------------------- 1 file changed, 86 insertions(+), 98 deletions(-) diff --git a/src/tdamapper/app.py b/src/tdamapper/app.py index 33a529b..5c28875 100644 --- a/src/tdamapper/app.py +++ b/src/tdamapper/app.py @@ -1,5 +1,6 @@ import logging from dataclasses import asdict, dataclass +from typing import Any import networkx as nx import numpy as np @@ -92,6 +93,38 @@ def _fix_data(data): return df +def get_dataset(state: State, storage: dict[str, Any]): + source_type = state.source_type + source_name = state.source_name + csv_file = storage.get("csv_file", None) + openml_code = state.source_openml + df_X, df_y = pd.DataFrame(), pd.Series() + if source_type == DATA_SOURCE_EXAMPLE: + if source_name == DATA_SOURCE_EXAMPLE_DIGITS: + df_X, df_y = load_digits(return_X_y=True, as_frame=True) + elif source_name == DATA_SOURCE_EXAMPLE_IRIS: + df_X, df_y = load_iris(return_X_y=True, as_frame=True) + elif source_type == DATA_SOURCE_CSV: + if csv_file is None: + logger.warning("No CSV file uploaded") + df_X, df_y = pd.DataFrame(), pd.Series() + else: + df_X = pd.read_csv(csv_file) + df_y = pd.Series() + elif source_type == DATA_SOURCE_OPENML: + if not openml_code: + logger.warning("No OpenML code provided") + df_X, df_y = pd.DataFrame(), pd.Series() + else: + df_X, df_y = fetch_openml(openml_code, return_X_y=True, as_frame=True) + else: + logger.error(f"Unknown data source type: {source_type}") + return pd.DataFrame(), pd.Series() + df_X = _fix_data(df_X) + df_y = _fix_data(df_y) + return df_X, df_y + + def get_lens(state: State): def _pca(n): pca = PCA(n_components=n, random_state=RANDOM_STATE) @@ -161,7 +194,6 @@ def get_clustering(state: State): def compute_mapper(df_X, labels, **kwargs): state = State(**kwargs) - # df_X, labels = get_dataset(state) if df_X.empty: logger.warning("No dataset loaded") return None, None @@ -456,38 +488,6 @@ def build_plot(self): fig.layout.autosize = True self.plot_container = ui.element("div").classes("w-full h-full") - def get_dataset(self): - state = self.state - source_type = state.source_type - source_name = state.source_name - csv_file = self.storage.get("csv_file", None) - openml_code = state.source_openml - df_X, df_y = pd.DataFrame(), pd.Series() - if source_type == DATA_SOURCE_EXAMPLE: - if source_name == DATA_SOURCE_EXAMPLE_DIGITS: - df_X, df_y = load_digits(return_X_y=True, as_frame=True) - elif source_name == DATA_SOURCE_EXAMPLE_IRIS: - df_X, df_y = load_iris(return_X_y=True, as_frame=True) - elif source_type == DATA_SOURCE_CSV: - if csv_file is None: - logger.warning("No CSV file uploaded") - df_X, df_y = pd.DataFrame(), pd.Series() - else: - df_X = pd.read_csv(csv_file) - df_y = pd.Series() - elif source_type == DATA_SOURCE_OPENML: - if not openml_code: - logger.warning("No OpenML code provided") - df_X, df_y = pd.DataFrame(), pd.Series() - else: - df_X, df_y = fetch_openml(openml_code, return_X_y=True, as_frame=True) - else: - logger.error(f"Unknown data source type: {source_type}") - return pd.DataFrame(), pd.Series() - df_X = _fix_data(df_X) - df_y = _fix_data(df_y) - return df_X, df_y - async def upload_csv(self, file): if file is None: logger.warning("No file uploaded") @@ -496,10 +496,8 @@ async def upload_csv(self, file): await self.load_dataset() async def load_dataset(self, _=None): - self.state.source_type = str(self.data_source_type.value) - self.state.source_name = str(self.data_source_example_file.value) - self.state.source_openml = str(self.data_source_openml.value) - df_X, labels = self.get_dataset() + state = self.get_state() + df_X, labels = get_dataset(state, self.storage) if df_X.empty: logger.warning("No dataset loaded") return None @@ -507,106 +505,96 @@ async def load_dataset(self, _=None): self.storage["labels"] = labels await self.update() - async def update(self, _=None): - self.state.lens_type = LENS_PCA + def get_state(self) -> State: + state = State( + source_type=DATA_SOURCE_EXAMPLE, + source_name=DATA_SOURCE_EXAMPLE_DIGITS, + source_openml=SOURCE_OPENML, + lens_type=LENS_PCA, + lens_pca_n_components=LENS_PCA_N_COMPONENTS, + lens_umap_n_components=LENS_UMAP_N_COMPONENTS, + cover_type=COVER_CUBICAL, + cover_cubical_n_intervals=COVER_CUBICAL_N_INTERVALS, + cover_cubical_overlap_frac=COVER_CUBICAL_OVERLAP_FRAC, + cover_ball_radius=COVER_BALL_RADIUS, + cover_knn_neighbors=COVER_KNN_NEIGHBORS, + clustering_type=CLUSTERING_TRIVIAL, + clustering_kmeans_n_clusters=CLUSTERING_KMEANS_N_CLUSTERS, + clustering_dbscan_eps=CLUSTERING_DBSCAN_EPS, + clustering_dbscan_min_samples=CLUSTERING_DBSCAN_MIN_SAMPLES, + clustering_agglomerative_n_clusters=CLUSTERING_AGGLOMERATIVE_N_CLUSTERS, + draw_dim=DRAW_3D, + draw_aggregation=DRAW_MEAN, + draw_iterations=DRAW_ITERATIONS, + ) + if self.data_source_type.value is not None: + state.source_type = str(self.data_source_type.value) + if self.data_source_example_file.value is not None: + state.source_name = str(self.data_source_example_file.value) + if self.data_source_openml.value is not None: + state.source_openml = str(self.data_source_openml.value) if self.lens_type.value is not None: - self.state.lens_type = str(self.lens_type.value) - - self.state.lens_pca_n_components = LENS_PCA_N_COMPONENTS + state.lens_type = str(self.lens_type.value) if self.pca_n_components.value is not None: - self.state.lens_pca_n_components = int(self.pca_n_components.value) - - self.state.lens_umap_n_components = LENS_UMAP_N_COMPONENTS + state.lens_pca_n_components = int(self.pca_n_components.value) if self.umap_n_components.value is not None: - self.state.lens_umap_n_components = int(self.umap_n_components.value) - - self.state.cover_type = COVER_CUBICAL + state.lens_umap_n_components = int(self.umap_n_components.value) if self.cover_type.value is not None: - self.state.cover_type = str(self.cover_type.value) - - self.state.cover_cubical_n_intervals = COVER_CUBICAL_N_INTERVALS + state.cover_type = str(self.cover_type.value) if self.cover_cubical_n_intervals.value is not None: - self.state.cover_cubical_n_intervals = int( - self.cover_cubical_n_intervals.value - ) - - self.state.cover_cubical_overlap_frac = COVER_CUBICAL_OVERLAP_FRAC + state.cover_cubical_n_intervals = int(self.cover_cubical_n_intervals.value) if self.cover_cubical_overlap_frac.value is not None: - self.state.cover_cubical_overlap_frac = float( + state.cover_cubical_overlap_frac = float( self.cover_cubical_overlap_frac.value ) - - self.state.cover_ball_radius = COVER_BALL_RADIUS if self.cover_ball_radius.value is not None: - self.state.cover_ball_radius = float(self.cover_ball_radius.value) - - self.state.cover_knn_neighbors = COVER_KNN_NEIGHBORS + state.cover_ball_radius = float(self.cover_ball_radius.value) if self.cover_knn_neighbors.value is not None: - self.state.cover_knn_neighbors = int(self.cover_knn_neighbors.value) - - self.state.clustering_type = CLUSTERING_TRIVIAL + state.cover_knn_neighbors = int(self.cover_knn_neighbors.value) if self.clustering_type.value is not None: - self.state.clustering_type = str(self.clustering_type.value) - - self.state.clustering_kmeans_n_clusters = CLUSTERING_KMEANS_N_CLUSTERS + state.clustering_type = str(self.clustering_type.value) if self.clustering_kmeans_n_clusters.value is not None: - self.state.clustering_kmeans_n_clusters = int( + state.clustering_kmeans_n_clusters = int( self.clustering_kmeans_n_clusters.value ) - - self.state.clustering_dbscan_eps = CLUSTERING_DBSCAN_EPS if self.clustering_dbscan_eps.value is not None: - self.state.clustering_dbscan_eps = float(self.clustering_dbscan_eps.value) - - self.state.clustering_dbscan_min_samples = CLUSTERING_DBSCAN_MIN_SAMPLES + state.clustering_dbscan_eps = float(self.clustering_dbscan_eps.value) if self.clustering_dbscan_min_samples.value is not None: - self.state.clustering_dbscan_min_samples = int( + state.clustering_dbscan_min_samples = int( self.clustering_dbscan_min_samples.value ) - - self.state.clustering_agglomerative_n_clusters = ( - CLUSTERING_AGGLOMERATIVE_N_CLUSTERS - ) if self.clustering_agglomerative_n_clusters.value is not None: - self.state.clustering_agglomerative_n_clusters = int( + state.clustering_agglomerative_n_clusters = int( self.clustering_agglomerative_n_clusters.value ) - - self.state.draw_dim = DRAW_3D if self.draw_3d.value is not None: - self.state.draw_dim = str(self.draw_3d.value) - - self.state.draw_iterations = DRAW_ITERATIONS + state.draw_dim = str(self.draw_3d.value) if self.draw_iterations.value is not None: - self.state.draw_iterations = int(self.draw_iterations.value) - - self.state.draw_aggregation = DRAW_MEAN + state.draw_iterations = int(self.draw_iterations.value) if self.draw_aggregation.value is not None: - self.state.draw_aggregation = str(self.draw_aggregation.value) + state.draw_aggregation = str(self.draw_aggregation.value) - await self.render() + return state - async def render(self): + async def update(self, _=None): + state = self.get_state() df_X = self.storage.get("df_X", pd.DataFrame()) labels = self.storage.get("labels", pd.Series()) - mapper_graph, mapper_fig = await run.cpu_bound( compute_mapper, df_X, labels, - **asdict(self.state), + **asdict(state), ) - self.storage["mapper_graph"] = mapper_graph self.storage["mapper_fig"] = mapper_fig - self.plot_container.clear() - with self.plot_container: - ui.plotly(mapper_fig) + if mapper_fig is not None: + with self.plot_container: + ui.plotly(mapper_fig) def __init__(self, storage): self.storage = storage - self.state = State() with ui.row().classes("w-full h-screen m-0 p-0 gap-0 overflow-hidden"): with ui.column().classes("w-64 h-full m-0 p-0"): with ui.column().classes("w-64 h-full overflow-y-auto p-3 gap-2"): From aa1adb439d3ff215136c39e8a4edac80b2f0478d Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Wed, 4 Jun 2025 08:18:15 +0200 Subject: [PATCH 15/15] Fixed error --- src/tdamapper/app.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tdamapper/app.py b/src/tdamapper/app.py index 5c28875..7b53f87 100644 --- a/src/tdamapper/app.py +++ b/src/tdamapper/app.py @@ -615,10 +615,12 @@ def __init__(self, storage): with ui.row(align_items="baseline"): self.build_draw() self.build_plot() - df_X, labels = self.get_dataset() + + state = self.get_state() + df_X, labels = get_dataset(state, self.storage) self.storage["df_X"] = df_X self.storage["labels"] = labels - mapper_graph, mapper_fig = compute_mapper(df_X, labels, **asdict(self.state)) + mapper_graph, mapper_fig = compute_mapper(df_X, labels, **asdict(state)) with self.plot_container: ui.plotly(mapper_fig)