From 4efc6c4299c1cd9643baeb6c2c408eba7b1525a3 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sun, 30 Jan 2022 02:59:30 +0000 Subject: [PATCH] option to treat y_s as detetion probability --- cell2location/models/_cell2location_model.py | 10 +++- cell2location/models/_cell2location_module.py | 60 +++++++++++++------ tests/test_cell2location.py | 16 +++++ 3 files changed, 64 insertions(+), 22 deletions(-) diff --git a/cell2location/models/_cell2location_model.py b/cell2location/models/_cell2location_model.py index 21c70437..b763589a 100755 --- a/cell2location/models/_cell2location_model.py +++ b/cell2location/models/_cell2location_model.py @@ -76,11 +76,15 @@ def __init__( self.factor_names_ = cell_state_df.columns.values if not detection_mean_per_sample: - # compute expected change in sensitivity (m_g in V1 or y_s in V2) + # compute expected change in sensitivity (m_g in V1 and y_s in V2) sc_total = cell_state_df.sum(0).mean() - sp_total = get_from_registry(self.adata, _CONSTANTS.X_KEY).sum(1).mean() - get_from_registry(adata, _CONSTANTS.BATCH_KEY) + sp_total = get_from_registry(self.adata, _CONSTANTS.X_KEY).sum(1) + batch = get_from_registry(self.adata, _CONSTANTS.BATCH_KEY).flatten() + sp_total = np.array([sp_total[batch == b].mean() for b in range(self.summary_stats["n_batch"])]) self.detection_mean_ = (sp_total / model_kwargs.get("N_cells_per_location", 1)) / sc_total + if (self.detection_mean_.max() > 1.0) and (model_kwargs.get("use_detection_probability", False) is True): + self.detection_mean_ = self.detection_mean_ / (self.detection_mean_.max() + 0.000001) + self.detection_mean_ = self.detection_mean_.mean() self.detection_mean_ = self.detection_mean_ * detection_mean_correction model_kwargs["detection_mean"] = self.detection_mean_ else: diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index ab705ee8..81fac259 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -83,6 +83,7 @@ def __init__( n_groups: int = 50, detection_mean=1 / 2, detection_alpha=200.0, + use_detection_probability: bool = False, m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1.0, "alpha_mean": 3.0}, N_cells_per_location=8.0, A_factors_per_location=7.0, @@ -117,6 +118,7 @@ def __init__( detection_hyp_prior["mean"] = detection_mean detection_hyp_prior["alpha"] = detection_alpha self.detection_hyp_prior = detection_hyp_prior + self.use_detection_probability = use_detection_probability if (init_vals is not None) & (type(init_vals) is dict): self.np_init_vals = init_vals @@ -323,27 +325,47 @@ def forward(self, x_data, idx, batch_index): ) # (self.n_obs, self.n_factors) # =====================Location-specific detection efficiency ======================= # - # y_s with hierarchical mean prior - detection_mean_y_e = pyro.sample( - "detection_mean_y_e", - dist.Gamma( - self.ones * self.detection_mean_hyp_prior_alpha, - self.ones * self.detection_mean_hyp_prior_beta, + if not self.use_detection_probability: + # y_s with hierarchical mean prior + detection_mean_y_e = pyro.sample( + "detection_mean_y_e", + dist.Gamma( + self.ones * self.detection_mean_hyp_prior_alpha, + self.ones * self.detection_mean_hyp_prior_beta, + ) + .expand([self.n_batch, 1]) + .to_event(2), + ) + detection_hyp_prior_alpha = pyro.deterministic( + "detection_hyp_prior_alpha", + self.ones_n_batch_1 * self.detection_hyp_prior_alpha, ) - .expand([self.n_batch, 1]) - .to_event(2), - ) - detection_hyp_prior_alpha = pyro.deterministic( - "detection_hyp_prior_alpha", - self.ones_n_batch_1 * self.detection_hyp_prior_alpha, - ) - beta = (obs2sample @ detection_hyp_prior_alpha) / (obs2sample @ detection_mean_y_e) - with obs_plate: - detection_y_s = pyro.sample( - "detection_y_s", - dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta), - ) # (self.n_obs, 1) + beta = (obs2sample @ detection_hyp_prior_alpha) / (obs2sample @ detection_mean_y_e) + with obs_plate: + detection_y_s = pyro.sample( + "detection_y_s", + dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta), + ) # (self.n_obs, 1) + else: + # y_s with hierarchical mean prior + detection_mean_y_e = pyro.sample( + "detection_mean_y_e", + dist.Beta( + self.ones * self.detection_mean_hyp_prior_alpha, + self.ones * self.detection_mean_hyp_prior_beta, + ) + .expand([self.n_batch, 1]) + .to_event(2), + ) + + alpha = (obs2sample @ detection_mean_y_e) * self.ones * self.detection_hyp_prior_alpha + beta = (obs2sample @ (self.ones - detection_mean_y_e)) * self.ones * self.detection_hyp_prior_alpha + with obs_plate: + detection_y_s = pyro.sample( + "detection_y_s", + dist.Beta(alpha, beta), + ) # (self.n_obs, 1) # =====================Gene-specific additive component ======================= # # per gene molecule contribution that cannot be explained by diff --git a/tests/test_cell2location.py b/tests/test_cell2location.py index 02e4f587..423a1ec3 100644 --- a/tests/test_cell2location.py +++ b/tests/test_cell2location.py @@ -120,3 +120,19 @@ def test_cell2location(): # export the estimated cell abundance (summary of the posterior distribution) # full data st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs}) + + ### test new cell2location models ### + ## detection probability rather than detection efficiency ## + st_model = Cell2location( + dataset, + cell_state_df=inf_aver, + N_cells_per_location=30, + detection_alpha=200, + use_detection_probability=True, + detection_hyp_prior={"mean_alpha": 100.0}, + ) + # test full data training + st_model.train(max_epochs=1) + # export the estimated cell abundance (summary of the posterior distribution) + # full data + dataset = st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs})