Skip to content

Commit e987f54

Browse files
docs: add clustering example notebook with all 7 methods
New examples/gromacs/clustering.ipynb demonstrates: - Distance-matrix methods: Gromos, Hierarchical, DBSCAN (numba + sklearn), HDBSCAN with comparison table and population plots - Feature-vector methods: KMeans, MiniBatchKMeans, RegularSpace on PCA-projected backbone torsions with PC1/PC2 scatter plots - Medoid structure extraction
1 parent 21a6d11 commit e987f54

1 file changed

Lines changed: 348 additions & 0 deletions

File tree

examples/gromacs/clustering.ipynb

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "6c0add61",
6+
"metadata": {},
7+
"source": "# `mdpp` Example: Conformational Clustering\n\nThis notebook demonstrates all clustering methods available in `mdpp`:\n\n**Distance-matrix methods** (operate on a pairwise RMSD matrix):\n\n- `Gromos` -- greedy largest-cluster-first (Numba JIT, O(n) aux memory)\n- `Hierarchical` -- agglomerative clustering (scipy)\n- `DBSCAN` -- density-based with noise detection (Numba JIT or sklearn)\n- `HDBSCAN` -- hierarchical density-based (sklearn)\n\n**Feature-vector methods** (operate on PCA/TICA projections):\n\n- `KMeans` -- standard k-means (sklearn)\n- `MiniBatchKMeans` -- scalable mini-batch variant (sklearn)\n- `RegularSpace` -- regular-space discretization (deeptime)\n\nEach method is a frozen dataclass configured at construction and called on data:\n\n```python\nresult = Gromos(cutoff_nm=0.15)(rmsd_matrix)\nresult = KMeans(n_clusters=10)(pca.projections)\n```"
8+
},
9+
{
10+
"cell_type": "code",
11+
"execution_count": null,
12+
"id": "3c411c15",
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"from __future__ import annotations\n",
17+
"\n",
18+
"from pathlib import Path\n",
19+
"\n",
20+
"import matplotlib.pyplot as plt\n",
21+
"import numpy as np\n",
22+
"from mplplots.utils import auto_ticks\n",
23+
"\n",
24+
"from mdpp.analysis.clustering import (\n",
25+
" DBSCAN,\n",
26+
" HDBSCAN,\n",
27+
" Gromos,\n",
28+
" Hierarchical,\n",
29+
" KMeans,\n",
30+
" MiniBatchKMeans,\n",
31+
" RegularSpace,\n",
32+
" compute_rmsd_matrix,\n",
33+
")\n",
34+
"from mdpp.analysis.decomposition import compute_pca, featurize_backbone_torsions\n",
35+
"from mdpp.core.trajectory import align_trajectory, load_trajectory\n",
36+
"\n",
37+
"plt.style.use(\"mplplots.styles.GraphPadPrism\")"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"id": "39804045",
44+
"metadata": {},
45+
"outputs": [],
46+
"source": [
47+
"TOPOLOGY_PATH = Path(\"/path/to/topology.pdb\")\n",
48+
"TRAJECTORY_PATH = Path(\"/path/to/trajectory.xtc\")\n",
49+
"STRIDE = 5\n",
50+
"CUTOFF_NM = 0.15\n",
51+
"\n",
52+
"if not TOPOLOGY_PATH.exists() or not TRAJECTORY_PATH.exists():\n",
53+
" raise FileNotFoundError(\n",
54+
" \"Update TOPOLOGY_PATH and TRAJECTORY_PATH before running analysis cells.\"\n",
55+
" )\n",
56+
"\n",
57+
"traj = load_trajectory(\n",
58+
" trajectory_path=TRAJECTORY_PATH,\n",
59+
" topology_path=TOPOLOGY_PATH,\n",
60+
" stride=STRIDE,\n",
61+
" atom_selection=\"protein\",\n",
62+
")\n",
63+
"traj = align_trajectory(traj, atom_selection=\"name CA\")\n",
64+
"\n",
65+
"print(f\"Frames: {traj.n_frames}, Atoms: {traj.n_atoms}\")"
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"id": "ce40bec9",
71+
"metadata": {},
72+
"source": "## Compute RMSD Matrix\n\nThe pairwise RMSD matrix is shared by all distance-matrix clustering methods.\nUse `backend=\"numba\"` or `backend=\"torch\"` for large trajectories."
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"id": "36584688",
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"rmsd_mat = compute_rmsd_matrix(traj, atom_selection=\"backbone\", backend=\"numba\")\n",
82+
"\n",
83+
"print(f\"RMSD matrix: {rmsd_mat.rmsd_matrix_nm.shape}, dtype={rmsd_mat.rmsd_matrix_nm.dtype}\")\n",
84+
"print(f\"Range: {rmsd_mat.rmsd_matrix_nm.max():.3f} nm\")"
85+
]
86+
},
87+
{
88+
"cell_type": "markdown",
89+
"id": "dd1d4045",
90+
"metadata": {},
91+
"source": "## Distance-Matrix Methods\n\n### GROMOS\n\nGreedy largest-cluster-first assignment. Custom Numba kernel with O(n) auxiliary memory -- handles 120k+ frames."
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": null,
96+
"id": "c5ffad58",
97+
"metadata": {},
98+
"outputs": [],
99+
"source": [
100+
"gromos = Gromos(cutoff_nm=CUTOFF_NM)(rmsd_mat.rmsd_matrix_nm)\n",
101+
"\n",
102+
"print(f\"GROMOS: {gromos.n_clusters} clusters\")\n",
103+
"for i in range(min(5, gromos.n_clusters)):\n",
104+
" count = int(np.sum(gromos.labels == i))\n",
105+
" print(f\" Cluster {i}: {count} frames, medoid={gromos.medoid_frames[i]}\")"
106+
]
107+
},
108+
{
109+
"cell_type": "markdown",
110+
"id": "ceb1c9b1",
111+
"metadata": {},
112+
"source": "### Hierarchical\n\nAgglomerative clustering via scipy. Supports `distance_threshold` (default) or fixed `n_clusters`."
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"id": "92b829ce",
118+
"metadata": {},
119+
"outputs": [],
120+
"source": [
121+
"# Distance-threshold mode (like GROMOS cutoff)\n",
122+
"hier_dist = Hierarchical(\n",
123+
" linkage_method=\"average\",\n",
124+
" distance_threshold=CUTOFF_NM,\n",
125+
")(rmsd_mat.rmsd_matrix_nm)\n",
126+
"\n",
127+
"# Fixed cluster count mode\n",
128+
"hier_k = Hierarchical(\n",
129+
" linkage_method=\"average\",\n",
130+
" n_clusters=5,\n",
131+
")(rmsd_mat.rmsd_matrix_nm)\n",
132+
"\n",
133+
"print(f\"Hierarchical (distance_threshold={CUTOFF_NM}): {hier_dist.n_clusters} clusters\")\n",
134+
"print(f\"Hierarchical (n_clusters=5): {hier_k.n_clusters} clusters\")"
135+
]
136+
},
137+
{
138+
"cell_type": "markdown",
139+
"id": "3e7f99fa",
140+
"metadata": {},
141+
"source": "### DBSCAN\n\nDensity-based clustering with noise detection. Frames that don't belong to any dense region get label -1.\n\nThe default `backend=\"numba\"` uses a custom Numba kernel with O(n) auxiliary memory. Pass `backend=\"sklearn\"` for the official scikit-learn implementation."
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": null,
146+
"id": "7b330e5f",
147+
"metadata": {},
148+
"outputs": [],
149+
"source": [
150+
"dbscan = DBSCAN(eps=CUTOFF_NM, min_samples=5)(rmsd_mat.rmsd_matrix_nm)\n",
151+
"\n",
152+
"noise = int(np.sum(dbscan.labels == -1))\n",
153+
"print(f\"DBSCAN: {dbscan.n_clusters} clusters, {noise} noise frames\")\n",
154+
"\n",
155+
"# sklearn backend for comparison\n",
156+
"dbscan_sk = DBSCAN(eps=CUTOFF_NM, min_samples=5, backend=\"sklearn\")(rmsd_mat.rmsd_matrix_nm)\n",
157+
"print(f\"DBSCAN (sklearn): {dbscan_sk.n_clusters} clusters\")"
158+
]
159+
},
160+
{
161+
"cell_type": "markdown",
162+
"id": "390f9624",
163+
"metadata": {},
164+
"source": "### HDBSCAN\n\nHierarchical density-based clustering via sklearn. Handles clusters of varying density without an epsilon parameter."
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": null,
169+
"id": "6b4d0516",
170+
"metadata": {},
171+
"outputs": [],
172+
"source": [
173+
"hdbscan = HDBSCAN(min_cluster_size=10, min_samples=5)(rmsd_mat.rmsd_matrix_nm)\n",
174+
"\n",
175+
"noise = int(np.sum(hdbscan.labels == -1))\n",
176+
"print(f\"HDBSCAN: {hdbscan.n_clusters} clusters, {noise} noise frames\")"
177+
]
178+
},
179+
{
180+
"cell_type": "markdown",
181+
"id": "0d3594cd",
182+
"metadata": {},
183+
"source": "### Compare Distance-Matrix Methods"
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": null,
188+
"id": "70da02ff",
189+
"metadata": {},
190+
"outputs": [],
191+
"source": [
192+
"results = {\n",
193+
" \"GROMOS\": gromos,\n",
194+
" \"Hierarchical\": hier_dist,\n",
195+
" \"DBSCAN\": dbscan,\n",
196+
" \"HDBSCAN\": hdbscan,\n",
197+
"}\n",
198+
"\n",
199+
"n_total = traj.n_frames\n",
200+
"print(f\"{'Method':<15s} {'Clusters':>10s} {'Noise':>8s} {'Largest':>10s}\")\n",
201+
"print(\"-\" * 45)\n",
202+
"for name, r in results.items():\n",
203+
" noise = int(np.sum(r.labels == -1))\n",
204+
" valid = r.labels[r.labels >= 0]\n",
205+
" largest = int(np.bincount(valid).max()) if len(valid) > 0 else 0\n",
206+
" pct = largest / n_total * 100\n",
207+
" print(f\"{name:<15s} {r.n_clusters:>10d} {noise:>8d} {largest:>6d} ({pct:.1f}%)\")"
208+
]
209+
},
210+
{
211+
"cell_type": "code",
212+
"execution_count": null,
213+
"id": "bde64aa1",
214+
"metadata": {},
215+
"outputs": [],
216+
"source": [
217+
"fig, axes = plt.subplots(2, 2, figsize=(12, 8), dpi=120, sharey=True)\n",
218+
"\n",
219+
"for ax, (name, r) in zip(axes.ravel(), results.items()):\n",
220+
" valid = r.labels[r.labels >= 0]\n",
221+
" if len(valid) > 0:\n",
222+
" counts = np.bincount(valid)\n",
223+
" top_k = min(20, len(counts))\n",
224+
" ax.bar(range(top_k), counts[:top_k])\n",
225+
" ax.set_xlabel(\"Cluster\")\n",
226+
" ax.set_title(f\"{name} ({r.n_clusters} clusters)\")\n",
227+
" auto_ticks(ax)\n",
228+
"\n",
229+
"axes[0, 0].set_ylabel(\"Frames\")\n",
230+
"axes[1, 0].set_ylabel(\"Frames\")\n",
231+
"fig.suptitle(f\"Cluster Populations (cutoff = {CUTOFF_NM} nm)\", y=1.02)\n",
232+
"fig.tight_layout()"
233+
]
234+
},
235+
{
236+
"cell_type": "markdown",
237+
"id": "594c705e",
238+
"metadata": {},
239+
"source": "## Feature-Vector Methods\n\nBackbone torsion featurization (sin/cos embedded phi/psi) followed by PCA.\nFeature-based methods scale linearly with N and don't require the RMSD matrix."
240+
},
241+
{
242+
"cell_type": "code",
243+
"execution_count": null,
244+
"id": "c0928a15",
245+
"metadata": {},
246+
"outputs": [],
247+
"source": [
248+
"torsions = featurize_backbone_torsions(traj, atom_selection=\"protein\")\n",
249+
"pca = compute_pca(torsions.values, n_components=10)\n",
250+
"\n",
251+
"print(f\"Torsion features: {torsions.values.shape[1]}\")\n",
252+
"print(\n",
253+
" f\"PCA: {pca.projections.shape[1]} components, \"\n",
254+
" f\"explained variance = {pca.explained_variance_ratio.sum():.1%}\"\n",
255+
")"
256+
]
257+
},
258+
{
259+
"cell_type": "code",
260+
"execution_count": null,
261+
"id": "95842a0d",
262+
"metadata": {},
263+
"outputs": [],
264+
"source": [
265+
"N_CLUSTERS = 5\n",
266+
"\n",
267+
"km = KMeans(n_clusters=N_CLUSTERS)(pca.projections)\n",
268+
"mb = MiniBatchKMeans(n_clusters=N_CLUSTERS, batch_size=256)(pca.projections)\n",
269+
"rs = RegularSpace(dmin=1.0)(pca.projections)\n",
270+
"\n",
271+
"print(f\"KMeans: {km.n_clusters} clusters, inertia={km.inertia:.1f}\")\n",
272+
"print(f\"MiniBatchKMeans: {mb.n_clusters} clusters, inertia={mb.inertia:.1f}\")\n",
273+
"print(f\"RegularSpace: {rs.n_clusters} clusters (dmin=1.0)\")"
274+
]
275+
},
276+
{
277+
"cell_type": "code",
278+
"execution_count": null,
279+
"id": "c0bc4780",
280+
"metadata": {},
281+
"outputs": [],
282+
"source": [
283+
"fig, axes = plt.subplots(1, 3, figsize=(16, 4.5), dpi=120)\n",
284+
"\n",
285+
"for ax, (name, r) in zip(axes, [(\"KMeans\", km), (\"MiniBatch\", mb), (\"RegularSpace\", rs)]):\n",
286+
" sc = ax.scatter(\n",
287+
" pca.projections[:, 0],\n",
288+
" pca.projections[:, 1],\n",
289+
" c=r.labels,\n",
290+
" cmap=\"tab10\",\n",
291+
" s=2,\n",
292+
" alpha=0.4,\n",
293+
" rasterized=True,\n",
294+
" )\n",
295+
" ax.scatter(\n",
296+
" r.cluster_centers[:, 0],\n",
297+
" r.cluster_centers[:, 1],\n",
298+
" c=\"black\",\n",
299+
" marker=\"x\",\n",
300+
" s=100,\n",
301+
" linewidths=2,\n",
302+
" zorder=5,\n",
303+
" )\n",
304+
" ax.set_xlabel(\"PC1\")\n",
305+
" ax.set_ylabel(\"PC2\")\n",
306+
" ax.set_title(f\"{name} ({r.n_clusters} clusters)\")\n",
307+
"\n",
308+
"fig.tight_layout()"
309+
]
310+
},
311+
{
312+
"cell_type": "markdown",
313+
"id": "639a69e0",
314+
"metadata": {},
315+
"source": "## Save Medoid Structures\n\nExtract representative frames from the GROMOS result."
316+
},
317+
{
318+
"cell_type": "code",
319+
"execution_count": null,
320+
"id": "6b18c017",
321+
"metadata": {},
322+
"outputs": [],
323+
"source": [
324+
"output_dir = Path(\"cluster_medoids\")\n",
325+
"output_dir.mkdir(exist_ok=True)\n",
326+
"\n",
327+
"for i, frame_idx in enumerate(gromos.medoid_frames[:10]):\n",
328+
" out = output_dir / f\"cluster{i}_medoid.pdb\"\n",
329+
" traj[int(frame_idx)].save(str(out))\n",
330+
" count = int(np.sum(gromos.labels == i))\n",
331+
" print(f\"Cluster {i}: {count} frames, medoid frame {frame_idx} -> {out}\")"
332+
]
333+
}
334+
],
335+
"metadata": {
336+
"kernelspec": {
337+
"display_name": "Python 3",
338+
"language": "python",
339+
"name": "python3"
340+
},
341+
"language_info": {
342+
"name": "python",
343+
"version": "3.12.0"
344+
}
345+
},
346+
"nbformat": 4,
347+
"nbformat_minor": 5
348+
}

0 commit comments

Comments
 (0)