Skip to content

Commit bf3ad2d

Browse files
committed
updates in response to reviewers
1 parent 1a29c5c commit bf3ad2d

10 files changed

Lines changed: 549 additions & 104 deletions

analysis.py

Lines changed: 331 additions & 87 deletions
Large diffs are not rendered by default.

gears_runner.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import pickle
66
from scgpt.utils import compute_perturbation_metrics
77

8-
def run_gears(runs=1, mode="train"):
8+
def run_gears(runs=1, mode="train", cross_validation=False):
99
for run_number in range(0, runs):
10-
for data_name in ["adam_corrected", "adam_corrected_upr", "adamson", "norman", "replogle_k562_essential"]:
10+
for data_name in ["adam_corrected_upr", "norman", "replogle_k562_essential", "adam_corrected", "adamson"]:
1111
##setup PertData object
1212
if data_name == "adam_corrected":
1313
pert_data = get_adam_corrected_dataset(split="simulation", batch_size=64, test_batch_size=64, generate_new=False, just_upr=False)
@@ -26,28 +26,42 @@ def run_gears(runs=1, mode="train"):
2626
pert_data.get_dataloader(batch_size = 64, test_batch_size = 64)
2727
if "replogle" in data_name:
2828
modify_pertdata_dataloaders(pert_data, logger=None)
29+
if cross_validation:
30+
cross_validate_split(pert_data, run_number + 1)
31+
prefix = "pickles/gears_results_cv/"
32+
suffix = f"cross_val_{run_number + 1}"
33+
else:
34+
prefix = "pickles/gears_results/"
35+
suffix = f"{run_number}"
2936
gears_model = GEARS(pert_data, device = 'cuda:0')
3037
if mode == "train":
3138
# set up and train a model
3239
gears_model.model_initialize(hidden_size = 64)
3340
gears_model.train(epochs = 20) #20 originally
34-
gears_model.save_model(f'gears_models/gears_trained_{data_name}_{run_number}')
41+
print("finished training, save model")
42+
gears_model.save_model(f'gears_models/gears_trained_{data_name}_{suffix}')
3543
#load model
36-
gears_model.load_pretrained(f'gears_models/gears_trained_{data_name}_{run_number}')
44+
gears_model.load_pretrained(f'gears_models/gears_trained_{data_name}_{suffix}')
3745
##evaluate
3846
eval_results = evaluate(loader=pert_data.dataloader['test_loader'], model=gears_model.model, uncertainty=gears_model.config['uncertainty'], device=torch.device("cuda:0"))
3947
##get rank score
4048
ranks = get_gears_rank(eval_results)
4149
print("avg rank: ", np.mean(list(ranks.values())), np.std(list(ranks.values())))
42-
pickle.dump(ranks, open(f"pickles/gears_results/gears_rank_metrics_{data_name}_{run_number}.pkl", "wb"))
50+
pickle.dump(ranks, open(f"{prefix}gears_rank_metrics_{data_name}_{suffix}.pkl", "wb"))
4351
##get pearson scores
4452
metrics, metrics_pert = compute_metrics(eval_results)
4553
test_metrics = compute_perturbation_metrics(eval_results, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
4654
print(f"metrics: {metrics}")
4755
print(f"metrics_pert: {metrics_pert}")
4856
print(f"test metrics: {test_metrics}")
49-
pickle.dump((metrics, metrics_pert), open(f"pickles/gears_results/gears_results_{data_name}_{run_number}.pkl", "wb"))
50-
pickle.dump(test_metrics, open(f"pickles/gears_results/gears_pert_delta_results_{data_name}_{run_number}.pkl", "wb"))
57+
pickle.dump((metrics, metrics_pert), open(f"{prefix}gears_results_{data_name}_{suffix}.pkl", "wb"))
58+
pickle.dump(test_metrics, open(f"{prefix}gears_pert_delta_results_{data_name}_{suffix}.pkl", "wb"))
59+
##condition specific performance
60+
condition_map = get_condition_performance_breakdown(eval_results, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
61+
pickle.dump(condition_map, open(f"{prefix}gears_condition_specific_results_{data_name}_{suffix}.pkl", "wb"))
62+
##gene specific performance
63+
gene_to_pearson_map = get_gene_performance_breakdown(eval_results, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
64+
pickle.dump(gene_to_pearson_map, open(f"{prefix}gears_gene_specific_results_{data_name}_{suffix}.pkl", "wb"))
5165

5266
def get_gears_rank(eval_results):
5367
pert_map = {} ##key: condition, value: (actual avg truth vector, predicted avg vector)
@@ -62,6 +76,13 @@ def get_gears_rank(eval_results):
6276
return ranks
6377

6478
print("running gears")
79+
##no cross validation
6580
##mode = "train" for training or "eval" for just evaluating models
66-
run_gears(runs=10, mode="train")
81+
run_gears(runs=10, mode="train", cross_validation=False)
82+
run_gears(runs=10, mode="eval", cross_validation=False)
83+
84+
##using cross validation
85+
##if cross_validate, then the fold will be the run_number + 1
86+
run_gears(runs=4, mode="train", cross_validation=True)
87+
6788

library.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,59 @@ def eval_perturb(
332332
results["avg_loss"] = total_loss / float(len(loader.dataset))
333333
return results
334334

335+
def get_condition_performance_breakdown(results, ctrl_adata):
336+
"""
337+
Given output from eval_perturb, compute pertubation specific performance
338+
"""
339+
from scgpt.utils import find_DE_genes
340+
mean_ctrl = np.array(ctrl_adata.X.mean(0)).flatten()
341+
conditions = np.unique(results["pert_cat"])
342+
geneid2idx = dict(zip(ctrl_adata.var.index.values, range(len(ctrl_adata.var))))
343+
de_idx = {c: find_DE_genes(ctrl_adata, c, geneid2idx, non_zero_genes=False)[0] for c in conditions}
344+
condition2idx = {c: np.where(results["pert_cat"] == c)[0] for c in conditions} #condition to indices where condition occurs in conditions array
345+
condition_map = {}
346+
for pert in conditions:
347+
pert_indices = condition2idx[pert]
348+
pert_de_idx = de_idx[pert]
349+
pert_pred = np.mean(results["pred"][pert_indices, :], axis=0)
350+
pert_truth = np.mean(results["truth"][pert_indices, :], axis=0)
351+
if np.sum(pert_pred[pert_de_idx]) == 0 or np.sum(pert_truth[pert_de_idx]) == 0: ##vector of all zeros will result in NaN pearson_de, skip
352+
print(f"WARNING: {pert} has 0 vector, will result in NaN pearson_de")
353+
pearson = scipy.stats.pearsonr(pert_pred, pert_truth)[0]
354+
pearson_delta = scipy.stats.pearsonr(pert_pred - mean_ctrl , pert_truth - mean_ctrl)[0]
355+
pearson_de = scipy.stats.pearsonr(pert_pred[pert_de_idx], pert_truth[pert_de_idx])[0]
356+
pearson_de_delta = scipy.stats.pearsonr(pert_pred[pert_de_idx] - mean_ctrl[pert_de_idx], pert_truth[pert_de_idx] - mean_ctrl[pert_de_idx])[0]
357+
condition_map[pert] = {"pearson": pearson, "pearson_de": pearson_de, "pearson_delta": pearson_delta, "pearson_de_delta": pearson_de_delta}
358+
print("average results in get_condition_performance_breakdown: ")
359+
for metric in ["pearson", "pearson_de", "pearson_delta", "pearson_de_delta"]:
360+
print(" ", metric, np.nanmean([condition_map[pert][metric] for pert in condition_map]))
361+
return condition_map
362+
363+
def get_gene_performance_breakdown(results, ctrl_adata):
364+
"""
365+
Given output from eval_perturb, compute gene specific performance
366+
pearson between actual and predicted for each gene
367+
note: no such concept as delta scores because pearson (x,y) == pearson (x-k, y-k) for constant k and vectors x,y
368+
"""
369+
pred = results["pred"]
370+
truth = results["truth"]
371+
gene_list = ctrl_adata.var["gene_name"].tolist()
372+
assert(len(gene_list) == len(pred[0]) == len(truth[0]))
373+
gene_to_pearson_map = {gene: "" for gene in gene_list}
374+
for i in range(0, len(gene_list)):
375+
gene = gene_list[i]
376+
if np.std(pred[:,i]) < 0.000001: ##mode collapse for models like mean, if std == 0 we cannot compute pearson (will be NaN), add very small random noise to prediction
377+
pred_vector = pred[:,i] + np.random.rand(len(pred[:, i])) * 0.0000001
378+
else:
379+
pred_vector = pred[:,i]
380+
if np.std(pred[:,i]) < 0.000001:
381+
truth_vector = truth[:,i] + np.random.rand(len(truth[:, i])) * 0.0000001
382+
else:
383+
truth_vector = truth[:,i]
384+
corr_i = scipy.stats.pearsonr(pred_vector, truth_vector)[0]
385+
gene_to_pearson_map[gene] = corr_i
386+
return gene_to_pearson_map
387+
335388
def get_variables(load_model=None, config_path=None):
336389
"""
337390
Reads config file and returns dictionary of variables
@@ -555,9 +608,88 @@ def modify_pertdata_dataloaders(pert_data, logger=None):
555608
for load_type in ["train", "val", "test"]:
556609
logger.info(f" new {load_type} loader length: {len(pert_data.dataloader[f'{load_type}_loader'])}")
557610

611+
def get_split(pert, pert_map):
612+
"""
613+
Given a perturbation pert
614+
and a dictionary with key: split, value: list of perts
615+
will return which split pert is found in
616+
"""
617+
for split in pert_map:
618+
if pert in pert_map[split]:
619+
return split
620+
return -1
621+
622+
def cross_validate_split(pert_data, cross_validation_fold):
623+
"""
624+
will modify PertData loaders to conform to the cross_validation fold
625+
4-fold, will have two be training, 1 be val, and 1 be test
626+
each split has unique perturbations (minus control, which will be in just train)
627+
"""
628+
print(f"WARNING: splitting data into cross validation fold {cross_validation_fold}")
629+
##get all perturbations in train/val/test, sort and then shuffle them by fixed seed so deterministic
630+
old_dataloaders = pert_data.dataloader
631+
all_perts = []
632+
splits = ["train", "val", "test"]
633+
for load_type in splits:
634+
old_loader = old_dataloaders[f"{load_type}_loader"]
635+
for batch, batch_data in enumerate(old_loader):
636+
all_perts = all_perts + list(batch_data.pert)
637+
all_perts = sorted(list(set(all_perts)))
638+
all_perts.remove("ctrl") ##let's add ctrl back later to just the train perturbations
639+
##deterministically shuffle all_perts
640+
g = torch.Generator()
641+
g.manual_seed(0)
642+
rand_indices = torch.randperm(len(all_perts), generator=g).tolist()
643+
shuffled_perts = [all_perts[rand_index] for rand_index in rand_indices]
644+
print(shuffled_perts)
645+
##chunk the list into folds
646+
divisor = int(len(shuffled_perts) / 4)
647+
chunk_1 = shuffled_perts[0: divisor]
648+
chunk_2 = shuffled_perts[divisor: divisor * 2]
649+
chunk_3 = shuffled_perts[divisor * 2: divisor * 3]
650+
chunk_4 = shuffled_perts[divisor * 3: ]
651+
##assign train/val/test depending on fold
652+
pert_map = {split: set() for split in splits}
653+
if cross_validation_fold == 1:
654+
pert_map["train"] = chunk_1 + chunk_2
655+
pert_map["val"] = chunk_3
656+
pert_map["test"] = chunk_4
657+
if cross_validation_fold == 2:
658+
pert_map["train"] = chunk_4 + chunk_1
659+
pert_map["val"] = chunk_2
660+
pert_map["test"] = chunk_3
661+
if cross_validation_fold == 3:
662+
pert_map["train"] = chunk_3 + chunk_4
663+
pert_map["val"] = chunk_1
664+
pert_map["test"] = chunk_2
665+
if cross_validation_fold == 4:
666+
pert_map["train"] = chunk_2 + chunk_3
667+
pert_map["val"] = chunk_4
668+
pert_map["test"] = chunk_1
669+
pert_map["train"].append("ctrl")
670+
##now create new loaders and assign
671+
new_data_map = {split: [] for split in splits}
672+
for load_type in splits:
673+
new_data = []
674+
skipped = set()
675+
old_loader = old_dataloaders[f"{load_type}_loader"]
676+
for batch, batch_data in enumerate(old_loader): ##batch_data is of type torch_geometric.data.batch.DataBatch, batch_data[i] is of type torch_geometric.data.data.Data
677+
for i in range(0, len(batch_data)):
678+
pert = batch_data.pert[i]
679+
my_split = get_split(pert, pert_map)
680+
new_data_map[my_split].append(batch_data[i])
681+
shuffle = {"train": True, "val":True, "test":False}
682+
new_dataloaders = {}
683+
for split in splits:
684+
new_loader = DataLoader(new_data_map[split], batch_size=old_loader.batch_size, shuffle=shuffle[load_type])
685+
new_dataloaders[f"{split}_loader"] = new_loader
686+
pert_data.dataloader = new_dataloaders
687+
558688
def check_args(opt):
559689
if opt.pretrain_control == True and opt.mode == "test":
560690
raise Exception("opt.pretrain_control == True and opt.mode == test")
691+
if opt.cross_validation == True:
692+
assert(opt.cross_validation_fold != None)
561693

562694
def convert_ENSG_to_gene(ensg_list):
563695
"""

runner.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ def main():
2222
parser.add_argument("--use_lora", type=bool_flag, default=False, help="True if we want to use LoRa for finetuning")
2323
parser.add_argument("--lora_rank", type=int, default=8, help="if use_lora, specifies the inner dimension of the low-rank matrices to train")
2424
parser.add_argument("--config_path", type=str, default="config/default_config.json", help="path to JSON configuration file to use for setting up model")
25-
parser.add_argument("--model_type", type=str, default="scGPT", help="scGPT, simple_affine, mean_control, mean_perturbed, mean_control+perturbed, smart_mean_control, smart_mean_perturbed, smart_mean_control+perturbed")
25+
parser.add_argument("--model_type", type=str, default="scGPT", help="scGPT, simple_affine, simple_affine_large, mean_control, mean_perturbed, mean_control+perturbed, smart_mean_control, smart_mean_perturbed, smart_mean_control+perturbed")
2626
parser.add_argument("--validation_selection", type=str, default="pearson", help="how to select the best model during training, if 'pearson' will be by pearson correlation between predicted and actual expression over validation set, if 'loss' will be by minimal loss")
2727
parser.add_argument("--loss_type", type=str, default="mse", help="mse, mse+triplet, mse+pearson")
2828
parser.add_argument("--fixed_seed", type=bool_flag, default=True, help="True if we want to use a constant fixed seed")
29+
parser.add_argument("--cross_validation", type=bool_flag, default=False, help="if True will cross validate instead of single random split, --cross_validation_fold arg must also be set")
30+
parser.add_argument("--cross_validation_fold", type=int, default=None, help="which fold to train")
31+
2932
opt = parser.parse_args()
3033
check_args(opt)
3134
matplotlib.rcParams["savefig.transparent"] = False
@@ -70,7 +73,10 @@ def main():
7073
logger.info("WARNING: filtering dataloaders! but keeping pert_data.adata the same")
7174
modify_pertdata_dataloaders(pert_data, logger)
7275

73-
check_pert_split(opt.data_name, pert_data)
76+
if opt.cross_validation:
77+
cross_validate_split(pert_data, opt.cross_validation_fold)
78+
else:
79+
check_pert_split(opt.data_name, pert_data)
7480

7581
logger.info(f"adata.obs: {pert_data.adata.obs}")
7682
logger.info(f"|conditions|: {len(set(pert_data.adata.obs['condition']))}")
@@ -103,7 +109,7 @@ def main():
103109
##mean predictor models - compute after data loaders are set (after a possible application of filter_perturbations)
104110
if opt.mode in ["train", "test", "analysis"]:
105111
for baseline in ["smart", "baseline"]:
106-
for mean_type in ["perturbed", "control", "control+perturbed"]:
112+
for mean_type in ["perturbed"]:#, "control", "control+perturbed"]:
107113
##baseline mean
108114
if baseline == "baseline":
109115
mean_pred_model = MeanPredictor(pert_data, opt.data_name, mean_type=mean_type)
@@ -115,7 +121,8 @@ def main():
115121
##get rank metrics
116122
test_perts = pickle.load(open(f"pickles/{opt.data_name}_perturbation_splits.pkl", "rb"))["test"]
117123
ranks = get_rank(mean_pred_model, test_perts, pert_data=pert_data, var=var, gene_ids=gene_ids, gene_idx_map=gene_idx_map)
118-
pickle.dump(ranks, open(rank_save, "wb"))
124+
if opt.cross_validation == False: ##save ranks to global pickles/ only on main split
125+
pickle.dump(ranks, open(rank_save, "wb"))
119126
##GEARS-type metrics
120127
mean_res = eval_perturb(pert_data.dataloader["test_loader"], mean_pred_model, gene_ids=[], gene_idx_map={}, var={"device":"cpu"}, loss_type=opt.loss_type) ##keep on cpu, no need to shuttle to gpu for mean pred model
121128
mean_metrics, mean_metrics_pert = compute_metrics(mean_res) ##from GEARS library
@@ -125,6 +132,12 @@ def main():
125132
mean_metrics = compute_perturbation_metrics(mean_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]) ##from scGPT library
126133
logger.info(f"{opt.data_name} {baseline} mean {mean_type} delta test metrics: {mean_metrics}")
127134
pickle.dump(mean_metrics, open(save_dir / f"{baseline}_mean_{mean_type}_pert_delta_results_{opt.data_name}.pkl", "wb"))
135+
##condition specific performance
136+
condition_map = get_condition_performance_breakdown(mean_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
137+
pickle.dump(condition_map, open(save_dir / f"{baseline}_mean_{mean_type}_condition_specific_results_{opt.data_name}.pkl", "wb"))
138+
##gene specific performance
139+
gene_to_pearson_map = get_gene_performance_breakdown(mean_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
140+
pickle.dump(gene_to_pearson_map, open(save_dir / f"{baseline}_mean_{mean_type}_gene_specific_results_{opt.data_name}.pkl", "wb"))
128141

129142
if opt.model_type == "scGPT":
130143
model = TransformerGenerator(
@@ -142,9 +155,10 @@ def main():
142155
pert_pad_id=var["pert_pad_id"],
143156
use_fast_transformer=var["use_fast_transformer"],
144157
)
145-
146-
elif opt.model_type == "simple_affine":
158+
elif "simple_affine" in opt.model_type:
147159
from simple_affine import SimpleAffine
160+
is_large = True if "large" in opt.model_type else False
161+
print("LARGE: ", is_large)
148162
model = SimpleAffine(
149163
ntoken=ntokens,
150164
d_model=var["embsize"],
@@ -154,6 +168,7 @@ def main():
154168
dropout=var["dropout"],
155169
pad_token=var["pad_token"],
156170
pert_pad_id=var["pert_pad_id"],
171+
is_large=is_large
157172
)
158173
elif "mean" in opt.model_type:
159174
if "smart" in opt.model_type:
@@ -165,7 +180,7 @@ def main():
165180
else:
166181
raise Exception("model_type must be one of scGPT, simple_affine, mean_control, mean_perturbed, mean_control+perturbed, smart_mean_control, smart_mean_perturbed, smart_mean_control+perturbed")
167182

168-
if opt.model_type in ["scGPT", "simple_affine"]:
183+
if opt.model_type in ["scGPT", "simple_affine", "simple_affine_large"]:
169184
model = load_model(var, model, model_file, logger, attention_control=opt.attention_control, freeze_input_encoder=opt.freeze_input_encoder, freeze_transformer_encoder=opt.freeze_transformer_encoder, mode=opt.mode, use_lora=opt.use_lora, lora_rank=opt.lora_rank, pretrain_control=opt.pretrain_control, transformer_encoder_control=opt.transformer_encoder_control, input_encoder_control=opt.input_encoder_control)
170185
model.to(var["device"])
171186

@@ -244,6 +259,11 @@ def main():
244259
test_metrics = compute_perturbation_metrics(test_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]) ##from scGPT utils library
245260
logger.info(f"{opt.data_name} delta test metrics: {test_metrics}")
246261
pickle.dump(test_metrics, open(save_dir / f"{opt.model_type}_pert_delta_results_{opt.data_name}.pkl", "wb"))
262+
##condition specific performance
263+
condition_map = get_condition_performance_breakdown(test_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
264+
pickle.dump(condition_map, open(save_dir / f"{opt.model_type}_condition_specific_results_{opt.data_name}.pkl", "wb"))
265+
gene_to_pearson_map = get_gene_performance_breakdown(test_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
266+
pickle.dump(gene_to_pearson_map, open(save_dir / f"{opt.model_type}_gene_specific_results_{opt.data_name}.pkl", "wb"))
247267

248268
if opt.mode == "analysis":
249269
for plot_type in ["boxplots", "scatterplots"]:

0 commit comments

Comments
 (0)