|
17 | 17 |
|
18 | 18 | # %% |
19 | 19 | import numpy as np |
20 | | - |
21 | 20 | from matplotlib import pyplot as plt |
22 | | - |
| 21 | +from sklearn.cluster import DBSCAN |
23 | 22 | from sklearn.datasets import make_circles |
24 | 23 | from sklearn.decomposition import PCA |
25 | | -from sklearn.cluster import DBSCAN |
26 | 24 |
|
27 | | -from tdamapper.learn import MapperAlgorithm |
28 | 25 | from tdamapper.cover import CubicalCover |
| 26 | +from tdamapper.learn import MapperAlgorithm |
29 | 27 | from tdamapper.plot import MapperPlot |
30 | 28 |
|
31 | | -X, y = make_circles( # load a labelled dataset |
32 | | - n_samples=5000, |
33 | | - noise=0.05, |
34 | | - factor=0.3, |
35 | | - random_state=42 |
36 | | -) |
37 | | -lens = PCA(2, random_state=42).fit_transform(X) |
| 29 | +width, height, dpi = 500, 500, 100 |
| 30 | + |
| 31 | +# Generate toy dataset |
| 32 | +X, labels = make_circles(n_samples=5000, noise=0.05, factor=0.3, random_state=42) |
| 33 | + |
| 34 | +fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi) |
| 35 | +plt.scatter(X[:, 0], X[:, 1], c=labels, s=0.25, cmap="jet") |
| 36 | +plt.axis("off") |
| 37 | +plt.show() |
| 38 | +# fig.savefig("circles_dataset.png", dpi=dpi) |
| 39 | + |
| 40 | +# Apply PCA as lens |
| 41 | +y = PCA(2, random_state=42).fit_transform(X) |
38 | 42 |
|
39 | | -plt.scatter(lens[:, 0], lens[:, 1], c=y, cmap='jet') |
40 | 43 |
|
41 | 44 | # %% [markdown] |
42 | 45 | # ### Build Mapper graph |
43 | 46 |
|
44 | 47 | # %% |
45 | | -mapper_algo = MapperAlgorithm( |
46 | | - cover=CubicalCover( |
47 | | - n_intervals=10, |
48 | | - overlap_frac=0.3 |
49 | | - ), |
50 | | - clustering=DBSCAN() |
51 | | -) |
52 | | - |
53 | | -mapper_graph = mapper_algo.fit_transform(X, lens) |
| 48 | +cover = CubicalCover(n_intervals=10, overlap_frac=0.3) |
| 49 | +clust = DBSCAN() |
| 50 | +mapper = MapperAlgorithm(cover=cover, clustering=clust) |
| 51 | +graph = mapper.fit_transform(X, y) |
54 | 52 |
|
55 | 53 | # %% [markdown] |
56 | 54 | # ### Plot Mapper graph with mean |
57 | 55 |
|
58 | 56 | # %% |
59 | | -mapper_plot = MapperPlot( |
60 | | - mapper_graph, |
61 | | - dim=2, |
62 | | - iterations=60, |
63 | | - seed=42 |
64 | | -) |
| 57 | +plot = MapperPlot(graph, dim=2, iterations=60, seed=42) |
65 | 58 |
|
66 | | -fig = mapper_plot.plot_plotly( |
67 | | - colors=y, # color according to categorical values |
68 | | - cmap='jet', # Jet colormap, for classes |
69 | | - agg=np.nanmean, # aggregate on nodes according to mean |
| 59 | +fig = plot.plot_plotly( |
| 60 | + colors=labels, # color according to categorical values |
| 61 | + cmap="jet", # Jet colormap, for classes |
| 62 | + agg=np.nanmean, # aggregate on nodes according to mean |
70 | 63 | width=600, |
71 | | - height=600 |
| 64 | + height=600, |
72 | 65 | ) |
73 | 66 |
|
74 | | -fig.show( |
75 | | - renderer='notebook_connected', |
76 | | - config={'scrollZoom': True} |
77 | | -) |
| 67 | +fig.show(renderer="notebook_connected", config={"scrollZoom": True}) |
| 68 | +# fig.write_image("circles_mean.png", width=width, height=height) |
78 | 69 |
|
79 | 70 | # %% |
80 | | -mapper_plot.plot_plotly_update( |
81 | | - fig, # update the old figure |
82 | | - colors=y, |
83 | | - cmap='viridis', # viridis colormap, for ranges |
84 | | - agg=np.nanstd # aggregate on nodes according to std |
85 | | -) |
86 | | - |
87 | | -fig.show( |
88 | | - renderer='notebook_connected', |
89 | | - config={'scrollZoom': True} |
| 71 | +plot.plot_plotly_update( |
| 72 | + fig, # update the old figure |
| 73 | + colors=labels, |
| 74 | + cmap="viridis", # viridis colormap, for ranges |
| 75 | + agg=np.nanstd, # aggregate on nodes according to std |
90 | 76 | ) |
| 77 | + |
| 78 | +fig.show(renderer="notebook_connected", config={"scrollZoom": True}) |
| 79 | +# fig.write_image("circles_std.png", width=width, height=height) |
0 commit comments