|
| 1 | +import matplotlib.pyplot as plt |
1 | 2 | import numpy as np |
2 | 3 | from sklearn.cluster import DBSCAN |
3 | 4 | from sklearn.datasets import make_circles |
|
7 | 8 | from tdamapper.learn import MapperAlgorithm |
8 | 9 | from tdamapper.plot import MapperPlot |
9 | 10 |
|
10 | | -X, y = make_circles( # load a labelled dataset |
11 | | - n_samples=5000, noise=0.05, factor=0.3, random_state=42 |
12 | | -) |
13 | | -lens = PCA(2).fit_transform(X) |
| 11 | +# Generate toy dataset |
| 12 | +X, labels = make_circles(n_samples=5000, noise=0.05, factor=0.3, random_state=42) |
| 13 | +plt.figure(figsize=(5, 5)) |
| 14 | +plt.scatter(X[:, 0], X[:, 1], c=labels, s=0.25, cmap="jet") |
| 15 | +plt.axis("off") |
| 16 | +plt.show() |
14 | 17 |
|
15 | | -mapper_algo = MapperAlgorithm( |
16 | | - cover=CubicalCover(n_intervals=10, overlap_frac=0.3), clustering=DBSCAN() |
17 | | -) |
18 | | -mapper_graph = mapper_algo.fit_transform(X, lens) |
| 18 | +# Apply PCA as lens |
| 19 | +y = PCA(2, random_state=42).fit_transform(X) |
19 | 20 |
|
20 | | -mapper_plot = MapperPlot(mapper_graph, dim=2, iterations=60, seed=42) |
21 | | - |
22 | | -fig = mapper_plot.plot_plotly( |
23 | | - title="", |
24 | | - width=600, |
25 | | - height=600, |
26 | | - colors=y, # color according to categorical values |
27 | | - cmap="jet", # Jet colormap, for classes |
28 | | - agg=np.nanmean, # aggregate on nodes according to mean |
29 | | -) |
30 | | - |
31 | | -fig.show(config={"scrollZoom": True}) |
32 | | - |
33 | | -mapper_plot.plot_plotly_update( |
34 | | - fig, # reuse the plot with the same positions |
35 | | - colors=y, |
36 | | - cmap="viridis", # viridis colormap, for ranges |
37 | | - agg=np.nanstd, # aggregate on nodes according to std |
38 | | -) |
| 21 | +# Mapper pipeline |
| 22 | +cover = CubicalCover(n_intervals=10, overlap_frac=0.3) |
| 23 | +clust = DBSCAN() |
| 24 | +graph = MapperAlgorithm(cover, clust).fit_transform(X, y) |
39 | 25 |
|
| 26 | +# Visualize the Mapper graph |
| 27 | +fig = MapperPlot(graph, dim=2, seed=42, iterations=60).plot_plotly(colors=labels) |
40 | 28 | fig.show(config={"scrollZoom": True}) |
0 commit comments