11import logging
22from dataclasses import asdict , dataclass
3+ from typing import Any
34
45import networkx as nx
56import numpy as np
@@ -92,6 +93,38 @@ def _fix_data(data):
9293 return df
9394
9495
96+ def get_dataset (state : State , storage : dict [str , Any ]):
97+ source_type = state .source_type
98+ source_name = state .source_name
99+ csv_file = storage .get ("csv_file" , None )
100+ openml_code = state .source_openml
101+ df_X , df_y = pd .DataFrame (), pd .Series ()
102+ if source_type == DATA_SOURCE_EXAMPLE :
103+ if source_name == DATA_SOURCE_EXAMPLE_DIGITS :
104+ df_X , df_y = load_digits (return_X_y = True , as_frame = True )
105+ elif source_name == DATA_SOURCE_EXAMPLE_IRIS :
106+ df_X , df_y = load_iris (return_X_y = True , as_frame = True )
107+ elif source_type == DATA_SOURCE_CSV :
108+ if csv_file is None :
109+ logger .warning ("No CSV file uploaded" )
110+ df_X , df_y = pd .DataFrame (), pd .Series ()
111+ else :
112+ df_X = pd .read_csv (csv_file )
113+ df_y = pd .Series ()
114+ elif source_type == DATA_SOURCE_OPENML :
115+ if not openml_code :
116+ logger .warning ("No OpenML code provided" )
117+ df_X , df_y = pd .DataFrame (), pd .Series ()
118+ else :
119+ df_X , df_y = fetch_openml (openml_code , return_X_y = True , as_frame = True )
120+ else :
121+ logger .error (f"Unknown data source type: { source_type } " )
122+ return pd .DataFrame (), pd .Series ()
123+ df_X = _fix_data (df_X )
124+ df_y = _fix_data (df_y )
125+ return df_X , df_y
126+
127+
95128def get_lens (state : State ):
96129 def _pca (n ):
97130 pca = PCA (n_components = n , random_state = RANDOM_STATE )
@@ -161,7 +194,6 @@ def get_clustering(state: State):
161194def compute_mapper (df_X , labels , ** kwargs ):
162195 state = State (** kwargs )
163196
164- # df_X, labels = get_dataset(state)
165197 if df_X .empty :
166198 logger .warning ("No dataset loaded" )
167199 return None , None
@@ -456,38 +488,6 @@ def build_plot(self):
456488 fig .layout .autosize = True
457489 self .plot_container = ui .element ("div" ).classes ("w-full h-full" )
458490
459- def get_dataset (self ):
460- state = self .state
461- source_type = state .source_type
462- source_name = state .source_name
463- csv_file = self .storage .get ("csv_file" , None )
464- openml_code = state .source_openml
465- df_X , df_y = pd .DataFrame (), pd .Series ()
466- if source_type == DATA_SOURCE_EXAMPLE :
467- if source_name == DATA_SOURCE_EXAMPLE_DIGITS :
468- df_X , df_y = load_digits (return_X_y = True , as_frame = True )
469- elif source_name == DATA_SOURCE_EXAMPLE_IRIS :
470- df_X , df_y = load_iris (return_X_y = True , as_frame = True )
471- elif source_type == DATA_SOURCE_CSV :
472- if csv_file is None :
473- logger .warning ("No CSV file uploaded" )
474- df_X , df_y = pd .DataFrame (), pd .Series ()
475- else :
476- df_X = pd .read_csv (csv_file )
477- df_y = pd .Series ()
478- elif source_type == DATA_SOURCE_OPENML :
479- if not openml_code :
480- logger .warning ("No OpenML code provided" )
481- df_X , df_y = pd .DataFrame (), pd .Series ()
482- else :
483- df_X , df_y = fetch_openml (openml_code , return_X_y = True , as_frame = True )
484- else :
485- logger .error (f"Unknown data source type: { source_type } " )
486- return pd .DataFrame (), pd .Series ()
487- df_X = _fix_data (df_X )
488- df_y = _fix_data (df_y )
489- return df_X , df_y
490-
491491 async def upload_csv (self , file ):
492492 if file is None :
493493 logger .warning ("No file uploaded" )
@@ -496,117 +496,105 @@ async def upload_csv(self, file):
496496 await self .load_dataset ()
497497
498498 async def load_dataset (self , _ = None ):
499- self .state .source_type = str (self .data_source_type .value )
500- self .state .source_name = str (self .data_source_example_file .value )
501- self .state .source_openml = str (self .data_source_openml .value )
502- df_X , labels = self .get_dataset ()
499+ state = self .get_state ()
500+ df_X , labels = get_dataset (state , self .storage )
503501 if df_X .empty :
504502 logger .warning ("No dataset loaded" )
505503 return None
506504 self .storage ["df_X" ] = df_X
507505 self .storage ["labels" ] = labels
508506 await self .update ()
509507
510- async def update (self , _ = None ):
511- self .state .lens_type = LENS_PCA
508+ def get_state (self ) -> State :
509+ state = State (
510+ source_type = DATA_SOURCE_EXAMPLE ,
511+ source_name = DATA_SOURCE_EXAMPLE_DIGITS ,
512+ source_openml = SOURCE_OPENML ,
513+ lens_type = LENS_PCA ,
514+ lens_pca_n_components = LENS_PCA_N_COMPONENTS ,
515+ lens_umap_n_components = LENS_UMAP_N_COMPONENTS ,
516+ cover_type = COVER_CUBICAL ,
517+ cover_cubical_n_intervals = COVER_CUBICAL_N_INTERVALS ,
518+ cover_cubical_overlap_frac = COVER_CUBICAL_OVERLAP_FRAC ,
519+ cover_ball_radius = COVER_BALL_RADIUS ,
520+ cover_knn_neighbors = COVER_KNN_NEIGHBORS ,
521+ clustering_type = CLUSTERING_TRIVIAL ,
522+ clustering_kmeans_n_clusters = CLUSTERING_KMEANS_N_CLUSTERS ,
523+ clustering_dbscan_eps = CLUSTERING_DBSCAN_EPS ,
524+ clustering_dbscan_min_samples = CLUSTERING_DBSCAN_MIN_SAMPLES ,
525+ clustering_agglomerative_n_clusters = CLUSTERING_AGGLOMERATIVE_N_CLUSTERS ,
526+ draw_dim = DRAW_3D ,
527+ draw_aggregation = DRAW_MEAN ,
528+ draw_iterations = DRAW_ITERATIONS ,
529+ )
530+ if self .data_source_type .value is not None :
531+ state .source_type = str (self .data_source_type .value )
532+ if self .data_source_example_file .value is not None :
533+ state .source_name = str (self .data_source_example_file .value )
534+ if self .data_source_openml .value is not None :
535+ state .source_openml = str (self .data_source_openml .value )
512536 if self .lens_type .value is not None :
513- self .state .lens_type = str (self .lens_type .value )
514-
515- self .state .lens_pca_n_components = LENS_PCA_N_COMPONENTS
537+ state .lens_type = str (self .lens_type .value )
516538 if self .pca_n_components .value is not None :
517- self .state .lens_pca_n_components = int (self .pca_n_components .value )
518-
519- self .state .lens_umap_n_components = LENS_UMAP_N_COMPONENTS
539+ state .lens_pca_n_components = int (self .pca_n_components .value )
520540 if self .umap_n_components .value is not None :
521- self .state .lens_umap_n_components = int (self .umap_n_components .value )
522-
523- self .state .cover_type = COVER_CUBICAL
541+ state .lens_umap_n_components = int (self .umap_n_components .value )
524542 if self .cover_type .value is not None :
525- self .state .cover_type = str (self .cover_type .value )
526-
527- self .state .cover_cubical_n_intervals = COVER_CUBICAL_N_INTERVALS
543+ state .cover_type = str (self .cover_type .value )
528544 if self .cover_cubical_n_intervals .value is not None :
529- self .state .cover_cubical_n_intervals = int (
530- self .cover_cubical_n_intervals .value
531- )
532-
533- self .state .cover_cubical_overlap_frac = COVER_CUBICAL_OVERLAP_FRAC
545+ state .cover_cubical_n_intervals = int (self .cover_cubical_n_intervals .value )
534546 if self .cover_cubical_overlap_frac .value is not None :
535- self . state .cover_cubical_overlap_frac = float (
547+ state .cover_cubical_overlap_frac = float (
536548 self .cover_cubical_overlap_frac .value
537549 )
538-
539- self .state .cover_ball_radius = COVER_BALL_RADIUS
540550 if self .cover_ball_radius .value is not None :
541- self .state .cover_ball_radius = float (self .cover_ball_radius .value )
542-
543- self .state .cover_knn_neighbors = COVER_KNN_NEIGHBORS
551+ state .cover_ball_radius = float (self .cover_ball_radius .value )
544552 if self .cover_knn_neighbors .value is not None :
545- self .state .cover_knn_neighbors = int (self .cover_knn_neighbors .value )
546-
547- self .state .clustering_type = CLUSTERING_TRIVIAL
553+ state .cover_knn_neighbors = int (self .cover_knn_neighbors .value )
548554 if self .clustering_type .value is not None :
549- self .state .clustering_type = str (self .clustering_type .value )
550-
551- self .state .clustering_kmeans_n_clusters = CLUSTERING_KMEANS_N_CLUSTERS
555+ state .clustering_type = str (self .clustering_type .value )
552556 if self .clustering_kmeans_n_clusters .value is not None :
553- self . state .clustering_kmeans_n_clusters = int (
557+ state .clustering_kmeans_n_clusters = int (
554558 self .clustering_kmeans_n_clusters .value
555559 )
556-
557- self .state .clustering_dbscan_eps = CLUSTERING_DBSCAN_EPS
558560 if self .clustering_dbscan_eps .value is not None :
559- self .state .clustering_dbscan_eps = float (self .clustering_dbscan_eps .value )
560-
561- self .state .clustering_dbscan_min_samples = CLUSTERING_DBSCAN_MIN_SAMPLES
561+ state .clustering_dbscan_eps = float (self .clustering_dbscan_eps .value )
562562 if self .clustering_dbscan_min_samples .value is not None :
563- self . state .clustering_dbscan_min_samples = int (
563+ state .clustering_dbscan_min_samples = int (
564564 self .clustering_dbscan_min_samples .value
565565 )
566-
567- self .state .clustering_agglomerative_n_clusters = (
568- CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
569- )
570566 if self .clustering_agglomerative_n_clusters .value is not None :
571- self . state .clustering_agglomerative_n_clusters = int (
567+ state .clustering_agglomerative_n_clusters = int (
572568 self .clustering_agglomerative_n_clusters .value
573569 )
574-
575- self .state .draw_dim = DRAW_3D
576570 if self .draw_3d .value is not None :
577- self .state .draw_dim = str (self .draw_3d .value )
578-
579- self .state .draw_iterations = DRAW_ITERATIONS
571+ state .draw_dim = str (self .draw_3d .value )
580572 if self .draw_iterations .value is not None :
581- self .state .draw_iterations = int (self .draw_iterations .value )
582-
583- self .state .draw_aggregation = DRAW_MEAN
573+ state .draw_iterations = int (self .draw_iterations .value )
584574 if self .draw_aggregation .value is not None :
585- self . state .draw_aggregation = str (self .draw_aggregation .value )
575+ state .draw_aggregation = str (self .draw_aggregation .value )
586576
587- await self . render ()
577+ return state
588578
589- async def render (self ):
579+ async def update (self , _ = None ):
580+ state = self .get_state ()
590581 df_X = self .storage .get ("df_X" , pd .DataFrame ())
591582 labels = self .storage .get ("labels" , pd .Series ())
592-
593583 mapper_graph , mapper_fig = await run .cpu_bound (
594584 compute_mapper ,
595585 df_X ,
596586 labels ,
597- ** asdict (self . state ),
587+ ** asdict (state ),
598588 )
599-
600589 self .storage ["mapper_graph" ] = mapper_graph
601590 self .storage ["mapper_fig" ] = mapper_fig
602-
603591 self .plot_container .clear ()
604- with self .plot_container :
605- ui .plotly (mapper_fig )
592+ if mapper_fig is not None :
593+ with self .plot_container :
594+ ui .plotly (mapper_fig )
606595
607596 def __init__ (self , storage ):
608597 self .storage = storage
609- self .state = State ()
610598 with ui .row ().classes ("w-full h-screen m-0 p-0 gap-0 overflow-hidden" ):
611599 with ui .column ().classes ("w-64 h-full m-0 p-0" ):
612600 with ui .column ().classes ("w-64 h-full overflow-y-auto p-3 gap-2" ):
0 commit comments