Skip to content

Commit 44abd57

Browse files
feat: add plot_feature_clustering and plot_cluster_populations
Add clustering visualization helpers to mdpp.plots: - plot_feature_clustering: scatter of PCA/TICA projections colored by cluster labels with cluster centers overlaid - plot_cluster_populations: bar chart of frame counts per cluster, ranked largest-first, with noise frame exclusion Update examples/gromacs/clustering notebook to use the new functions instead of inline plotting code.
1 parent 4b1731c commit 44abd57

4 files changed

Lines changed: 352 additions & 70 deletions

File tree

examples/gromacs/clustering.ipynb

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,7 @@
1212
"id": "3c411c15",
1313
"metadata": {},
1414
"outputs": [],
15-
"source": [
16-
"from __future__ import annotations\n",
17-
"\n",
18-
"from pathlib import Path\n",
19-
"\n",
20-
"import matplotlib.pyplot as plt\n",
21-
"import numpy as np\n",
22-
"from mplplots.utils import auto_ticks\n",
23-
"\n",
24-
"from mdpp.analysis.clustering import (\n",
25-
" DBSCAN,\n",
26-
" HDBSCAN,\n",
27-
" Gromos,\n",
28-
" Hierarchical,\n",
29-
" KMeans,\n",
30-
" MiniBatchKMeans,\n",
31-
" RegularSpace,\n",
32-
" compute_rmsd_matrix,\n",
33-
")\n",
34-
"from mdpp.analysis.decomposition import compute_pca, featurize_backbone_torsions\n",
35-
"from mdpp.core.trajectory import align_trajectory, load_trajectory\n",
36-
"\n",
37-
"plt.style.use(\"mplplots.styles.GraphPadPrism\")"
38-
]
15+
"source": "from __future__ import annotations\n\nfrom pathlib import Path\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom mplplots.utils import auto_ticks\n\nfrom mdpp.analysis.clustering import (\n DBSCAN,\n HDBSCAN,\n Gromos,\n Hierarchical,\n KMeans,\n MiniBatchKMeans,\n RegularSpace,\n compute_rmsd_matrix,\n)\nfrom mdpp.analysis.decomposition import compute_pca, featurize_backbone_torsions\nfrom mdpp.core.trajectory import align_trajectory, load_trajectory\nfrom mdpp.plots import plot_cluster_populations, plot_feature_clustering\n\nplt.style.use(\"mplplots.styles.GraphPadPrism\")"
3916
},
4017
{
4118
"cell_type": "code",
@@ -213,24 +190,7 @@
213190
"id": "bde64aa1",
214191
"metadata": {},
215192
"outputs": [],
216-
"source": [
217-
"fig, axes = plt.subplots(2, 2, figsize=(12, 8), dpi=120, sharey=True)\n",
218-
"\n",
219-
"for ax, (name, r) in zip(axes.ravel(), results.items()):\n",
220-
" valid = r.labels[r.labels >= 0]\n",
221-
" if len(valid) > 0:\n",
222-
" counts = np.bincount(valid)\n",
223-
" top_k = min(20, len(counts))\n",
224-
" ax.bar(range(top_k), counts[:top_k])\n",
225-
" ax.set_xlabel(\"Cluster\")\n",
226-
" ax.set_title(f\"{name} ({r.n_clusters} clusters)\")\n",
227-
" auto_ticks(ax)\n",
228-
"\n",
229-
"axes[0, 0].set_ylabel(\"Frames\")\n",
230-
"axes[1, 0].set_ylabel(\"Frames\")\n",
231-
"fig.suptitle(f\"Cluster Populations (cutoff = {CUTOFF_NM} nm)\", y=1.02)\n",
232-
"fig.tight_layout()"
233-
]
193+
"source": "fig, axes = plt.subplots(2, 2, figsize=(12, 8), dpi=120, sharey=True)\n\nfor ax, (name, r) in zip(axes.ravel(), results.items()):\n plot_cluster_populations(r, top_k=20, ax=ax)\n ax.set_title(f\"{name} ({r.n_clusters} clusters)\")\n auto_ticks(ax)\n\naxes[0, 0].set_ylabel(\"Frames\")\naxes[1, 0].set_ylabel(\"Frames\")\nfig.suptitle(f\"Cluster Populations (cutoff = {CUTOFF_NM} nm)\", y=1.02)\nfig.tight_layout()"
234194
},
235195
{
236196
"cell_type": "markdown",
@@ -279,34 +239,7 @@
279239
"id": "c0bc4780",
280240
"metadata": {},
281241
"outputs": [],
282-
"source": [
283-
"fig, axes = plt.subplots(1, 3, figsize=(16, 4.5), dpi=120)\n",
284-
"\n",
285-
"for ax, (name, r) in zip(axes, [(\"KMeans\", km), (\"MiniBatch\", mb), (\"RegularSpace\", rs)]):\n",
286-
" sc = ax.scatter(\n",
287-
" pca.projections[:, 0],\n",
288-
" pca.projections[:, 1],\n",
289-
" c=r.labels,\n",
290-
" cmap=\"tab10\",\n",
291-
" s=2,\n",
292-
" alpha=0.4,\n",
293-
" rasterized=True,\n",
294-
" )\n",
295-
" ax.scatter(\n",
296-
" r.cluster_centers[:, 0],\n",
297-
" r.cluster_centers[:, 1],\n",
298-
" c=\"black\",\n",
299-
" marker=\"x\",\n",
300-
" s=100,\n",
301-
" linewidths=2,\n",
302-
" zorder=5,\n",
303-
" )\n",
304-
" ax.set_xlabel(\"PC1\")\n",
305-
" ax.set_ylabel(\"PC2\")\n",
306-
" ax.set_title(f\"{name} ({r.n_clusters} clusters)\")\n",
307-
"\n",
308-
"fig.tight_layout()"
309-
]
242+
"source": "fig, axes = plt.subplots(1, 3, figsize=(16, 4.5), dpi=120)\n\nfor ax, (name, r) in zip(axes, [(\"KMeans\", km), (\"MiniBatch\", mb), (\"RegularSpace\", rs)]):\n plot_feature_clustering(r, pca, s=2, alpha=0.4, ax=ax)\n ax.set_title(f\"{name} ({r.n_clusters} clusters)\")\n\nfig.tight_layout()"
310243
},
311244
{
312245
"cell_type": "markdown",

src/mdpp/plots/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Plotting helpers for MD post-analysis outputs."""
22

3+
from mdpp.plots.clustering import plot_cluster_populations, plot_feature_clustering
34
from mdpp.plots.contacts import contact_frequency_to_matrix, plot_contact_map
45
from mdpp.plots.fes import plot_fes
56
from mdpp.plots.matrix import plot_dccm
@@ -31,11 +32,13 @@
3132
"draw_mols",
3233
"get_highlight_bonds",
3334
"make_atom_labels_3d",
35+
"plot_cluster_populations",
3436
"plot_contact_map",
3537
"plot_dccm",
3638
"plot_delta_rmsf",
3739
"plot_distances",
3840
"plot_energy",
41+
"plot_feature_clustering",
3942
"plot_fes",
4043
"plot_hbond_counts",
4144
"plot_hbond_occupancy",

src/mdpp/plots/clustering.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Clustering visualization helpers."""
2+
3+
from __future__ import annotations
4+
5+
import numpy as np
6+
from matplotlib.axes import Axes
7+
from numpy.typing import NDArray
8+
9+
from mdpp.analysis.clustering import ClusteringResult, FeatureClusteringResult
10+
from mdpp.analysis.decomposition import PCAResult, TICAResult
11+
from mdpp.plots.utils import get_axis
12+
13+
14+
def plot_feature_clustering(
15+
result: FeatureClusteringResult,
16+
projections: PCAResult | TICAResult | NDArray[np.floating],
17+
*,
18+
x_index: int = 0,
19+
y_index: int = 1,
20+
cmap: str = "tab10",
21+
center_color: str = "black",
22+
center_marker: str = "x",
23+
center_size: float = 80.0,
24+
center_linewidths: float = 2.0,
25+
s: float = 2.0,
26+
alpha: float = 0.4,
27+
rasterized: bool = True,
28+
show_centers: bool = True,
29+
ax: Axes | None = None,
30+
) -> Axes:
31+
"""Scatter plot of projections colored by cluster labels with centers.
32+
33+
Args:
34+
result: Feature clustering result (KMeans, MiniBatchKMeans, or
35+
RegularSpace).
36+
projections: 2-D projection coordinates. Accepts a ``PCAResult``,
37+
``TICAResult``, or a raw ``(n_samples, n_features)`` array.
38+
x_index: Column index for the x-axis.
39+
y_index: Column index for the y-axis.
40+
cmap: Colormap for cluster labels.
41+
center_color: Color for cluster center markers.
42+
center_marker: Marker style for cluster centers.
43+
center_size: Marker size for cluster centers.
44+
center_linewidths: Line width for cluster center markers.
45+
s: Marker size for data points.
46+
alpha: Marker transparency for data points.
47+
rasterized: Whether to rasterize the scatter layer (recommended
48+
for large trajectories to keep file sizes small).
49+
show_centers: Whether to overlay cluster center markers.
50+
ax: Optional matplotlib axis.
51+
52+
Returns:
53+
The matplotlib axis with the scatter plot.
54+
"""
55+
axis = get_axis(ax)
56+
57+
if isinstance(projections, PCAResult | TICAResult):
58+
coords = projections.projections
59+
else:
60+
coords = np.asarray(projections)
61+
62+
x = coords[:, x_index]
63+
y = coords[:, y_index]
64+
65+
axis.scatter(
66+
x,
67+
y,
68+
c=result.labels,
69+
cmap=cmap,
70+
s=s,
71+
alpha=alpha,
72+
edgecolors="none",
73+
rasterized=rasterized,
74+
)
75+
76+
if show_centers:
77+
axis.scatter(
78+
result.cluster_centers[:, x_index],
79+
result.cluster_centers[:, y_index],
80+
c=center_color,
81+
marker=center_marker,
82+
s=center_size,
83+
linewidths=center_linewidths,
84+
zorder=5,
85+
)
86+
87+
if isinstance(projections, PCAResult):
88+
axis.set_xlabel(f"PC{x_index + 1}")
89+
axis.set_ylabel(f"PC{y_index + 1}")
90+
elif isinstance(projections, TICAResult):
91+
axis.set_xlabel(f"IC{x_index + 1}")
92+
axis.set_ylabel(f"IC{y_index + 1}")
93+
else:
94+
axis.set_xlabel(f"Component {x_index + 1}")
95+
axis.set_ylabel(f"Component {y_index + 1}")
96+
97+
return axis
98+
99+
100+
def plot_cluster_populations(
101+
result: ClusteringResult | FeatureClusteringResult,
102+
*,
103+
top_k: int = 20,
104+
color: str | None = None,
105+
ax: Axes | None = None,
106+
) -> Axes:
107+
"""Bar chart of cluster populations (frame counts per cluster).
108+
109+
Args:
110+
result: Clustering result from any method.
111+
top_k: Maximum number of clusters to show (largest first).
112+
color: Bar color. If ``None``, uses the default color cycle.
113+
ax: Optional matplotlib axis.
114+
115+
Returns:
116+
The matplotlib axis with the bar chart.
117+
"""
118+
axis = get_axis(ax)
119+
120+
valid = result.labels[result.labels >= 0]
121+
if len(valid) > 0:
122+
counts = np.bincount(valid)
123+
ranked = np.argsort(counts)[::-1]
124+
n_show = min(top_k, len(counts))
125+
ranked_counts = counts[ranked[:n_show]]
126+
if color is not None:
127+
axis.bar(range(n_show), ranked_counts, color=color)
128+
else:
129+
axis.bar(range(n_show), ranked_counts)
130+
131+
axis.set_xlabel("Cluster (ranked)")
132+
axis.set_ylabel("Frames")
133+
134+
return axis

0 commit comments

Comments
 (0)