Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions app/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@
)


logger = st.logger.get_logger(__name__)


def _check_limits_mapper_graph(mapper_graph):
if LIMITS_ENABLED:
num_nodes = mapper_graph.number_of_nodes()
Expand Down Expand Up @@ -298,22 +301,34 @@ def mapper_lens_input_section(X):
value=2,
min_value=1,
)
pca_random_state = st.number_input(
"PCA random state",
value=VD_SEED,
)
_, n_feats = X.shape
if pca_n > n_feats:
lens = X
else:
lens = PCA(n_components=pca_n).fit_transform(X)
lens = PCA(n_components=pca_n, random_state=pca_random_state).fit_transform(
X
)
elif lens_type == V_LENS_UMAP:
umap_n = st.number_input(
"UMAP Components",
value=2,
min_value=1,
)
umap_random_state = st.number_input(
"UMAP random state",
value=VD_SEED,
)
_, n_feats = X.shape
if umap_n > n_feats:
lens = X
else:
lens = UMAP(n_components=umap_n).fit_transform(X)
lens = UMAP(
n_components=umap_n, random_state=umap_random_state
).fit_transform(X)
return lens


Expand Down Expand Up @@ -492,6 +507,7 @@ def mapper_clustering_input_section():
show_spinner="Computing Mapper",
)
def compute_mapper(mapper, X, y):
logger.info("Generating Mapper graph")
mapper_graph = mapper.fit_transform(X, y)
return mapper_graph

Expand Down Expand Up @@ -599,6 +615,7 @@ def plot_color_input_section(df_X, df_y):
)
def compute_mapper_plot(mapper_graph, dim, seed, iterations):
_check_limits_mapper_graph(mapper_graph)
logger.info("Generating Mapper plot")
mapper_plot = MapperPlot(
mapper_graph,
dim,
Expand Down Expand Up @@ -640,6 +657,7 @@ def mapper_plot_section(mapper_graph):
def compute_mapper_fig(
mapper_plot, colors, node_size, cmap, _agg, agg_name, colors_feat
):
logger.info("Generating Mapper figure")
mapper_fig = mapper_plot.plot_plotly(
colors,
node_size=node_size,
Expand Down