Skip to content

Commit cfff154

Browse files
committed
count deg as clustering metric
1 parent ceff773 commit cfff154

3 files changed

Lines changed: 213 additions & 43 deletions

File tree

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1-
[![Documentation Status](https://readthedocs.org/projects/pysinglecellnet/badge/?version=latest)](https://pysinglecellnet.readthedocs.io/en/latest/?badge=latest)
1+
# pySingleCellNet
22

3-
# PySingleCellNet: a computational toolkit for the single cell analysis and comparison of embryos and embryo models
3+
[![Docs Status](https://readthedocs.org/projects/pysinglecellnet/badge/?version=latest)](https://cahanlab-pysinglecellnet.readthedocs-hosted.com/)
4+
5+
6+
[![GitHub stars](https://img.shields.io/github/stars/CahanLab/PySingleCellNet.svg?style=social&label=Star)](https://github.com/CahanLab/PySingleCellNet/stargazers)
7+
8+
[![PyPI version](https://img.shields.io/pypi/v/pySingleCellNet.svg)](https://pypi.org/project/pySingleCellNet/)
9+
10+
**pySingleCellNet** helps you classify and analyze single-cell RNA-Seq data, …
11+
12+
# PySingleCellNet
13+
### A computational toolkit for the single cell analysis and comparison of embryos and embryo models
414
PySingleCellNet (PySCN) predicts the 'cell type' of query scRNA-seq data by Random forest multi-class classification. See [Tan & Cahan 2019] for more details. PySCN includes functionality to aid in the analysis of engineered cell populations (i.e. cells derived via directed differentiation of pluripotent stem cells or via direct conversion).
515

616
[Tan & Cahan 2019]: https://doi.org/10.1016/j.cels.2019.06.004

docs/install.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
1-
PySingleCellnet depends on several packages, most of which can be installed with pip, and some are available in conda channels.
1+
The easiest way to install PySingleCellNet is available on PyPI. Install it from there with:
22

3-
I recommend pre-installing the following:
4-
5-
```shell
6-
pip install scanpy python-igraph leidenalg
3+
```
4+
pip install pySingleCellNet
75
```
86

9-
Then, you should be able to install pySingleCellNet with PIP as follows:
7+
Alternatively, you can install directly from GitHub with:
108

119
```shell
1210
pip install git+https://github.com/CahanLab/pySingleCellNet.git
1311
```
14-
15-
This will install any remaining required packages, too

src/pySingleCellNet/utils/cell.py

Lines changed: 197 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
import pandas as pd
35
# from anndata import AnnData
@@ -6,15 +8,176 @@
68
import anndata as ad
79
# from scipy.sparse import issparse
810
# from alive_progress import alive_bar
9-
from scipy.stats import median_abs_deviation
11+
from scipy.stats import median_abs_deviation, ttest_ind
1012
# import string
1113
import igraph as ig
14+
from scipy import sparse
15+
from sklearn.decomposition import PCA
16+
17+
def clustering_quality_vs_nn(
18+
adata,
19+
label_col: str,
20+
n_genes: int = 5,
21+
naive: dict = {"p_val": 1e-2, "fold_change": 0.25},
22+
strict: dict = {"minpercentin": 0.20, "maxpercentout": 0.10},
23+
n_pcs_for_nn: int = 40,
24+
):
25+
"""
26+
For each cluster, find its nearest neighbor cluster and count genes that meet
27+
(1) naive DE criteria: p-value <= naive['p_val'] AND log2FC >= naive['fold_change']
28+
(2) strict DE criteria: pct_in >= strict['minpercentin'] AND pct_out <= strict['maxpercentout'].
29+
30+
Parameters
31+
----------
32+
adata : AnnData
33+
Single-cell object. Uses adata.X for expression (dense or sparse).
34+
label_col : str
35+
.obs column with cluster labels.
36+
n_genes : int, default 5
37+
For reporting convenience: include top n gene names (by effect size) for each rule.
38+
Does not affect the counts.
39+
naive : dict
40+
{'p_val': float, 'fold_change': float} ; fold_change is on log2 scale.
41+
strict : dict
42+
{'minpercentin': float, 'maxpercentout': float}
43+
n_pcs_for_nn : int, default 30
44+
Number of PCs (if available or computed) for nearest-neighbor cluster search.
45+
46+
Returns
47+
-------
48+
pandas.DataFrame
49+
Columns: ['cluster', 'nn_cluster', 'n_genes_naive', 'n_genes_strict',
50+
'top_naive_genes', 'top_strict_genes']
51+
"""
52+
# ----- helpers -----
53+
def _to_dense(X):
54+
return X.A if sparse.issparse(X) else np.asarray(X)
55+
56+
def _get_representation(adata, n_pcs_for_nn):
57+
# Prefer existing PCA; otherwise compute a compact representation
58+
if "X_pca" in adata.obsm and adata.obsm["X_pca"].shape[1] >= min(2, n_pcs_for_nn):
59+
rep = adata.obsm["X_pca"][:, :n_pcs_for_nn]
60+
return np.asarray(rep)
61+
# else: compute PCA on log1p(counts), on a subset of genes to keep things light
62+
X = _to_dense(adata.X)
63+
# choose up to 2000 HVGs if available; else top-1000 variable genes
64+
if "highly_variable" in adata.var.columns and adata.var["highly_variable"].any():
65+
genes_mask = adata.var["highly_variable"].values
66+
else:
67+
# compute variance per gene quickly
68+
var = X.var(axis=0)
69+
topk = min(1000, X.shape[1])
70+
genes_mask = np.zeros(X.shape[1], dtype=bool)
71+
genes_mask[np.argsort(var)[-topk:]] = True
72+
Xg = np.log1p(X[:, genes_mask])
73+
# center per gene
74+
Xg = Xg - Xg.mean(axis=0, keepdims=True)
75+
pca = PCA(n_components=min(n_pcs_for_nn, Xg.shape[1], Xg.shape[0]-1))
76+
return pca.fit_transform(Xg)
77+
78+
def _cluster_centroids(rep, labels):
79+
centroids = {}
80+
for c in labels.unique():
81+
idx = (labels == c).values
82+
if idx.sum() == 0:
83+
continue
84+
centroids[c] = rep[idx].mean(axis=0)
85+
return centroids
86+
87+
def _nearest_neighbors(centroids):
88+
# For each cluster, find the nearest other cluster (Euclidean)
89+
keys = list(centroids.keys())
90+
arr = np.stack([centroids[k] for k in keys], axis=0)
91+
# pairwise distances
92+
d2 = np.sum((arr[:, None, :] - arr[None, :, :])**2, axis=2)
93+
np.fill_diagonal(d2, np.inf)
94+
nn_idx = np.argmin(d2, axis=1)
95+
return {keys[i]: keys[j] for i, j in enumerate(nn_idx)}
96+
97+
# ----- main -----
98+
if label_col not in adata.obs.columns:
99+
raise ValueError(f"'{label_col}' not found in adata.obs")
100+
101+
labels = adata.obs[label_col].astype("category")
102+
clusters = pd.Index(labels.cat.categories)
103+
104+
rep = _get_representation(adata, n_pcs_for_nn)
105+
centroids = _cluster_centroids(rep, labels)
106+
if len(centroids) < 2:
107+
raise ValueError("Need at least two clusters to compute nearest neighbors.")
108+
nn_map = _nearest_neighbors(centroids)
109+
110+
# Expression matrix (cells x genes), dense for vectorized ops
111+
X = _to_dense(adata.X)
112+
genes = adata.var_names.to_numpy()
113+
114+
rows = []
115+
for c in clusters:
116+
if c not in nn_map:
117+
continue
118+
nnc = nn_map[c]
119+
in_mask = (labels == c).to_numpy()
120+
out_mask = (labels == nnc).to_numpy()
121+
122+
Xin = X[in_mask, :]
123+
Xout = X[out_mask, :]
124+
125+
# log1p for t-test stability; raw means for FC with small epsilon
126+
logXin = np.log1p(Xin)
127+
logXout = np.log1p(Xout)
128+
129+
# Welch t-test per gene
130+
# (scipy vectorizes if axis=0; handle potential NaNs for constant columns)
131+
t_stat, p_vals = ttest_ind(logXin, logXout, equal_var=False, axis=0, nan_policy="omit")
132+
p_vals = np.nan_to_num(p_vals, nan=1.0)
133+
134+
# log2 fold-change on raw means with small epsilon
135+
eps = 1e-9
136+
mu_in = Xin.mean(axis=0) + eps
137+
mu_out = Xout.mean(axis=0) + eps
138+
log2fc = np.log2(mu_in / mu_out)
139+
140+
# pct expressed
141+
pct_in = (Xin > 0).mean(axis=0)
142+
pct_out = (Xout > 0).mean(axis=0)
143+
144+
# criteria
145+
naive_mask = (p_vals <= float(naive["p_val"])) & (log2fc >= float(naive["fold_change"]))
146+
strict_mask = (pct_in >= float(strict["minpercentin"])) & (pct_out <= float(strict["maxpercentout"]))
147+
148+
n_naive = int(naive_mask.sum())
149+
n_strict = int(strict_mask.sum())
150+
151+
# top-gene reporting (up to n_genes), sorted by effect size
152+
top_naive = genes[naive_mask]
153+
if top_naive.size:
154+
ord_ix = np.argsort(-log2fc[naive_mask])
155+
top_naive = top_naive[ord_ix][:n_genes]
156+
top_strict = genes[strict_mask]
157+
if top_strict.size:
158+
ord_ix = np.argsort(-log2fc[strict_mask])
159+
top_strict = top_strict[ord_ix][:n_genes]
160+
161+
rows.append({
162+
"cluster": c,
163+
"nn_cluster": nnc,
164+
"n_genes_naive": n_naive,
165+
"n_genes_strict": n_strict,
166+
"top_naive_genes": ";".join(map(str, top_naive)) if top_naive.size else "",
167+
"top_strict_genes": ";".join(map(str, top_strict)) if top_strict.size else "",
168+
})
169+
170+
out = pd.DataFrame(rows).sort_values(["cluster"]).reset_index(drop=True)
171+
return out
172+
173+
174+
12175

13176

14177
def cluster_subclusters(
15178
adata: ad.AnnData,
16179
cluster_column: str = 'leiden',
17-
cluster_name: str = None,
180+
to_subcluster: list[str] = None,
18181
layer: str = 'counts',
19182
n_hvg: int = 2000,
20183
n_pcs: int = 40,
@@ -23,17 +186,19 @@ def cluster_subclusters(
23186
subcluster_col_name: str = 'subcluster'
24187
) -> None:
25188
"""
26-
Subcluster a specified cluster (or all clusters) within an AnnData object by recomputing HVGs, PCA,
189+
Subcluster selected clusters (or all clusters) within an AnnData object by recomputing HVGs, PCA,
27190
kNN graph, and Leiden clustering. Updates the AnnData object in-place, adding or updating
28191
the `subcluster_col_name` column in `.obs` with new labels prefixed by the original cluster.
29-
192+
193+
Cells in clusters not listed in `to_subcluster` retain their original cluster label as their "subcluster".
194+
30195
Args:
31196
adata: AnnData
32197
The AnnData object containing precomputed clusters in `.obs[cluster_column]`.
33198
cluster_column: str, optional
34199
Name of the `.obs` column holding the original cluster assignments. Default is 'leiden'.
35-
cluster_name: str or None, optional
36-
Specific cluster label to subcluster. If `None`, applies to all clusters. Default is None.
200+
to_subcluster: list of str, optional
201+
List of cluster labels (as strings) to subcluster. If `None`, subclusters *all* clusters.
37202
layer: str, optional
38203
Layer name in `adata.layers` to use for HVG detection. Default is 'counts'.
39204
n_hvg: int, optional
@@ -46,54 +211,53 @@ def cluster_subclusters(
46211
Resolution parameter for Leiden clustering. Default is 0.25.
47212
subcluster_col_name: str, optional
48213
Name of the `.obs` column to store subcluster labels. Default is 'subcluster'.
49-
214+
50215
Raises:
51216
ValueError: If `cluster_column` not in `adata.obs`.
52217
ValueError: If `layer` not in `adata.layers`.
53-
ValueError: If `cluster_name` is specified but not found in `adata.obs[cluster_column]`.
218+
ValueError: If any entry in `to_subcluster` is not found in `adata.obs[cluster_column]`.
54219
"""
55220
# Error checking
56221
if cluster_column not in adata.obs:
57222
raise ValueError(f"Cluster column '{cluster_column}' not found in adata.obs")
58223
if layer not in adata.layers:
59224
raise ValueError(f"Layer '{layer}' not found in adata.layers")
60-
61-
# Convert original clusters to string
225+
226+
# Cast original clusters to string
62227
adata.obs['original_cluster'] = adata.obs[cluster_column].astype(str)
63-
64-
# Ensure subcluster column exists
65-
adata.obs[subcluster_col_name] = ""
66-
67-
# Validate cluster_name
68-
unique_clusters = adata.obs['original_cluster'].unique()
69-
if cluster_name is not None:
70-
if str(cluster_name) not in unique_clusters:
71-
raise ValueError(
72-
f"Cluster '{cluster_name}' not found in adata.obs['{cluster_column}']"
73-
)
74-
clusters_to_process = [str(cluster_name)]
228+
adata.obs[subcluster_col_name] = adata.obs['original_cluster']
229+
230+
# Determine clusters to process
231+
unique_clusters = set(adata.obs['original_cluster'])
232+
if to_subcluster is None:
233+
clusters_to_process = sorted(unique_clusters)
75234
else:
76-
clusters_to_process = unique_clusters
77-
78-
# Iterate and subcluster
235+
# ensure strings
236+
requested = {str(c) for c in to_subcluster}
237+
missing = requested - unique_clusters
238+
if missing:
239+
raise ValueError(f"Clusters not found: {missing}")
240+
clusters_to_process = sorted(requested)
241+
242+
# Iterate and subcluster each requested cluster
79243
for orig in clusters_to_process:
80244
mask = adata.obs['original_cluster'] == orig
81245
sub = adata[mask].copy()
82-
246+
83247
# 1) Compute HVGs
84248
sc.pp.highly_variable_genes(
85249
sub,
86250
flavor='seurat_v3',
87251
n_top_genes=n_hvg,
88252
layer=layer
89253
)
90-
254+
91255
# 2) PCA
92256
sc.pp.pca(sub, n_comps=n_pcs, use_highly_variable=True)
93-
257+
94258
# 3) kNN
95259
sc.pp.neighbors(sub, n_neighbors=n_neighbors, use_rep='X_pca')
96-
260+
97261
# 4) Leiden
98262
sc.tl.leiden(
99263
sub,
@@ -102,10 +266,10 @@ def cluster_subclusters(
102266
n_iterations=2,
103267
key_added='leiden_sub'
104268
)
105-
106-
# Prefix and assign back
107-
labels = (orig + "_" + sub.obs['leiden_sub'].astype(str)).values
108-
adata.obs.loc[mask, subcluster_col_name] = labels
269+
270+
# Prefix subcluster labels and write back
271+
new_labels = orig + "_" + sub.obs['leiden_sub'].astype(str)
272+
adata.obs.loc[mask, subcluster_col_name] = new_labels.values
109273

110274

111275

0 commit comments

Comments
 (0)