Skip to content

Commit 0649e1e

Browse files
committed
rename clusters fix+
1 parent 2810413 commit 0649e1e

5 files changed

Lines changed: 315 additions & 49 deletions

File tree

src/pySingleCellNet/tools/__init__.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@
3131
deg
3232
)
3333

34-
from .gene import (
35-
build_gene_knn,
36-
find_gene_modules,
37-
whoare_genes_neighbors,
38-
what_module_has_gene,
39-
score_gene_sets
40-
)
34+
from .gene import (
35+
build_gene_knn,
36+
find_gene_modules,
37+
whoare_genes_neighbors,
38+
what_module_has_gene,
39+
score_gene_sets,
40+
correlate_module_scores_with_pcs
41+
)
4142

4243
# API
4344
__all__ = [
@@ -57,9 +58,10 @@
5758
"convert_diffExp_to_dict",
5859
"deg",
5960
"build_gene_knn",
60-
"find_gene_modules",
61-
"whoare_genes_neighbors",
62-
"what_module_has_gene",
63-
"score_gene_sets"
64-
]
61+
"find_gene_modules",
62+
"whoare_genes_neighbors",
63+
"what_module_has_gene",
64+
"score_gene_sets",
65+
"correlate_module_scores_with_pcs"
66+
]
6567

src/pySingleCellNet/tools/cluster.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ def cluster_alot(
2828
2929
Assumptions:
3030
* ``adata.X`` is **already log-transformed**.
31-
* PCA has been computed and ``adata.obsm['X_pca']`` is present; this is
32-
used as the base embedding for PC selection/subsampling.
31+
* A base embedding is stored in ``adata.obsm``. By default this is
32+
``adata.obsm['X_pca']``, but you can override it via
33+
``knn_params['use_rep']`` to leverage an alternative representation.
3334
3435
Args:
3536
adata: AnnData object containing the log-transformed expression matrix.
36-
Must include ``obsm['X_pca']`` (shape ``(n_cells, n_pcs_total)``).
37+
Must include the embedding referenced by ``knn_params['use_rep']``
38+
(defaults to ``obsm['X_pca']``).
3739
leiden_resolutions: Leiden resolution values to evaluate (passed to
3840
``sc.tl.leiden``). Each resolution is combined with every KNN/PC
3941
configuration in the sweep.
@@ -56,6 +58,9 @@ def cluster_alot(
5658
knn_params: KNN graph parameters. Supported keys:
5759
* ``"n_neighbors"`` (List[int], default ``[10]``): Candidate values
5860
for ``K`` used in ``sc.pp.neighbors``.
61+
* ``"use_rep"`` (str, default ``"X_pca"``): Name of the
62+
``adata.obsm`` representation to use as the base embedding (e.g.,
63+
``"X_pca_noPC1"``). PC subsampling operates on this matrix.
5964
random_state: Random seed for PC subset sampling (when
6065
``percent_of_pcs`` is used). Pass ``None`` for non-deterministic
6166
sampling. Defaults to ``None``.
@@ -70,6 +75,7 @@ def cluster_alot(
7075
* **runs** (``pd.DataFrame``): One row per clustering run with metadata columns such as:
7176
- ``obs_key``: Name of the column in ``adata.obs`` that stores cluster labels.
7277
- ``neighbors_key``: Name of the neighbors graph key used/created.
78+
- ``use_rep``: Embedding key that served as the base representation.
7379
- ``resolution``: Leiden resolution value used for the run.
7480
- ``top_n_pcs``: Number of leading PCs considered.
7581
- ``pct_pcs``: Fraction of PCs used when subsampling (``percent_of_pcs``), or ``1.0`` if all were used.
@@ -80,10 +86,10 @@ def cluster_alot(
8086
(``round(pct_pcs * top_n_pcs)`` or ``top_n_pcs`` if no subsampling).
8187
8288
Raises:
83-
KeyError: If ``'X_pca'`` is missing from ``adata.obsm``.
84-
ValueError: If any provided parameter is out of range (e.g.,
85-
``percent_of_pcs`` not in ``(0, 1]``; empty lists; non-positive
86-
``n_neighbors``).
89+
ValueError: If the requested ``knn_params['use_rep']`` embedding is
90+
missing from ``adata.obsm`` or if any provided parameter is out of
91+
range (e.g., ``percent_of_pcs`` not in ``(0, 1]``; empty lists;
92+
non-positive ``n_neighbors``).
8793
RuntimeError: If neighbor graph construction or Leiden clustering fails.
8894
8995
Notes:
@@ -108,17 +114,27 @@ def cluster_alot(
108114
>>> runs[["obs_key", "n_clusters"]].head()
109115
"""
110116

111-
# ---- Validate prerequisites ----
112-
if "X_pca" not in adata.obsm:
113-
raise ValueError("`adata.obsm['X_pca']` not found. Please run PCA first.")
114-
Xpca = adata.obsm["X_pca"]
115-
n_pcs_available = Xpca.shape[1]
116-
if n_pcs_available < 2:
117-
raise ValueError(f"Not enough PCs ({n_pcs_available}) in `X_pca`.")
118-
119117
# ---- Normalize params ----
120118
pca_params = dict(pca_params or {})
121119
knn_params = dict(knn_params or {})
120+
121+
use_rep_key = knn_params.get("use_rep", "X_pca")
122+
if use_rep_key is None:
123+
use_rep_key = "X_pca"
124+
if not isinstance(use_rep_key, str):
125+
raise ValueError("`knn_params['use_rep']` must be a string key in `adata.obsm`.")
126+
127+
# ---- Validate prerequisites ----
128+
if use_rep_key not in adata.obsm:
129+
raise ValueError(
130+
f"`adata.obsm['{use_rep_key}']` not found. Please compute that representation first."
131+
)
132+
X_rep = adata.obsm[use_rep_key]
133+
n_pcs_available = X_rep.shape[1]
134+
if n_pcs_available < 2:
135+
raise ValueError(
136+
f"Not enough components ({n_pcs_available}) in `adata.obsm['{use_rep_key}']`."
137+
)
122138
top_n_pcs: List[int] = pca_params.get("top_n_pcs", [40])
123139
percent_of_pcs: Optional[float] = pca_params.get("percent_of_pcs", None)
124140
n_random_samples: Optional[int] = pca_params.get("n_random_samples", None)
@@ -143,9 +159,9 @@ def cluster_alot(
143159
# ---- Helper: build neighbors from a given PC subspace ----
144160
def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_key: str):
145161
"""Create a neighbors graph using the given PC column indices."""
146-
# Create a temporary representation name
147-
temp_rep_key = f"X_pca_sub_{neighbors_key}"
148-
adata.obsm[temp_rep_key] = Xpca[:, pc_idx]
162+
# Create a temporary representation name derived from the requested embedding
163+
temp_rep_key = f"{use_rep_key}_sub_{neighbors_key}"
164+
adata.obsm[temp_rep_key] = X_rep[:, pc_idx]
149165

150166
# Build neighbors; store under unique keys (in uns & obsp)
151167
sc.pp.neighbors(
@@ -158,6 +174,7 @@ def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_k
158174
# Record which PCs were used (for provenance)
159175
if neighbors_key in adata.uns:
160176
adata.uns[neighbors_key]["pcs_indices"] = pc_idx.astype(int)
177+
adata.uns[neighbors_key]["base_representation"] = use_rep_key
161178

162179
# Clean up the temporary representation to save memory
163180
del adata.obsm[temp_rep_key]
@@ -203,6 +220,7 @@ def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_k
203220
rows.append({
204221
"obs_key": obs_key,
205222
"neighbors_key": neighbors_key,
223+
"use_rep": use_rep_key,
206224
"resolution": res,
207225
"top_n_pcs": N,
208226
"pct_pcs": float(pct_str),
@@ -236,6 +254,7 @@ def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_k
236254
rows.append({
237255
"obs_key": obs_key,
238256
"neighbors_key": neighbors_key,
257+
"use_rep": use_rep_key,
239258
"resolution": float(res),
240259
"top_n_pcs": int(N),
241260
"pct_pcs": float(pct_str),
@@ -248,8 +267,19 @@ def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_k
248267

249268
summary_df = pd.DataFrame(rows)
250269
# nice ordering
251-
cols = ["obs_key","neighbors_key","resolution","top_n_pcs","pct_pcs","sample_idx",
252-
"n_neighbors","pcs_used_count","n_clusters","status"]
270+
cols = [
271+
"obs_key",
272+
"neighbors_key",
273+
"use_rep",
274+
"resolution",
275+
"top_n_pcs",
276+
"pct_pcs",
277+
"sample_idx",
278+
"n_neighbors",
279+
"pcs_used_count",
280+
"n_clusters",
281+
"status",
282+
]
253283
summary_df = summary_df[cols]
254284

255285
return summary_df
@@ -352,6 +382,3 @@ def cluster_subclusters(
352382
# Prefix subcluster labels and write back
353383
new_labels = orig + "_" + sub.obs['leiden_sub'].astype(str)
354384
adata.obs.loc[mask, subcluster_col_name] = new_labels.values
355-
356-
357-

src/pySingleCellNet/tools/gene.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,4 +886,154 @@ def what_module_has_gene(
886886
return [key for key, genes in genemodules.items() if target_gene in genes]
887887

888888

889+
def correlate_module_scores_with_pcs(
890+
adata: AnnData,
891+
score_key: Union[str, Sequence[float], np.ndarray, pd.Series],
892+
*,
893+
pca_key: str = "X_pca",
894+
variance_key: Optional[str] = "pca",
895+
method: str = "pearson",
896+
min_abs_corr: Optional[float] = 0.3,
897+
drop_na: bool = True,
898+
sort: bool = True,
899+
) -> pd.DataFrame:
900+
"""Quantify the association between a module score and individual PCs.
901+
902+
Parameters
903+
----------
904+
adata
905+
AnnData object containing PCs in ``adata.obsm`` and per-cell module scores.
906+
score_key
907+
Either the name of an ``adata.obs`` column holding module scores (e.g., the
908+
output of :func:`score_gene_sets`) or an explicit array-like of shape
909+
``(n_cells,)``.
910+
pca_key
911+
Key of the embedding in ``adata.obsm`` to correlate against (defaults to
912+
``"X_pca"``).
913+
variance_key
914+
Optional ``adata.uns`` key that stores ``"variance_ratio"`` for the chosen
915+
PCA run (defaults to ``"pca"`` when using ``sc.tl.pca``).
916+
method
917+
Correlation metric: ``"pearson"`` (default) or ``"spearman"``.
918+
min_abs_corr
919+
Absolute-correlation threshold used to flag PCs that strongly follow the
920+
module score. Set to ``None`` to skip flagging.
921+
drop_na
922+
If ``True`` (default), silently drop cells with missing scores/PC values.
923+
Otherwise raise when NaNs are detected.
924+
sort
925+
If ``True`` (default), sort the output by descending absolute correlation.
926+
927+
Returns
928+
-------
929+
pandas.DataFrame
930+
Table with one row per PC containing the correlation, absolute correlation,
931+
two-sided p-value, variance ratio (when available), and a boolean flag
932+
indicating whether the PC exceeds ``min_abs_corr``.
933+
"""
934+
935+
if pca_key not in adata.obsm:
936+
raise ValueError(f"'{pca_key}' not found in adata.obsm. Run PCA first.")
937+
pcs = np.asarray(adata.obsm[pca_key], dtype=np.float64)
938+
if pcs.ndim != 2:
939+
raise ValueError(f"adata.obsm['{pca_key}'] must be 2-D (cells × PCs).")
940+
if pcs.shape[0] != adata.n_obs:
941+
raise ValueError("Number of rows in the PCA embedding does not match n_obs.")
942+
943+
# Resolve the module scores vector
944+
score_label = None
945+
if isinstance(score_key, str):
946+
if score_key not in adata.obs:
947+
raise ValueError(f"score_key='{score_key}' not present in adata.obs.")
948+
scores = adata.obs[score_key].to_numpy(dtype=np.float64)
949+
score_label = score_key
950+
else:
951+
scores = np.asarray(score_key, dtype=np.float64).reshape(-1)
952+
if scores.shape[0] != adata.n_obs:
953+
raise ValueError("score_key array must have length equal to adata.n_obs.")
954+
955+
# Handle missing data
956+
finite_scores = np.isfinite(scores)
957+
finite_pcs = np.all(np.isfinite(pcs), axis=1)
958+
if drop_na:
959+
mask = finite_scores & finite_pcs
960+
else:
961+
if not (finite_scores.all() and finite_pcs.all()):
962+
raise ValueError("NaN/inf detected in scores or PCs; set drop_na=True to filter them.")
963+
mask = np.ones_like(finite_scores, dtype=bool)
964+
965+
n_valid = int(mask.sum())
966+
if n_valid < 3:
967+
raise ValueError("Need at least 3 valid cells to compute correlations.")
968+
969+
y = scores[mask]
970+
X = pcs[mask]
971+
972+
method_lc = method.lower()
973+
if method_lc not in {"pearson", "spearman"}:
974+
raise ValueError("method must be either 'pearson' or 'spearman'.")
975+
976+
if method_lc == "spearman":
977+
from scipy.stats import rankdata # local import to avoid global dependency
978+
y = rankdata(y)
979+
# Rank each PC separately
980+
X = np.apply_along_axis(rankdata, 0, X)
981+
982+
# Center data
983+
y = y.astype(np.float64)
984+
y_centered = y - y.mean()
985+
y_norm = np.sqrt(np.sum(y_centered ** 2))
986+
if y_norm == 0:
987+
raise ValueError("Module score has zero variance; correlation undefined.")
988+
989+
X_centered = X - X.mean(axis=0)
990+
X_norm = np.sqrt(np.sum(X_centered ** 2, axis=0))
991+
992+
with np.errstate(divide="ignore", invalid="ignore"):
993+
corr = (y_centered @ X_centered) / (y_norm * X_norm)
994+
corr = corr.astype(np.float64)
995+
996+
n_pcs = corr.size
997+
dof = n_valid - 2
998+
if dof < 1:
999+
raise ValueError("Not enough cells to compute correlation p-values (need >= 3).")
1000+
1001+
# Compute two-sided Pearson p-values (valid for Spearman ranks as an approximation)
1002+
with np.errstate(divide="ignore", invalid="ignore"):
1003+
denom = np.clip(1.0 - corr**2, 1e-12, None)
1004+
t_stat = corr * np.sqrt(dof / denom)
1005+
from scipy import stats as _stats # local import
1006+
p_values = 2.0 * _stats.t.sf(np.abs(t_stat), df=dof)
1007+
1008+
# Variance ratios (if available)
1009+
var_ratio = np.full(n_pcs, np.nan)
1010+
if variance_key is not None and variance_key in adata.uns:
1011+
uns_entry = adata.uns[variance_key]
1012+
if isinstance(uns_entry, Mapping) and "variance_ratio" in uns_entry:
1013+
vr = np.asarray(uns_entry["variance_ratio"], dtype=np.float64).ravel()
1014+
if vr.size:
1015+
var_ratio[: min(n_pcs, vr.size)] = vr[:n_pcs]
1016+
1017+
result = pd.DataFrame({
1018+
"pc": [f"PC{i}" for i in range(1, n_pcs + 1)],
1019+
"pc_index": np.arange(1, n_pcs + 1, dtype=int),
1020+
"correlation": corr,
1021+
"abs_correlation": np.abs(corr),
1022+
"p_value": p_values,
1023+
"variance_ratio": var_ratio,
1024+
"n_cells": n_valid,
1025+
"score_key": score_label or "array",
1026+
})
1027+
1028+
if min_abs_corr is not None:
1029+
threshold = float(min_abs_corr)
1030+
result["flag_high_corr"] = result["abs_correlation"] >= threshold
1031+
result.attrs["min_abs_corr"] = threshold
1032+
else:
1033+
result["flag_high_corr"] = False
1034+
1035+
if sort:
1036+
result = result.sort_values("abs_correlation", ascending=False).reset_index(drop=True)
1037+
1038+
return result
8891039

src/pySingleCellNet/utils/__init__.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
)
77

88
from .adataTools import (
9-
split_adata_indices,
10-
rename_cluster_labels,
11-
limit_anndata_to_common_genes,
12-
remove_genes,
13-
filter_anndata_slots,
14-
filter_adata_by_group_size
15-
)
9+
split_adata_indices,
10+
rename_cluster_labels,
11+
limit_anndata_to_common_genes,
12+
remove_genes,
13+
filter_anndata_slots,
14+
filter_adata_by_group_size,
15+
drop_pcs_from_embedding
16+
)
1617

1718
#from .gene import (
1819
# extract_top_bottom_genes,
@@ -52,10 +53,11 @@
5253
"score_sex",
5354
"split_adata_indices",
5455
"rename_cluster_labels",
55-
"limit_anndata_to_common_genes",
56-
"remove_genes",
57-
"filter_anndata_slots",
58-
"filter_adata_by_group_size",
56+
"limit_anndata_to_common_genes",
57+
"remove_genes",
58+
"filter_anndata_slots",
59+
"filter_adata_by_group_size",
60+
"drop_pcs_from_embedding",
5961
# "extract_top_bottom_genes",
6062
# "pull_out_genes",
6163
# "pull_out_genes_v2",

0 commit comments

Comments
 (0)