Skip to content

Commit 18036ef

Browse files
Compatibility with scvi-tools 1.0.0+, lightning and pytorch 2.0 + signal gene and cell abundance quantification around candidate target cells (#305)
* quantification of cell abundance and gene expression necessary for cell communication models * fix new lightning imports * [pre-commit.ci] auto fixes from pre-commit.com hooks * replacing use_gpu with accelerator * increment version --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a583a83 commit 18036ef

10 files changed

Lines changed: 407 additions & 33 deletions

File tree

README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,51 @@ adata_incl_nontissue = read_all_and_qc(
224224
count_file='raw_feature_bc_matrix.h5',
225225
)
226226
```
227+
228+
Since Version 0.9.0 (released on 2023-04-11), the function `AnnData.concatenate()` has been deprecated in favour of `anndata.concat()` as per the official release notes ([Reference](https://anndata.readthedocs.io/en/latest/release-notes/index.html#id4)). Here is the updated code snippet of `read_all_and_qc`:
229+
230+
```python
231+
from anndata import concat
232+
233+
def read_all_and_qc(
234+
sample_annot, Sample_ID_col, file_col, sp_data_folder,
235+
count_file='filtered_feature_bc_matrix.h5',
236+
):
237+
"""
238+
Read and concatenate all Visium files.
239+
"""
240+
241+
# read all samples and store them in a list
242+
adatas = []
243+
for i, s in enumerate(sample_annot[Sample_ID_col]):
244+
adata_i = read_and_qc(s, Sample_ID_col[file_col][i], path=sp_data_folder)
245+
adatas.append(adata_i)
246+
# combine individual samples
247+
adata = concat(
248+
adatas,
249+
merge="unique",
250+
uns_merge="unique",
251+
label="batch",
252+
keys=sample_annot[Sample_ID_col].tolist(),
253+
index_unique=None
254+
)
255+
256+
sample_annot.index = sample_annot[Sample_ID_col]
257+
for c in sample_annot.columns:
258+
sample_annot.loc[:, c] = sample_annot[c].astype(str)
259+
adata.obs[sample_annot.columns] = sample_annot.reindex(index=adata.obs['sample']).values
260+
261+
return adata
262+
263+
adata = read_all_and_qc(
264+
sample_annot=sample_annot,
265+
Sample_ID_col='Sample_ID',
266+
file_col='file',
267+
sp_data_folder=sp_data_folder,
268+
count_file='filtered_feature_bc_matrix.h5',
269+
)
270+
271+
cell2location.models.Cell2location.setup_anndata(
272+
adata=adata_vis,
273+
batch_key="batch")
274+
```

cell2location/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.distributions import biject_to, transform_to
88

99
from . import models
10+
from .cell_comm.around_target import compute_weighted_average_around_target
1011
from .run_colocation import run_colocation
1112

1213
# https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
@@ -46,4 +47,5 @@ def _transform_to_positive(constraint):
4647
__all__ = [
4748
"models",
4849
"run_colocation",
50+
"compute_weighted_average_around_target",
4951
]

cell2location/cell_comm/__init__.py

Whitespace-only changes.
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import numpy as np
2+
import pandas as pd
3+
from scipy.sparse import csr_matrix
4+
5+
6+
def compute_weighted_average_around_target(
7+
adata,
8+
target_cell_type_quantile: float = 0.995,
9+
source_cell_type_quantile: float = 0.95,
10+
normalisation_quantile: float = 0.95,
11+
distance_bin: list = None,
12+
sample_key: str = "sample",
13+
genes_to_use_as_source: list = None,
14+
gene_symbols: str = None,
15+
obsm_spatial_key: str = "X_spatial",
16+
normalisation_key: str = None,
17+
layer: str = None,
18+
cell_abundance_key: str = "cell_abundance_w_sf",
19+
cell_abundance_quantile_key: str = "q05",
20+
):
21+
"""
22+
Compute average abundance of source cell types or genes around each target cell type.
23+
24+
Parameters
25+
----------
26+
adata
27+
AnnData object of spatial dataset with cell2location results
28+
target_cell_type_quantile
29+
Quantile of target cell type abundance to use for defining
30+
a set locations with highest abundance of target cell types.
31+
Cell abundance below this thereshold is set to 0.
32+
source_cell_type_quantile
33+
Quantile of source cell type abundance to use for defining
34+
a set locations with highest abundance of source cell types.
35+
Cell abundance or RNA abundance for genes below this thereshold is set to 0.
36+
normalisation_quantile
37+
Quantile of source cell type or source RNA abundance for genes to use as normalising constant.
38+
This step can be seen as scaling that puts all source cell types or genes into the same scale.
39+
distance_bin
40+
If using concentric bins list with two elements specifying inner and outer edge of the bin.
41+
Distances specified in coordinates of `obsm_spatial_key`.
42+
sample_key
43+
`adata.obs` column key specifying distinct sections across
44+
which distance bin computation is invalid.
45+
genes_to_use_as_source
46+
To request RNA abundance of genes around target cells provide a list of
47+
var_names or gene SYMBOLs.
48+
gene_symbols
49+
`adata.var` column key containing gene symbols
50+
obsm_spatial_key
51+
`adata.obsm` key containing spatial coordinates (can be 2D or 3D or N-D).
52+
normalisation_key
53+
RNA abundance must be normalised using y_s technical effect term
54+
estimated by cell2location. Provide `adata.obsm` key containing this normalisation term.
55+
layer
56+
adata.layers to use for getting RNA abundance. Default: `adata.X`
57+
cell_abundance_key
58+
which cell2location variable to use as cell abundance
59+
cell_abundance_quantile_key
60+
which quantile of cell abundance to use
61+
62+
Returns
63+
-------
64+
pd.DataFrame of average abundance of source cell types or RNA abundance of requested genes
65+
around target cell types.
66+
67+
"""
68+
# save initial names
69+
if genes_to_use_as_source is None:
70+
source_names = adata.uns["mod"]["factor_names"]
71+
else:
72+
source_names = genes_to_use_as_source
73+
74+
cell_abundance_key_ = cell_abundance_quantile_key + cell_abundance_key
75+
cell_abundance_key = cell_abundance_quantile_key + "_" + cell_abundance_key
76+
77+
# create result data frame to be completed
78+
weighted_avg = pd.DataFrame(
79+
index=[f"target {ct}" for ct in adata.uns["mod"]["factor_names"]],
80+
columns=source_names,
81+
)
82+
if genes_to_use_as_source is None:
83+
# pick locations where source cell type abundance is above source_cell_type_quantile
84+
source_cell_type_filter = adata.obsm[cell_abundance_key] > adata.obsm[cell_abundance_key].quantile(
85+
source_cell_type_quantile
86+
)
87+
# zero-out source cell abundance below selected quantile
88+
source_cell_type_data = adata.obsm[cell_abundance_key] * source_cell_type_filter
89+
# get normalising quantile values
90+
source_normalisation_quantile = adata.obsm[cell_abundance_key].quantile(normalisation_quantile, axis=0)
91+
# compute average abundance above this quantile
92+
source_normalisation_quantile = np.average(
93+
adata.obsm[cell_abundance_key],
94+
weights=adata.obsm[cell_abundance_key] > source_normalisation_quantile,
95+
axis=0,
96+
)
97+
else:
98+
# if using gene symbols get var names:
99+
if gene_symbols is not None:
100+
genes_to_use_as_source = adata.var_names[adata.var[gene_symbols].isin(genes_to_use_as_source)]
101+
# get RNA abundance data
102+
if layer is None:
103+
source_cell_type_data = adata[:, genes_to_use_as_source].X.toarray()
104+
else:
105+
source_cell_type_data = adata[:, genes_to_use_as_source].layers[layer].toarray()
106+
# apply technical across-location normalisation
107+
if normalisation_key:
108+
source_cell_type_data = source_cell_type_data / adata.obsm[normalisation_key]
109+
# pick locations where source cell type abundance is above source_cell_type_quantile
110+
source_cell_type_filter = source_cell_type_data > np.quantile(
111+
source_cell_type_data, q=source_cell_type_quantile, axis=0
112+
)
113+
# zero-out source cell abundance below selected quantile
114+
source_cell_type_data = source_cell_type_data * source_cell_type_filter
115+
# create a dataframe with initial source RNA abundance
116+
source_cell_type_data = pd.DataFrame(
117+
source_cell_type_data,
118+
index=adata.obs_names,
119+
columns=source_names,
120+
)
121+
# get normalising quantile values
122+
source_normalisation_quantile = source_cell_type_data.quantile(normalisation_quantile, axis=0)
123+
# compute average abundance above this quantile
124+
source_normalisation_quantile = np.average(
125+
source_cell_type_data,
126+
weights=source_cell_type_data > source_normalisation_quantile,
127+
axis=0,
128+
)
129+
130+
# [optional] compute average source_cell_type_data across closes locations (concentric circles)
131+
if distance_bin is not None:
132+
# iterate over samples of connected location from the same sections
133+
# or independent chunks registered 3D data
134+
for s in adata.obs[sample_key].unique():
135+
# get sample observations
136+
sample_ind = adata.obs[sample_key].isin([s])
137+
138+
# compute distances bewteen locations
139+
from scipy.spatial.distance import cdist
140+
141+
distances = cdist(adata[sample_ind, :].obsm[obsm_spatial_key], adata[sample_ind, :].obsm[obsm_spatial_key])
142+
# select locations in distance bin
143+
binary_distance = csr_matrix((distances > distance_bin[0]) & (distances <= distance_bin[1]))
144+
# compute average abundance across locations within a bin
145+
data_ = (
146+
(binary_distance @ csr_matrix(source_cell_type_data.loc[sample_ind, :].values))
147+
.multiply(1 / binary_distance.sum(1))
148+
.toarray()
149+
)
150+
# to account for locations with no neighbours within a bin (sum == 0)
151+
data_[np.isnan(data_)] = 0
152+
# complete the average for a given sample
153+
source_cell_type_data.loc[sample_ind, :] = data_
154+
# normalise data by normalising quantile (global value across distance bins)
155+
source_cell_type_data = source_cell_type_data / source_normalisation_quantile
156+
# account for cases of undetected signal
157+
source_cell_type_data[source_cell_type_data.isna()] = 0
158+
159+
# compute average for each target cell type
160+
for ct in adata.uns["mod"]["factor_names"]:
161+
# find locations containing high abundance of target cell type
162+
target_cell_type_filter = adata.obsm[cell_abundance_key][f"{cell_abundance_key_}_{ct}"] > adata.obsm[
163+
cell_abundance_key
164+
][f"{cell_abundance_key_}_{ct}"].quantile(target_cell_type_quantile)
165+
# use thresholded abundance of target cell type as a weight
166+
weights = adata.obsm[cell_abundance_key][f"{cell_abundance_key_}_{ct}"] * target_cell_type_filter
167+
# normalise for target cell type abundance
168+
target_quantile = adata.obsm[cell_abundance_key][f"{cell_abundance_key_}_{ct}"].quantile(normalisation_quantile)
169+
target_quantile = np.average(
170+
adata.obsm[cell_abundance_key][f"{cell_abundance_key_}_{ct}"].values,
171+
weights=adata.obsm[cell_abundance_key][f"{cell_abundance_key_}_{ct}"].values > target_quantile,
172+
).flatten()
173+
assert target_quantile.shape == (1,), target_quantile.shape
174+
weights = weights / target_quantile
175+
# compute the final weighted average
176+
weighted_avg_ = np.average(
177+
source_cell_type_data,
178+
weights=weights,
179+
axis=0,
180+
)
181+
# weighted_avg_[weighted_avg_.isna()] = 0
182+
183+
weighted_avg_ = pd.Series(weighted_avg_, name=ct, index=source_names)
184+
185+
# hack to make self interactions less apparent
186+
weighted_avg_[ct] = weighted_avg_[~weighted_avg_.index.isin([ct])].max() + 0.02
187+
# complete the results dataframe
188+
weighted_avg.loc[f"target {ct}", :] = weighted_avg_
189+
190+
return weighted_avg.astype("float32")
191+
192+
193+
def melt_data_frame_per_signal(weighted_avg_dict, source_var, distance_bins):
194+
source_var_1 = pd.DataFrame(
195+
np.array([weighted_avg_dict[str(distance_bin)][source_var].values for distance_bin in distance_bins]),
196+
columns=weighted_avg_dict[str(distance_bins[0])].index,
197+
index=[np.mean(distance_bin) for distance_bin in distance_bins],
198+
).T
199+
200+
source_var_1 = source_var_1.melt(
201+
value_name="Abundance",
202+
var_name="Distance bin",
203+
ignore_index=False,
204+
)
205+
source_var_1["Target"] = source_var_1.index
206+
source_var_1["Signal"] = source_var
207+
return source_var_1
208+
209+
210+
def melt_signal_target_data_frame(weighted_avg_dict, distance_bins):
211+
source_vars = weighted_avg_dict[str(distance_bins[0])].columns
212+
213+
return pd.concat(
214+
[melt_data_frame_per_signal(weighted_avg_dict, source_var, distance_bins) for source_var in source_vars]
215+
)

cell2location/models/_cell2location_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
NumericalJointObsField,
1818
NumericalObsField,
1919
)
20-
from scvi.dataloaders import DataSplitter, DeviceBackedDataSplitter
20+
from scvi.dataloaders import DeviceBackedDataSplitter
2121
from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin
2222
from scvi.model.base._pyromixin import PyroJitGuideWarmup
2323
from scvi.train import TrainRunner
@@ -212,8 +212,11 @@ def train_aggressive(
212212
self,
213213
max_epochs: Optional[int] = 1000,
214214
use_gpu: Optional[Union[str, int, bool]] = None,
215+
accelerator: str = "auto",
216+
device: Union[int, str] = "auto",
215217
train_size: float = 1,
216218
validation_size: Optional[float] = None,
219+
shuffle_set_split: bool = True,
217220
batch_size: int = None,
218221
early_stopping: bool = False,
219222
lr: Optional[float] = None,
@@ -266,14 +269,16 @@ def train_aggressive(
266269
validation_size=validation_size,
267270
batch_size=batch_size,
268271
use_gpu=use_gpu,
272+
accelerator=accelerator,
273+
device=device,
269274
)
270275
else:
271-
data_splitter = DataSplitter(
276+
data_splitter = self._data_splitter_cls(
272277
self.adata_manager,
273278
train_size=train_size,
274279
validation_size=validation_size,
280+
shuffle_set_split=shuffle_set_split,
275281
batch_size=batch_size,
276-
use_gpu=use_gpu,
277282
)
278283
training_plan = PyroAggressiveTrainingPlan(pyro_module=self.module, **plan_kwargs)
279284

@@ -291,6 +296,8 @@ def train_aggressive(
291296
data_splitter=data_splitter,
292297
max_epochs=max_epochs,
293298
use_gpu=use_gpu,
299+
accelerator=accelerator,
300+
devices=device,
294301
**trainer_kwargs,
295302
)
296303
res = runner()

0 commit comments

Comments
 (0)