1+ from __future__ import annotations
2+
13import numpy as np
24import pandas as pd
35# from anndata import AnnData
68import anndata as ad
79# from scipy.sparse import issparse
810# from alive_progress import alive_bar
9- from scipy .stats import median_abs_deviation
11+ from scipy .stats import median_abs_deviation , ttest_ind
1012# import string
1113import igraph as ig
14+ from scipy import sparse
15+ from sklearn .decomposition import PCA
16+
17+ def clustering_quality_vs_nn (
18+ adata ,
19+ label_col : str ,
20+ n_genes : int = 5 ,
21+ naive : dict = {"p_val" : 1e-2 , "fold_change" : 0.25 },
22+ strict : dict = {"minpercentin" : 0.20 , "maxpercentout" : 0.10 },
23+ n_pcs_for_nn : int = 40 ,
24+ ):
25+ """
26+ For each cluster, find its nearest neighbor cluster and count genes that meet
27+ (1) naive DE criteria: p-value <= naive['p_val'] AND log2FC >= naive['fold_change']
28+ (2) strict DE criteria: pct_in >= strict['minpercentin'] AND pct_out <= strict['maxpercentout'].
29+
30+ Parameters
31+ ----------
32+ adata : AnnData
33+ Single-cell object. Uses adata.X for expression (dense or sparse).
34+ label_col : str
35+ .obs column with cluster labels.
36+ n_genes : int, default 5
37+ For reporting convenience: include top n gene names (by effect size) for each rule.
38+ Does not affect the counts.
39+ naive : dict
40+ {'p_val': float, 'fold_change': float} ; fold_change is on log2 scale.
41+ strict : dict
42+ {'minpercentin': float, 'maxpercentout': float}
43+ n_pcs_for_nn : int, default 30
44+ Number of PCs (if available or computed) for nearest-neighbor cluster search.
45+
46+ Returns
47+ -------
48+ pandas.DataFrame
49+ Columns: ['cluster', 'nn_cluster', 'n_genes_naive', 'n_genes_strict',
50+ 'top_naive_genes', 'top_strict_genes']
51+ """
52+ # ----- helpers -----
53+ def _to_dense (X ):
54+ return X .A if sparse .issparse (X ) else np .asarray (X )
55+
56+ def _get_representation (adata , n_pcs_for_nn ):
57+ # Prefer existing PCA; otherwise compute a compact representation
58+ if "X_pca" in adata .obsm and adata .obsm ["X_pca" ].shape [1 ] >= min (2 , n_pcs_for_nn ):
59+ rep = adata .obsm ["X_pca" ][:, :n_pcs_for_nn ]
60+ return np .asarray (rep )
61+ # else: compute PCA on log1p(counts), on a subset of genes to keep things light
62+ X = _to_dense (adata .X )
63+ # choose up to 2000 HVGs if available; else top-1000 variable genes
64+ if "highly_variable" in adata .var .columns and adata .var ["highly_variable" ].any ():
65+ genes_mask = adata .var ["highly_variable" ].values
66+ else :
67+ # compute variance per gene quickly
68+ var = X .var (axis = 0 )
69+ topk = min (1000 , X .shape [1 ])
70+ genes_mask = np .zeros (X .shape [1 ], dtype = bool )
71+ genes_mask [np .argsort (var )[- topk :]] = True
72+ Xg = np .log1p (X [:, genes_mask ])
73+ # center per gene
74+ Xg = Xg - Xg .mean (axis = 0 , keepdims = True )
75+ pca = PCA (n_components = min (n_pcs_for_nn , Xg .shape [1 ], Xg .shape [0 ]- 1 ))
76+ return pca .fit_transform (Xg )
77+
78+ def _cluster_centroids (rep , labels ):
79+ centroids = {}
80+ for c in labels .unique ():
81+ idx = (labels == c ).values
82+ if idx .sum () == 0 :
83+ continue
84+ centroids [c ] = rep [idx ].mean (axis = 0 )
85+ return centroids
86+
87+ def _nearest_neighbors (centroids ):
88+ # For each cluster, find the nearest other cluster (Euclidean)
89+ keys = list (centroids .keys ())
90+ arr = np .stack ([centroids [k ] for k in keys ], axis = 0 )
91+ # pairwise distances
92+ d2 = np .sum ((arr [:, None , :] - arr [None , :, :])** 2 , axis = 2 )
93+ np .fill_diagonal (d2 , np .inf )
94+ nn_idx = np .argmin (d2 , axis = 1 )
95+ return {keys [i ]: keys [j ] for i , j in enumerate (nn_idx )}
96+
97+ # ----- main -----
98+ if label_col not in adata .obs .columns :
99+ raise ValueError (f"'{ label_col } ' not found in adata.obs" )
100+
101+ labels = adata .obs [label_col ].astype ("category" )
102+ clusters = pd .Index (labels .cat .categories )
103+
104+ rep = _get_representation (adata , n_pcs_for_nn )
105+ centroids = _cluster_centroids (rep , labels )
106+ if len (centroids ) < 2 :
107+ raise ValueError ("Need at least two clusters to compute nearest neighbors." )
108+ nn_map = _nearest_neighbors (centroids )
109+
110+ # Expression matrix (cells x genes), dense for vectorized ops
111+ X = _to_dense (adata .X )
112+ genes = adata .var_names .to_numpy ()
113+
114+ rows = []
115+ for c in clusters :
116+ if c not in nn_map :
117+ continue
118+ nnc = nn_map [c ]
119+ in_mask = (labels == c ).to_numpy ()
120+ out_mask = (labels == nnc ).to_numpy ()
121+
122+ Xin = X [in_mask , :]
123+ Xout = X [out_mask , :]
124+
125+ # log1p for t-test stability; raw means for FC with small epsilon
126+ logXin = np .log1p (Xin )
127+ logXout = np .log1p (Xout )
128+
129+ # Welch t-test per gene
130+ # (scipy vectorizes if axis=0; handle potential NaNs for constant columns)
131+ t_stat , p_vals = ttest_ind (logXin , logXout , equal_var = False , axis = 0 , nan_policy = "omit" )
132+ p_vals = np .nan_to_num (p_vals , nan = 1.0 )
133+
134+ # log2 fold-change on raw means with small epsilon
135+ eps = 1e-9
136+ mu_in = Xin .mean (axis = 0 ) + eps
137+ mu_out = Xout .mean (axis = 0 ) + eps
138+ log2fc = np .log2 (mu_in / mu_out )
139+
140+ # pct expressed
141+ pct_in = (Xin > 0 ).mean (axis = 0 )
142+ pct_out = (Xout > 0 ).mean (axis = 0 )
143+
144+ # criteria
145+ naive_mask = (p_vals <= float (naive ["p_val" ])) & (log2fc >= float (naive ["fold_change" ]))
146+ strict_mask = (pct_in >= float (strict ["minpercentin" ])) & (pct_out <= float (strict ["maxpercentout" ]))
147+
148+ n_naive = int (naive_mask .sum ())
149+ n_strict = int (strict_mask .sum ())
150+
151+ # top-gene reporting (up to n_genes), sorted by effect size
152+ top_naive = genes [naive_mask ]
153+ if top_naive .size :
154+ ord_ix = np .argsort (- log2fc [naive_mask ])
155+ top_naive = top_naive [ord_ix ][:n_genes ]
156+ top_strict = genes [strict_mask ]
157+ if top_strict .size :
158+ ord_ix = np .argsort (- log2fc [strict_mask ])
159+ top_strict = top_strict [ord_ix ][:n_genes ]
160+
161+ rows .append ({
162+ "cluster" : c ,
163+ "nn_cluster" : nnc ,
164+ "n_genes_naive" : n_naive ,
165+ "n_genes_strict" : n_strict ,
166+ "top_naive_genes" : ";" .join (map (str , top_naive )) if top_naive .size else "" ,
167+ "top_strict_genes" : ";" .join (map (str , top_strict )) if top_strict .size else "" ,
168+ })
169+
170+ out = pd .DataFrame (rows ).sort_values (["cluster" ]).reset_index (drop = True )
171+ return out
172+
173+
174+
12175
13176
14177def cluster_subclusters (
15178 adata : ad .AnnData ,
16179 cluster_column : str = 'leiden' ,
17- cluster_name : str = None ,
180+ to_subcluster : list [ str ] = None ,
18181 layer : str = 'counts' ,
19182 n_hvg : int = 2000 ,
20183 n_pcs : int = 40 ,
@@ -23,17 +186,19 @@ def cluster_subclusters(
23186 subcluster_col_name : str = 'subcluster'
24187) -> None :
25188 """
26- Subcluster a specified cluster (or all clusters) within an AnnData object by recomputing HVGs, PCA,
189+ Subcluster selected clusters (or all clusters) within an AnnData object by recomputing HVGs, PCA,
27190 kNN graph, and Leiden clustering. Updates the AnnData object in-place, adding or updating
28191 the `subcluster_col_name` column in `.obs` with new labels prefixed by the original cluster.
29-
192+
193+ Cells in clusters not listed in `to_subcluster` retain their original cluster label as their "subcluster".
194+
30195 Args:
31196 adata: AnnData
32197 The AnnData object containing precomputed clusters in `.obs[cluster_column]`.
33198 cluster_column: str, optional
34199 Name of the `.obs` column holding the original cluster assignments. Default is 'leiden'.
35- cluster_name: str or None , optional
36- Specific cluster label to subcluster. If `None`, applies to all clusters. Default is None .
200+ to_subcluster: list of str , optional
201+ List of cluster labels (as strings) to subcluster. If `None`, subclusters * all* clusters.
37202 layer: str, optional
38203 Layer name in `adata.layers` to use for HVG detection. Default is 'counts'.
39204 n_hvg: int, optional
@@ -46,54 +211,53 @@ def cluster_subclusters(
46211 Resolution parameter for Leiden clustering. Default is 0.25.
47212 subcluster_col_name: str, optional
48213 Name of the `.obs` column to store subcluster labels. Default is 'subcluster'.
49-
214+
50215 Raises:
51216 ValueError: If `cluster_column` not in `adata.obs`.
52217 ValueError: If `layer` not in `adata.layers`.
53- ValueError: If `cluster_name ` is specified but not found in `adata.obs[cluster_column]`.
218+ ValueError: If any entry in `to_subcluster ` is not found in `adata.obs[cluster_column]`.
54219 """
55220 # Error checking
56221 if cluster_column not in adata .obs :
57222 raise ValueError (f"Cluster column '{ cluster_column } ' not found in adata.obs" )
58223 if layer not in adata .layers :
59224 raise ValueError (f"Layer '{ layer } ' not found in adata.layers" )
60-
61- # Convert original clusters to string
225+
226+ # Cast original clusters to string
62227 adata .obs ['original_cluster' ] = adata .obs [cluster_column ].astype (str )
63-
64- # Ensure subcluster column exists
65- adata .obs [subcluster_col_name ] = ""
66-
67- # Validate cluster_name
68- unique_clusters = adata .obs ['original_cluster' ].unique ()
69- if cluster_name is not None :
70- if str (cluster_name ) not in unique_clusters :
71- raise ValueError (
72- f"Cluster '{ cluster_name } ' not found in adata.obs['{ cluster_column } ']"
73- )
74- clusters_to_process = [str (cluster_name )]
228+ adata .obs [subcluster_col_name ] = adata .obs ['original_cluster' ]
229+
230+ # Determine clusters to process
231+ unique_clusters = set (adata .obs ['original_cluster' ])
232+ if to_subcluster is None :
233+ clusters_to_process = sorted (unique_clusters )
75234 else :
76- clusters_to_process = unique_clusters
77-
78- # Iterate and subcluster
235+ # ensure strings
236+ requested = {str (c ) for c in to_subcluster }
237+ missing = requested - unique_clusters
238+ if missing :
239+ raise ValueError (f"Clusters not found: { missing } " )
240+ clusters_to_process = sorted (requested )
241+
242+ # Iterate and subcluster each requested cluster
79243 for orig in clusters_to_process :
80244 mask = adata .obs ['original_cluster' ] == orig
81245 sub = adata [mask ].copy ()
82-
246+
83247 # 1) Compute HVGs
84248 sc .pp .highly_variable_genes (
85249 sub ,
86250 flavor = 'seurat_v3' ,
87251 n_top_genes = n_hvg ,
88252 layer = layer
89253 )
90-
254+
91255 # 2) PCA
92256 sc .pp .pca (sub , n_comps = n_pcs , use_highly_variable = True )
93-
257+
94258 # 3) kNN
95259 sc .pp .neighbors (sub , n_neighbors = n_neighbors , use_rep = 'X_pca' )
96-
260+
97261 # 4) Leiden
98262 sc .tl .leiden (
99263 sub ,
@@ -102,10 +266,10 @@ def cluster_subclusters(
102266 n_iterations = 2 ,
103267 key_added = 'leiden_sub'
104268 )
105-
106- # Prefix and assign back
107- labels = ( orig + "_" + sub .obs ['leiden_sub' ].astype (str )). values
108- adata .obs .loc [mask , subcluster_col_name ] = labels
269+
270+ # Prefix subcluster labels and write back
271+ new_labels = orig + "_" + sub .obs ['leiden_sub' ].astype (str )
272+ adata .obs .loc [mask , subcluster_col_name ] = new_labels . values
109273
110274
111275
0 commit comments