Skip to content

Commit 1c76634

Browse files
committed
Updated example
1 parent 8f03130 commit 1c76634

1 file changed

Lines changed: 15 additions & 27 deletions

File tree

tests/example.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import matplotlib.pyplot as plt
12
import numpy as np
23
from sklearn.cluster import DBSCAN
34
from sklearn.datasets import make_circles
@@ -7,34 +8,21 @@
78
from tdamapper.learn import MapperAlgorithm
89
from tdamapper.plot import MapperPlot
910

10-
X, y = make_circles( # load a labelled dataset
11-
n_samples=5000, noise=0.05, factor=0.3, random_state=42
12-
)
13-
lens = PCA(2).fit_transform(X)
11+
# Generate toy dataset
12+
X, labels = make_circles(n_samples=5000, noise=0.05, factor=0.3, random_state=42)
13+
plt.figure(figsize=(5, 5))
14+
plt.scatter(X[:, 0], X[:, 1], c=labels, s=0.25, cmap="jet")
15+
plt.axis("off")
16+
plt.show()
1417

15-
mapper_algo = MapperAlgorithm(
16-
cover=CubicalCover(n_intervals=10, overlap_frac=0.3), clustering=DBSCAN()
17-
)
18-
mapper_graph = mapper_algo.fit_transform(X, lens)
18+
# Apply PCA as lens
19+
y = PCA(2, random_state=42).fit_transform(X)
1920

20-
mapper_plot = MapperPlot(mapper_graph, dim=2, iterations=60, seed=42)
21-
22-
fig = mapper_plot.plot_plotly(
23-
title="",
24-
width=600,
25-
height=600,
26-
colors=y, # color according to categorical values
27-
cmap="jet", # Jet colormap, for classes
28-
agg=np.nanmean, # aggregate on nodes according to mean
29-
)
30-
31-
fig.show(config={"scrollZoom": True})
32-
33-
mapper_plot.plot_plotly_update(
34-
fig, # reuse the plot with the same positions
35-
colors=y,
36-
cmap="viridis", # viridis colormap, for ranges
37-
agg=np.nanstd, # aggregate on nodes according to std
38-
)
21+
# Mapper pipeline
22+
cover = CubicalCover(n_intervals=10, overlap_frac=0.3)
23+
clust = DBSCAN()
24+
graph = MapperAlgorithm(cover, clust).fit_transform(X, y)
3925

26+
# Visualize the Mapper graph
27+
fig = MapperPlot(graph, dim=2, seed=42, iterations=60).plot_plotly(colors=labels)
4028
fig.show(config={"scrollZoom": True})

0 commit comments

Comments
 (0)