Skip to content

Commit f288638

Browse files
committed
Improved state handling
1 parent aa6879d commit f288638

1 file changed

Lines changed: 86 additions & 98 deletions

File tree

src/tdamapper/app.py

Lines changed: 86 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from dataclasses import asdict, dataclass
3+
from typing import Any
34

45
import networkx as nx
56
import numpy as np
@@ -92,6 +93,38 @@ def _fix_data(data):
9293
return df
9394

9495

96+
def get_dataset(state: State, storage: dict[str, Any]):
97+
source_type = state.source_type
98+
source_name = state.source_name
99+
csv_file = storage.get("csv_file", None)
100+
openml_code = state.source_openml
101+
df_X, df_y = pd.DataFrame(), pd.Series()
102+
if source_type == DATA_SOURCE_EXAMPLE:
103+
if source_name == DATA_SOURCE_EXAMPLE_DIGITS:
104+
df_X, df_y = load_digits(return_X_y=True, as_frame=True)
105+
elif source_name == DATA_SOURCE_EXAMPLE_IRIS:
106+
df_X, df_y = load_iris(return_X_y=True, as_frame=True)
107+
elif source_type == DATA_SOURCE_CSV:
108+
if csv_file is None:
109+
logger.warning("No CSV file uploaded")
110+
df_X, df_y = pd.DataFrame(), pd.Series()
111+
else:
112+
df_X = pd.read_csv(csv_file)
113+
df_y = pd.Series()
114+
elif source_type == DATA_SOURCE_OPENML:
115+
if not openml_code:
116+
logger.warning("No OpenML code provided")
117+
df_X, df_y = pd.DataFrame(), pd.Series()
118+
else:
119+
df_X, df_y = fetch_openml(openml_code, return_X_y=True, as_frame=True)
120+
else:
121+
logger.error(f"Unknown data source type: {source_type}")
122+
return pd.DataFrame(), pd.Series()
123+
df_X = _fix_data(df_X)
124+
df_y = _fix_data(df_y)
125+
return df_X, df_y
126+
127+
95128
def get_lens(state: State):
96129
def _pca(n):
97130
pca = PCA(n_components=n, random_state=RANDOM_STATE)
@@ -161,7 +194,6 @@ def get_clustering(state: State):
161194
def compute_mapper(df_X, labels, **kwargs):
162195
state = State(**kwargs)
163196

164-
# df_X, labels = get_dataset(state)
165197
if df_X.empty:
166198
logger.warning("No dataset loaded")
167199
return None, None
@@ -456,38 +488,6 @@ def build_plot(self):
456488
fig.layout.autosize = True
457489
self.plot_container = ui.element("div").classes("w-full h-full")
458490

459-
def get_dataset(self):
460-
state = self.state
461-
source_type = state.source_type
462-
source_name = state.source_name
463-
csv_file = self.storage.get("csv_file", None)
464-
openml_code = state.source_openml
465-
df_X, df_y = pd.DataFrame(), pd.Series()
466-
if source_type == DATA_SOURCE_EXAMPLE:
467-
if source_name == DATA_SOURCE_EXAMPLE_DIGITS:
468-
df_X, df_y = load_digits(return_X_y=True, as_frame=True)
469-
elif source_name == DATA_SOURCE_EXAMPLE_IRIS:
470-
df_X, df_y = load_iris(return_X_y=True, as_frame=True)
471-
elif source_type == DATA_SOURCE_CSV:
472-
if csv_file is None:
473-
logger.warning("No CSV file uploaded")
474-
df_X, df_y = pd.DataFrame(), pd.Series()
475-
else:
476-
df_X = pd.read_csv(csv_file)
477-
df_y = pd.Series()
478-
elif source_type == DATA_SOURCE_OPENML:
479-
if not openml_code:
480-
logger.warning("No OpenML code provided")
481-
df_X, df_y = pd.DataFrame(), pd.Series()
482-
else:
483-
df_X, df_y = fetch_openml(openml_code, return_X_y=True, as_frame=True)
484-
else:
485-
logger.error(f"Unknown data source type: {source_type}")
486-
return pd.DataFrame(), pd.Series()
487-
df_X = _fix_data(df_X)
488-
df_y = _fix_data(df_y)
489-
return df_X, df_y
490-
491491
async def upload_csv(self, file):
492492
if file is None:
493493
logger.warning("No file uploaded")
@@ -496,117 +496,105 @@ async def upload_csv(self, file):
496496
await self.load_dataset()
497497

498498
async def load_dataset(self, _=None):
499-
self.state.source_type = str(self.data_source_type.value)
500-
self.state.source_name = str(self.data_source_example_file.value)
501-
self.state.source_openml = str(self.data_source_openml.value)
502-
df_X, labels = self.get_dataset()
499+
state = self.get_state()
500+
df_X, labels = get_dataset(state, self.storage)
503501
if df_X.empty:
504502
logger.warning("No dataset loaded")
505503
return None
506504
self.storage["df_X"] = df_X
507505
self.storage["labels"] = labels
508506
await self.update()
509507

510-
async def update(self, _=None):
511-
self.state.lens_type = LENS_PCA
508+
def get_state(self) -> State:
509+
state = State(
510+
source_type=DATA_SOURCE_EXAMPLE,
511+
source_name=DATA_SOURCE_EXAMPLE_DIGITS,
512+
source_openml=SOURCE_OPENML,
513+
lens_type=LENS_PCA,
514+
lens_pca_n_components=LENS_PCA_N_COMPONENTS,
515+
lens_umap_n_components=LENS_UMAP_N_COMPONENTS,
516+
cover_type=COVER_CUBICAL,
517+
cover_cubical_n_intervals=COVER_CUBICAL_N_INTERVALS,
518+
cover_cubical_overlap_frac=COVER_CUBICAL_OVERLAP_FRAC,
519+
cover_ball_radius=COVER_BALL_RADIUS,
520+
cover_knn_neighbors=COVER_KNN_NEIGHBORS,
521+
clustering_type=CLUSTERING_TRIVIAL,
522+
clustering_kmeans_n_clusters=CLUSTERING_KMEANS_N_CLUSTERS,
523+
clustering_dbscan_eps=CLUSTERING_DBSCAN_EPS,
524+
clustering_dbscan_min_samples=CLUSTERING_DBSCAN_MIN_SAMPLES,
525+
clustering_agglomerative_n_clusters=CLUSTERING_AGGLOMERATIVE_N_CLUSTERS,
526+
draw_dim=DRAW_3D,
527+
draw_aggregation=DRAW_MEAN,
528+
draw_iterations=DRAW_ITERATIONS,
529+
)
530+
if self.data_source_type.value is not None:
531+
state.source_type = str(self.data_source_type.value)
532+
if self.data_source_example_file.value is not None:
533+
state.source_name = str(self.data_source_example_file.value)
534+
if self.data_source_openml.value is not None:
535+
state.source_openml = str(self.data_source_openml.value)
512536
if self.lens_type.value is not None:
513-
self.state.lens_type = str(self.lens_type.value)
514-
515-
self.state.lens_pca_n_components = LENS_PCA_N_COMPONENTS
537+
state.lens_type = str(self.lens_type.value)
516538
if self.pca_n_components.value is not None:
517-
self.state.lens_pca_n_components = int(self.pca_n_components.value)
518-
519-
self.state.lens_umap_n_components = LENS_UMAP_N_COMPONENTS
539+
state.lens_pca_n_components = int(self.pca_n_components.value)
520540
if self.umap_n_components.value is not None:
521-
self.state.lens_umap_n_components = int(self.umap_n_components.value)
522-
523-
self.state.cover_type = COVER_CUBICAL
541+
state.lens_umap_n_components = int(self.umap_n_components.value)
524542
if self.cover_type.value is not None:
525-
self.state.cover_type = str(self.cover_type.value)
526-
527-
self.state.cover_cubical_n_intervals = COVER_CUBICAL_N_INTERVALS
543+
state.cover_type = str(self.cover_type.value)
528544
if self.cover_cubical_n_intervals.value is not None:
529-
self.state.cover_cubical_n_intervals = int(
530-
self.cover_cubical_n_intervals.value
531-
)
532-
533-
self.state.cover_cubical_overlap_frac = COVER_CUBICAL_OVERLAP_FRAC
545+
state.cover_cubical_n_intervals = int(self.cover_cubical_n_intervals.value)
534546
if self.cover_cubical_overlap_frac.value is not None:
535-
self.state.cover_cubical_overlap_frac = float(
547+
state.cover_cubical_overlap_frac = float(
536548
self.cover_cubical_overlap_frac.value
537549
)
538-
539-
self.state.cover_ball_radius = COVER_BALL_RADIUS
540550
if self.cover_ball_radius.value is not None:
541-
self.state.cover_ball_radius = float(self.cover_ball_radius.value)
542-
543-
self.state.cover_knn_neighbors = COVER_KNN_NEIGHBORS
551+
state.cover_ball_radius = float(self.cover_ball_radius.value)
544552
if self.cover_knn_neighbors.value is not None:
545-
self.state.cover_knn_neighbors = int(self.cover_knn_neighbors.value)
546-
547-
self.state.clustering_type = CLUSTERING_TRIVIAL
553+
state.cover_knn_neighbors = int(self.cover_knn_neighbors.value)
548554
if self.clustering_type.value is not None:
549-
self.state.clustering_type = str(self.clustering_type.value)
550-
551-
self.state.clustering_kmeans_n_clusters = CLUSTERING_KMEANS_N_CLUSTERS
555+
state.clustering_type = str(self.clustering_type.value)
552556
if self.clustering_kmeans_n_clusters.value is not None:
553-
self.state.clustering_kmeans_n_clusters = int(
557+
state.clustering_kmeans_n_clusters = int(
554558
self.clustering_kmeans_n_clusters.value
555559
)
556-
557-
self.state.clustering_dbscan_eps = CLUSTERING_DBSCAN_EPS
558560
if self.clustering_dbscan_eps.value is not None:
559-
self.state.clustering_dbscan_eps = float(self.clustering_dbscan_eps.value)
560-
561-
self.state.clustering_dbscan_min_samples = CLUSTERING_DBSCAN_MIN_SAMPLES
561+
state.clustering_dbscan_eps = float(self.clustering_dbscan_eps.value)
562562
if self.clustering_dbscan_min_samples.value is not None:
563-
self.state.clustering_dbscan_min_samples = int(
563+
state.clustering_dbscan_min_samples = int(
564564
self.clustering_dbscan_min_samples.value
565565
)
566-
567-
self.state.clustering_agglomerative_n_clusters = (
568-
CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
569-
)
570566
if self.clustering_agglomerative_n_clusters.value is not None:
571-
self.state.clustering_agglomerative_n_clusters = int(
567+
state.clustering_agglomerative_n_clusters = int(
572568
self.clustering_agglomerative_n_clusters.value
573569
)
574-
575-
self.state.draw_dim = DRAW_3D
576570
if self.draw_3d.value is not None:
577-
self.state.draw_dim = str(self.draw_3d.value)
578-
579-
self.state.draw_iterations = DRAW_ITERATIONS
571+
state.draw_dim = str(self.draw_3d.value)
580572
if self.draw_iterations.value is not None:
581-
self.state.draw_iterations = int(self.draw_iterations.value)
582-
583-
self.state.draw_aggregation = DRAW_MEAN
573+
state.draw_iterations = int(self.draw_iterations.value)
584574
if self.draw_aggregation.value is not None:
585-
self.state.draw_aggregation = str(self.draw_aggregation.value)
575+
state.draw_aggregation = str(self.draw_aggregation.value)
586576

587-
await self.render()
577+
return state
588578

589-
async def render(self):
579+
async def update(self, _=None):
580+
state = self.get_state()
590581
df_X = self.storage.get("df_X", pd.DataFrame())
591582
labels = self.storage.get("labels", pd.Series())
592-
593583
mapper_graph, mapper_fig = await run.cpu_bound(
594584
compute_mapper,
595585
df_X,
596586
labels,
597-
**asdict(self.state),
587+
**asdict(state),
598588
)
599-
600589
self.storage["mapper_graph"] = mapper_graph
601590
self.storage["mapper_fig"] = mapper_fig
602-
603591
self.plot_container.clear()
604-
with self.plot_container:
605-
ui.plotly(mapper_fig)
592+
if mapper_fig is not None:
593+
with self.plot_container:
594+
ui.plotly(mapper_fig)
606595

607596
def __init__(self, storage):
608597
self.storage = storage
609-
self.state = State()
610598
with ui.row().classes("w-full h-screen m-0 p-0 gap-0 overflow-hidden"):
611599
with ui.column().classes("w-64 h-full m-0 p-0"):
612600
with ui.column().classes("w-64 h-full overflow-y-auto p-3 gap-2"):

0 commit comments

Comments
 (0)