@@ -28,12 +28,14 @@ def cluster_alot(
2828
2929 Assumptions:
3030 * ``adata.X`` is **already log-transformed**.
31- * PCA has been computed and ``adata.obsm['X_pca']`` is present; this is
32- used as the base embedding for PC selection/subsampling.
31+ * A base embedding is stored in ``adata.obsm``. By default this is
32+ ``adata.obsm['X_pca']``, but you can override it via
33+ ``knn_params['use_rep']`` to leverage an alternative representation.
3334
3435 Args:
3536 adata: AnnData object containing the log-transformed expression matrix.
36- Must include ``obsm['X_pca']`` (shape ``(n_cells, n_pcs_total)``).
37+ Must include the embedding referenced by ``knn_params['use_rep']``
38+ (defaults to ``obsm['X_pca']``).
3739 leiden_resolutions: Leiden resolution values to evaluate (passed to
3840 ``sc.tl.leiden``). Each resolution is combined with every KNN/PC
3941 configuration in the sweep.
@@ -56,6 +58,9 @@ def cluster_alot(
5658 knn_params: KNN graph parameters. Supported keys:
5759 * ``"n_neighbors"`` (List[int], default ``[10]``): Candidate values
5860 for ``K`` used in ``sc.pp.neighbors``.
61+ * ``"use_rep"`` (str, default ``"X_pca"``): Name of the
62+ ``adata.obsm`` representation to use as the base embedding (e.g.,
63+ ``"X_pca_noPC1"``). PC subsampling operates on this matrix.
5964 random_state: Random seed for PC subset sampling (when
6065 ``percent_of_pcs`` is used). Pass ``None`` for non-deterministic
6166 sampling. Defaults to ``None``.
@@ -70,6 +75,7 @@ def cluster_alot(
7075 * **runs** (``pd.DataFrame``): One row per clustering run with metadata columns such as:
7176 - ``obs_key``: Name of the column in ``adata.obs`` that stores cluster labels.
7277 - ``neighbors_key``: Name of the neighbors graph key used/created.
78+ - ``use_rep``: Embedding key that served as the base representation.
7379 - ``resolution``: Leiden resolution value used for the run.
7480 - ``top_n_pcs``: Number of leading PCs considered.
7581 - ``pct_pcs``: Fraction of PCs used when subsampling (``percent_of_pcs``), or ``1.0`` if all were used.
@@ -80,10 +86,10 @@ def cluster_alot(
8086 (``round(pct_pcs * top_n_pcs)`` or ``top_n_pcs`` if no subsampling).
8187
8288 Raises:
83- KeyError : If ``'X_pca' `` is missing from ``adata.obsm``.
84- ValueError: If any provided parameter is out of range (e.g.,
85- ``percent_of_pcs`` not in ``(0, 1]``; empty lists; non-positive
86- ``n_neighbors``).
89+ ValueError : If the requested ``knn_params['use_rep'] `` embedding is
90+ missing from ``adata.obsm`` or if any provided parameter is out of
91+ range (e.g., ``percent_of_pcs`` not in ``(0, 1]``; empty lists;
92+ non-positive ``n_neighbors``).
8793 RuntimeError: If neighbor graph construction or Leiden clustering fails.
8894
8995 Notes:
@@ -108,17 +114,27 @@ def cluster_alot(
108114 >>> runs[["obs_key", "n_clusters"]].head()
109115 """
110116
111- # ---- Validate prerequisites ----
112- if "X_pca" not in adata .obsm :
113- raise ValueError ("`adata.obsm['X_pca']` not found. Please run PCA first." )
114- Xpca = adata .obsm ["X_pca" ]
115- n_pcs_available = Xpca .shape [1 ]
116- if n_pcs_available < 2 :
117- raise ValueError (f"Not enough PCs ({ n_pcs_available } ) in `X_pca`." )
118-
119117 # ---- Normalize params ----
120118 pca_params = dict (pca_params or {})
121119 knn_params = dict (knn_params or {})
120+
121+ use_rep_key = knn_params .get ("use_rep" , "X_pca" )
122+ if use_rep_key is None :
123+ use_rep_key = "X_pca"
124+ if not isinstance (use_rep_key , str ):
125+ raise ValueError ("`knn_params['use_rep']` must be a string key in `adata.obsm`." )
126+
127+ # ---- Validate prerequisites ----
128+ if use_rep_key not in adata .obsm :
129+ raise ValueError (
130+ f"`adata.obsm['{ use_rep_key } ']` not found. Please compute that representation first."
131+ )
132+ X_rep = adata .obsm [use_rep_key ]
133+ n_pcs_available = X_rep .shape [1 ]
134+ if n_pcs_available < 2 :
135+ raise ValueError (
136+ f"Not enough components ({ n_pcs_available } ) in `adata.obsm['{ use_rep_key } ']`."
137+ )
122138 top_n_pcs : List [int ] = pca_params .get ("top_n_pcs" , [40 ])
123139 percent_of_pcs : Optional [float ] = pca_params .get ("percent_of_pcs" , None )
124140 n_random_samples : Optional [int ] = pca_params .get ("n_random_samples" , None )
@@ -143,9 +159,9 @@ def cluster_alot(
143159 # ---- Helper: build neighbors from a given PC subspace ----
144160 def _neighbors_from_pc_indices (pc_idx : np .ndarray , n_neighbors : int , neighbors_key : str ):
145161 """Create a neighbors graph using the given PC column indices."""
146- # Create a temporary representation name
147- temp_rep_key = f"X_pca_sub_ { neighbors_key } "
148- adata .obsm [temp_rep_key ] = Xpca [:, pc_idx ]
162+ # Create a temporary representation name derived from the requested embedding
163+ temp_rep_key = f"{ use_rep_key } _sub_ { neighbors_key } "
164+ adata .obsm [temp_rep_key ] = X_rep [:, pc_idx ]
149165
150166 # Build neighbors; store under unique keys (in uns & obsp)
151167 sc .pp .neighbors (
@@ -158,6 +174,7 @@ def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_k
158174 # Record which PCs were used (for provenance)
159175 if neighbors_key in adata .uns :
160176 adata .uns [neighbors_key ]["pcs_indices" ] = pc_idx .astype (int )
177+ adata .uns [neighbors_key ]["base_representation" ] = use_rep_key
161178
162179 # Clean up the temporary representation to save memory
163180 del adata .obsm [temp_rep_key ]
@@ -203,6 +220,7 @@ def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_k
203220 rows .append ({
204221 "obs_key" : obs_key ,
205222 "neighbors_key" : neighbors_key ,
223+ "use_rep" : use_rep_key ,
206224 "resolution" : res ,
207225 "top_n_pcs" : N ,
208226 "pct_pcs" : float (pct_str ),
@@ -236,6 +254,7 @@ def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_k
236254 rows .append ({
237255 "obs_key" : obs_key ,
238256 "neighbors_key" : neighbors_key ,
257+ "use_rep" : use_rep_key ,
239258 "resolution" : float (res ),
240259 "top_n_pcs" : int (N ),
241260 "pct_pcs" : float (pct_str ),
@@ -248,8 +267,19 @@ def _neighbors_from_pc_indices(pc_idx: np.ndarray, n_neighbors: int, neighbors_k
248267
249268 summary_df = pd .DataFrame (rows )
250269 # nice ordering
251- cols = ["obs_key" ,"neighbors_key" ,"resolution" ,"top_n_pcs" ,"pct_pcs" ,"sample_idx" ,
252- "n_neighbors" ,"pcs_used_count" ,"n_clusters" ,"status" ]
270+ cols = [
271+ "obs_key" ,
272+ "neighbors_key" ,
273+ "use_rep" ,
274+ "resolution" ,
275+ "top_n_pcs" ,
276+ "pct_pcs" ,
277+ "sample_idx" ,
278+ "n_neighbors" ,
279+ "pcs_used_count" ,
280+ "n_clusters" ,
281+ "status" ,
282+ ]
253283 summary_df = summary_df [cols ]
254284
255285 return summary_df
@@ -352,6 +382,3 @@ def cluster_subclusters(
352382 # Prefix subcluster labels and write back
353383 new_labels = orig + "_" + sub .obs ['leiden_sub' ].astype (str )
354384 adata .obs .loc [mask , subcluster_col_name ] = new_labels .values
355-
356-
357-
0 commit comments