diff --git a/muon/_prot/preproc.py b/muon/_prot/preproc.py index 5f03b90..cf27106 100644 --- a/muon/_prot/preproc.py +++ b/muon/_prot/preproc.py @@ -1,10 +1,11 @@ -from typing import Optional, Iterable, Tuple, Union +from typing import Optional, Iterable, Tuple, Union, Literal from numbers import Integral, Real from warnings import warn import numpy as np import pandas as pd -from scipy.sparse import issparse, csc_matrix, csr_matrix +from scipy.sparse import issparse, csc_matrix, csr_matrix, csc_array, csr_array +from scipy.stats import gmean from sklearn.mixture import GaussianMixture from sklearn.decomposition import PCA from sklearn.linear_model import LinearRegression @@ -198,7 +199,12 @@ def dsb( return toreturn -def clr(adata: AnnData, inplace: bool = True, axis: int = 0) -> Union[None, AnnData]: +def clr( + adata: AnnData, + inplace: bool = True, + axis: int = 0, + flavor: Literal["seurat", "stoeckius", "standard"] = "seurat", +) -> AnnData | None: """ Apply the centered log ratio (CLR) transformation to normalize counts in adata.X. @@ -207,6 +213,19 @@ def clr(adata: AnnData, inplace: bool = True, axis: int = 0) -> Union[None, AnnD data: AnnData object with protein expression counts. inplace: Whether to update adata.X inplace. axis: Axis across which CLR is performed. + flavor: How to perform the CLR transformation. + + - seurat: Uses log1p transformations throughout. This results in non-negative values and preserves + sparse matrices. + - stoeckius: Follows the original CITE-Seq paper by adding a pseudocount of 1 to the data before + performing any transformations and using the standard log transform. This adheres more closely + to the standard definition of the CLR transform, but can yield negative values and does not + preserve sparse matrices (the result is always a dense matrix.) + - standard: The standard CLR transform without any pseudocounts. Does not preserve sparse matrices + and may yield infinite values if the input contains zeros. + + References: + Stoeckius et al, 2017 (`doi:10.1038/nmeth.4380 `_) """ if axis not in [0, 1]: @@ -215,25 +234,33 @@ def clr(adata: AnnData, inplace: bool = True, axis: int = 0) -> Union[None, AnnD if not inplace: adata = adata.copy() - if issparse(adata.X) and axis == 0 and not isinstance(adata.X, csc_matrix): - warn("adata.X is sparse but not in CSC format. Converting to CSC.") - x = csc_matrix(adata.X) - elif issparse(adata.X) and axis == 1 and not isinstance(adata.X, csr_matrix): - warn("adata.X is sparse but not in CSR format. Converting to CSR.") - x = csr_matrix(adata.X) - else: - x = adata.X + x = adata.X - if issparse(x): - x.data /= np.repeat( - np.exp(np.log1p(x).sum(axis=axis).A / x.shape[axis]), x.getnnz(axis=axis) - ) - np.log1p(x.data, out=x.data) + if flavor == "seurat": + if issparse(x): + if axis == 0 and not isinstance(x, csc_matrix | csc_array): + warn( + "adata.X is sparse but not in CSC format. CSC format required for `axis=0`. Converting to CSC." + ) + x = x.tocsc() + elif axis == 1 and not isinstance(x, csr_matrix | csr_array): + warn( + "adata.X is sparse but not in CSR format. CSR format required for `axis=1`. Converting to CSR." + ) + x = x.tocsr() + + x.data /= np.repeat(np.exp(np.log1p(x).mean(axis=axis).toarray()), x.getnnz(axis=axis)) + np.log1p(x.data, out=x.data) + else: + np.log1p(x / np.exp(np.log1p(x).mean(axis=axis, keepdims=True)), out=x) + elif flavor in ("stoeckius", "standard"): + if issparse(x): + x = x.toarray() + if flavor == "stoeckius": + x += 1 + np.log(x / gmean(x, axis=axis, keepdims=True), out=x) else: - np.log1p( - x / np.exp(np.log1p(x).sum(axis=axis, keepdims=True) / x.shape[axis]), - out=x, - ) + raise ValueError(f"Unknown flavor `{flavor}`.") adata.X = x