Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ from tdamapper.plot import MapperPlot

# Generate toy dataset
X, labels = make_circles(n_samples=5000, noise=0.05, factor=0.3, random_state=42)
plt.scatter(X[:,0], X[:,1], c=labels, cmap='jet', s=0.25)
plt.figure(figsize=(5, 5))
plt.scatter(X[:,0], X[:,1], c=labels, s=0.25, cmap="jet")
plt.axis("off")
plt.show()

# Apply PCA as lens
Expand All @@ -105,7 +107,7 @@ graph = MapperAlgorithm(cover, clust).fit_transform(X, y)

# Visualize the Mapper graph
fig = MapperPlot(graph, dim=2, seed=42, iterations=60).plot_plotly(colors=labels)
fig.show(config={'scrollZoom': True})
fig.show(config={"scrollZoom": True})
```

| Original Dataset | Mapper Graph |
Expand Down
81 changes: 35 additions & 46 deletions docs/source/notebooks/circles_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,74 +17,63 @@

# %%
import numpy as np

from matplotlib import pyplot as plt

from sklearn.cluster import DBSCAN
from sklearn.datasets import make_circles
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN

from tdamapper.learn import MapperAlgorithm
from tdamapper.cover import CubicalCover
from tdamapper.learn import MapperAlgorithm
from tdamapper.plot import MapperPlot

X, y = make_circles( # load a labelled dataset
n_samples=5000,
noise=0.05,
factor=0.3,
random_state=42
)
lens = PCA(2, random_state=42).fit_transform(X)
width, height, dpi = 500, 500, 100

# Generate toy dataset
X, labels = make_circles(n_samples=5000, noise=0.05, factor=0.3, random_state=42)

fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi)
plt.scatter(X[:, 0], X[:, 1], c=labels, s=0.25, cmap="jet")
plt.axis("off")
plt.show()
# fig.savefig("circles_dataset.png", dpi=dpi)

# Apply PCA as lens
y = PCA(2, random_state=42).fit_transform(X)

plt.scatter(lens[:, 0], lens[:, 1], c=y, cmap='jet')

# %% [markdown]
# ### Build Mapper graph

# %%
mapper_algo = MapperAlgorithm(
cover=CubicalCover(
n_intervals=10,
overlap_frac=0.3
),
clustering=DBSCAN()
)

mapper_graph = mapper_algo.fit_transform(X, lens)
cover = CubicalCover(n_intervals=10, overlap_frac=0.3)
clust = DBSCAN()
mapper = MapperAlgorithm(cover=cover, clustering=clust)
graph = mapper.fit_transform(X, y)

# %% [markdown]
# ### Plot Mapper graph with mean

# %%
mapper_plot = MapperPlot(
mapper_graph,
dim=2,
iterations=60,
seed=42
)
plot = MapperPlot(graph, dim=2, iterations=60, seed=42)

fig = mapper_plot.plot_plotly(
colors=y, # color according to categorical values
cmap='jet', # Jet colormap, for classes
agg=np.nanmean, # aggregate on nodes according to mean
fig = plot.plot_plotly(
colors=labels, # color according to categorical values
cmap="jet", # Jet colormap, for classes
agg=np.nanmean, # aggregate on nodes according to mean
width=600,
height=600
height=600,
)

fig.show(
renderer='notebook_connected',
config={'scrollZoom': True}
)
fig.show(renderer="notebook_connected", config={"scrollZoom": True})
# fig.write_image("circles_mean.png", width=width, height=height)

# %%
mapper_plot.plot_plotly_update(
fig, # update the old figure
colors=y,
cmap='viridis', # viridis colormap, for ranges
agg=np.nanstd # aggregate on nodes according to std
)

fig.show(
renderer='notebook_connected',
config={'scrollZoom': True}
plot.plot_plotly_update(
fig, # update the old figure
colors=labels,
cmap="viridis", # viridis colormap, for ranges
agg=np.nanstd, # aggregate on nodes according to std
)

fig.show(renderer="notebook_connected", config={"scrollZoom": True})
# fig.write_image("circles_std.png", width=width, height=height)
82 changes: 43 additions & 39 deletions docs/source/notebooks/digits_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,74 +17,78 @@

# %%
import numpy as np

from sklearn.datasets import load_digits
from sklearn.cluster import AgglomerativeClustering
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA

from tdamapper.learn import MapperAlgorithm
from tdamapper.cover import CubicalCover
from tdamapper.clustering import FailSafeClustering
from tdamapper.cover import CubicalCover
from tdamapper.learn import MapperAlgorithm
from tdamapper.plot import MapperPlot

# We load a labelled dataset
X, labels = load_digits(return_X_y=True)

X, y = load_digits(return_X_y=True) # We load a labelled dataset
lens = PCA(2, random_state=42).fit_transform(X) # We compute the lens values
# Apply PCA as lens
y = PCA(2, random_state=42).fit_transform(X)

# %% [markdown]
# ### Build Mapper graph

# %%
mapper_algo = MapperAlgorithm(
cover=CubicalCover(
n_intervals=10,
overlap_frac=0.65
),
algo = MapperAlgorithm(
cover=CubicalCover(n_intervals=10, overlap_frac=0.5),
clustering=AgglomerativeClustering(10),
verbose=False
verbose=False,
)

mapper_graph = mapper_algo.fit_transform(X, lens)
graph = algo.fit_transform(X, y)

# %% [markdown]
# ### Plot Mapper graph with mean

# %%
mapper_plot = MapperPlot(
mapper_graph,
dim=2,
iterations=400,
seed=42
)
plot = MapperPlot(graph, dim=3, iterations=400, seed=42)

fig = mapper_plot.plot_plotly(
colors=y, # We color according to digit values
cmap='jet', # Jet colormap, used for classes
agg=np.nanmean, # We aggregate on graph nodes according to mean
title='digit (mean)',
fig = plot.plot_plotly(
colors=labels, # We color according to digit values
cmap="jet", # Jet colormap, used for classes
agg=np.nanmean, # We aggregate on graph nodes according to mean
title="digit (mean)",
width=600,
height=600
height=600,
)

fig.show(
renderer='notebook_connected',
config={'scrollZoom': True}
)
fig.show(renderer="notebook_connected", config={"scrollZoom": True})

# %% [markdown]
# ### Plot Mapper graph with standard deviation

# %%
fig = mapper_plot.plot_plotly(
colors=y,
cmap='viridis', # Viridis colormap, used for ranges
agg=np.nanstd, # We aggregate on graph nodes according to std
title='digit (std)',
fig = plot.plot_plotly(
colors=labels,
cmap="viridis", # Viridis colormap, used for ranges
agg=np.nanstd, # We aggregate on graph nodes according to std
title="digit (std)",
width=600,
height=600
height=600,
)

fig.show(
renderer='notebook_connected',
config={'scrollZoom': True}
)
fig.show(renderer="notebook_connected", config={"scrollZoom": True})

# %% [markdown]
# ### Inspect interesting nodes

# %%
from matplotlib import pyplot as plt

# By interacting with the plot we see that node 140 is joining the cluster of
# digit 0 with the cluster of digit 4. Let's see how digits inside look like!

node_140 = [X[i, :] for i in graph.nodes()[140]["ids"]]
fig, axes = plt.subplots(1, len(node_140))
for dgt, ax in zip(node_140, axes):
ax.imshow(dgt.reshape(8, 8), cmap="gray")
ax.axis("off")
plt.tight_layout()
plt.show()