-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrunner.py
More file actions
294 lines (271 loc) · 21.2 KB
/
runner.py
File metadata and controls
294 lines (271 loc) · 21.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
"""
Main runner script for training, evaluating, and analyzing results
Much of the core scGPT code is based off the scGPT authors' tutorial: https://github.com/bowang-lab/scGPT/blob/7301b51a72f5db321fccebb51bc4dd1380d99023/tutorials/Tutorial_Perturbation.ipynb#L831
"""
from library import *
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, default="test", help="train or test")
parser.add_argument("--random_shuffle", type=bool_flag, default=False, help="True if want to randomly shuffle pert_flag, else False, used for control testing")
parser.add_argument("--data_name", type=str, default="adamson", help="which dataset to use adamson, norman, replogle_k562_essential")
parser.add_argument("--load_model", type=str, default="models/scgpt-pretrained/scGPT_human", help="which pretrained model to load")
parser.add_argument("--filter_perturbations", type=bool_flag, default=False, help="True if want to remove perturbated cells that have a perturbation NOT part of the gene set, else False")
parser.add_argument("--transformer_encoder_control", type=bool_flag, default=False, help="True if want to intentionally NOT load any pre-trained transformer encoder weights prior to training, else False, used for control testing")
parser.add_argument("--attention_control", type=bool_flag, default=False, help="True if want to intentionally NOT load the pre-trained self attention weights prior to training, else False, used for control testing")
parser.add_argument("--input_encoder_control", type=bool_flag, default=False, help="True if want to intentionally NOT load the pre-trained input encoding weights (gene encoder + expression encoder) prior to training, else False, used for control testing")
parser.add_argument("--pretrain_control", type=bool_flag, default=False, help="True if want to intentionally NOT load any pre-trained weights prior to training, else False, used for control testing")
parser.add_argument("--save_dir", type=str, default="default", help="set to 'default' if want the default save_dir, else set to specific path")
parser.add_argument("--always_keep_pert_gene", type=bool_flag, default=False, help="True if we want to always inform the model of which gene was perturbed during training")
parser.add_argument("--freeze_input_encoder", type=bool_flag, default=False, help="True if we want to freeze the input encoder weights during training - just the gene and expression encoder, leave the perturbation encoder unfrozen")
parser.add_argument("--freeze_transformer_encoder", type=bool_flag, default=False, help="True if we want to freeze the transformer encoder weights during training")
parser.add_argument("--use_lora", type=bool_flag, default=False, help="True if we want to use LoRa for finetuning")
parser.add_argument("--lora_rank", type=int, default=8, help="if use_lora, specifies the inner dimension of the low-rank matrices to train")
parser.add_argument("--config_path", type=str, default="config/default_config.json", help="path to JSON configuration file to use for setting up model")
parser.add_argument("--model_type", type=str, default="scGPT", help="scGPT, simple_affine, simple_affine_large, mean_control, mean_perturbed, mean_control+perturbed, smart_mean_control, smart_mean_perturbed, smart_mean_control+perturbed")
parser.add_argument("--validation_selection", type=str, default="pearson", help="how to select the best model during training, if 'pearson' will be by pearson correlation between predicted and actual expression over validation set, if 'loss' will be by minimal loss")
parser.add_argument("--loss_type", type=str, default="mse", help="mse, mse+triplet, mse+pearson")
parser.add_argument("--fixed_seed", type=bool_flag, default=True, help="True if we want to use a constant fixed seed")
parser.add_argument("--cross_validation", type=bool_flag, default=False, help="if True will cross validate instead of single random split, --cross_validation_fold arg must also be set")
parser.add_argument("--cross_validation_fold", type=int, default=None, help="which fold to train")
opt = parser.parse_args()
check_args(opt)
matplotlib.rcParams["savefig.transparent"] = False
if opt.fixed_seed:
set_seed(42)
else:
set_seed(int(time.perf_counter()))
##var values will depend on load_model and mode
var = get_variables(load_model=opt.load_model, config_path=opt.config_path)
##set up save dir and logger
if opt.save_dir == "default":
save_dir = Path(f"./save/random_shuffle={opt.random_shuffle}/dev_perturb_{opt.data_name}-{time.strftime('%b%d-%H-%M')}/")
else:
save_dir = Path(opt.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")
# log command line arguments
logger.info(f"{opt}")
# log running date and current git commit
logger.info(f"Running on {time.strftime('%Y-%m-%d %H:%M:%S')}")
logger.info(f"saving to {save_dir}")
logger.info(f"var: {var}")
##setup PertData object
if opt.data_name == "telohaec":
pert_data = get_telohaec_pert_data(split=var["split"], batch_size=var["batch_size"], test_batch_size=var["eval_batch_size"], generate_new=False)
if opt.data_name == "replogle_k562_gwps":
pert_data = get_replogle_gwps_pert_data(split=var["split"], batch_size=var["batch_size"], test_batch_size=var["eval_batch_size"], generate_new=False)
if opt.data_name == "adam_corrected":
pert_data = get_adam_corrected_dataset(split=var["split"], batch_size=var["batch_size"], test_batch_size=var["eval_batch_size"], generate_new=False, just_upr=False)
if opt.data_name == "adam_corrected_upr":
pert_data = get_adam_corrected_dataset(split=var["split"], batch_size=var["batch_size"], test_batch_size=var["eval_batch_size"], generate_new=False, just_upr=True)
if opt.data_name in ["adamson", "norman", "replogle_k562_essential"]:
pert_data = PertData("./data")
pert_data.load(data_name=opt.data_name) ##seems to instantiate a lot of PertData attributes
pert_data.prepare_split(split=var["split"], seed=1)
pert_data.get_dataloader(batch_size=var["batch_size"], test_batch_size=var["eval_batch_size"])
if opt.filter_perturbations:
logger.info("WARNING: filtering dataloaders! but keeping pert_data.adata the same")
modify_pertdata_dataloaders(pert_data, logger)
if opt.cross_validation:
cross_validate_split(pert_data, opt.cross_validation_fold)
else:
check_pert_split(opt.data_name, pert_data)
logger.info(f"adata.obs: {pert_data.adata.obs}")
logger.info(f"|conditions|: {len(set(pert_data.adata.obs['condition']))}")
gene_name_list = pert_data.adata.var["gene_name"].tolist()
gene_idx_map = {gene: gene_name_list.index(gene) for gene in gene_name_list} ##add dictionary from gene_name to index for passing to train to fix bug left by scGPT authors to get perturbation index (they don't account for latest version of GEARS loaders)
logger.info(f"|gene_name_list|: {len(set(gene_name_list))}")
model_file, vocab, n_genes, gene_ids, ntokens = get_model_setup(var, pert_data, logger)
##baselines from perturBench
if opt.mode in ["benchmark"]:
encoder_input_dim = len(pert_data.adata.var["gene_name"].tolist()) ##want all the genes for these baseline models
perturbench_models = {"linear_additive": LinearAdditive(encoder_input_dim, pert_data, var), "latent_additive": LatentAdditive(encoder_input_dim, pert_data, var), "decoder_only": DecoderOnly(encoder_input_dim, pert_data, var)}
logger.info(f"perturbench models: ")
for perturbench_model in perturbench_models:
p_model = perturbench_models[perturbench_model]
p_model.to(var["device"])
p_model = p_model.train_model(pert_data.dataloader["train_loader"], pert_data.dataloader["val_loader"], var, gene_ids)
p_res = eval_perturb(pert_data.dataloader["test_loader"], p_model, gene_ids=gene_ids, gene_idx_map={}, var=var, loss_type=opt.loss_type) ##keep on cpu, no need to shuttle to gpu
p_metrics, p_metrics_pert = compute_metrics(p_res)
logger.info(f"test metrics: {perturbench_model}")
pickle.dump((p_metrics, p_metrics_pert), open(save_dir / f"{perturbench_model}_results_{opt.data_name}.pkl", "wb"))
p_metrics = compute_perturbation_metrics(p_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
logger.info(f"{opt.data_name} {perturbench_model} delta test metrics: {p_metrics}")
pickle.dump(p_metrics, open(save_dir / f"{perturbench_model}_pert_delta_results_{opt.data_name}.pkl", "wb"))
test_perts = pickle.load(open(f"pickles/{opt.data_name}_perturbation_splits.pkl", "rb"))["test"]
ranks = get_rank(p_model, test_perts, pert_data=pert_data, var=var, gene_ids=gene_ids, gene_idx_map=gene_idx_map)
pickle.dump(ranks, open(save_dir / f"rank_metrics_{opt.data_name}_{perturbench_model}.pkl", "wb"))
##mean predictor models - compute after data loaders are set (after a possible application of filter_perturbations)
if opt.mode in ["train", "test", "analysis"]:
for baseline in ["smart", "baseline"]:
for mean_type in ["perturbed"]:#, "control", "control+perturbed"]:
##baseline mean
if baseline == "baseline":
mean_pred_model = MeanPredictor(pert_data, opt.data_name, mean_type=mean_type)
rank_save = f"pickles/rank_metrics_{opt.data_name}_mean_{mean_type}.pkl"
if baseline == "smart":
mean_pred_model = SmartMeanPredictor(pert_data, opt.data_name, mean_type=mean_type, crispr_type="crispra" if opt.data_name in ["norman"] else "crispri")
rank_save = f"pickles/rank_metrics_{opt.data_name}_smart_mean_{mean_type}.pkl"
# mean_pred_model.test_ordering(pert_data.dataloader["test_loader"])
##get rank metrics
test_perts = pickle.load(open(f"pickles/{opt.data_name}_perturbation_splits.pkl", "rb"))["test"]
ranks = get_rank(mean_pred_model, test_perts, pert_data=pert_data, var=var, gene_ids=gene_ids, gene_idx_map=gene_idx_map)
if opt.cross_validation == False: ##save ranks to global pickles/ only on main split
pickle.dump(ranks, open(rank_save, "wb"))
##GEARS-type metrics
mean_res = eval_perturb(pert_data.dataloader["test_loader"], mean_pred_model, gene_ids=[], gene_idx_map={}, var={"device":"cpu"}, loss_type=opt.loss_type) ##keep on cpu, no need to shuttle to gpu for mean pred model
mean_metrics, mean_metrics_pert = compute_metrics(mean_res) ##from GEARS library
logger.info(f"test metrics: ")
pickle.dump((mean_metrics, mean_metrics_pert), open(save_dir / f"{baseline}_mean_{mean_type}_results_{opt.data_name}.pkl", "wb"))
##scGPT-type metrics
mean_metrics = compute_perturbation_metrics(mean_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]) ##from scGPT library
logger.info(f"{opt.data_name} {baseline} mean {mean_type} delta test metrics: {mean_metrics}")
pickle.dump(mean_metrics, open(save_dir / f"{baseline}_mean_{mean_type}_pert_delta_results_{opt.data_name}.pkl", "wb"))
##condition specific performance
condition_map = get_condition_performance_breakdown(mean_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
pickle.dump(condition_map, open(save_dir / f"{baseline}_mean_{mean_type}_condition_specific_results_{opt.data_name}.pkl", "wb"))
##gene specific performance
gene_to_pearson_map = get_gene_performance_breakdown(mean_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
pickle.dump(gene_to_pearson_map, open(save_dir / f"{baseline}_mean_{mean_type}_gene_specific_results_{opt.data_name}.pkl", "wb"))
if opt.model_type == "scGPT":
model = TransformerGenerator(
ntoken=ntokens,
d_model=var["embsize"],
nhead=var["nhead"],
d_hid=var["d_hid"],
nlayers=var["nlayers"],
nlayers_cls=var["n_layers_cls"],
n_cls=1,
vocab=vocab,
dropout=var["dropout"],
pad_token=var["pad_token"],
pad_value=var["pad_value"],
pert_pad_id=var["pert_pad_id"],
use_fast_transformer=var["use_fast_transformer"],
)
elif "simple_affine" in opt.model_type:
from simple_affine import SimpleAffine
is_large = True if "large" in opt.model_type else False
print("LARGE: ", is_large)
model = SimpleAffine(
ntoken=ntokens,
d_model=var["embsize"],
nlayers=var["nlayers"],
nlayers_cls=var["n_layers_cls"],
vocab=vocab,
dropout=var["dropout"],
pad_token=var["pad_token"],
pert_pad_id=var["pert_pad_id"],
is_large=is_large
)
elif "mean" in opt.model_type:
if "smart" in opt.model_type:
mean_type = opt.model_type.split("smart_mean_")[1]
model = SmartMeanPredictor(pert_data, opt.data_name, mean_type=mean_type, crispr_type="crispra" if opt.data_name in ["norman"] else "crispri")
else:
mean_type = opt.model_type.split("_")[1]
model = MeanPredictor(pert_data, opt.data_name, mean_type=mean_type)
else:
raise Exception("model_type must be one of scGPT, simple_affine, mean_control, mean_perturbed, mean_control+perturbed, smart_mean_control, smart_mean_perturbed, smart_mean_control+perturbed")
if opt.model_type in ["scGPT", "simple_affine", "simple_affine_large"]:
model = load_model(var, model, model_file, logger, attention_control=opt.attention_control, freeze_input_encoder=opt.freeze_input_encoder, freeze_transformer_encoder=opt.freeze_transformer_encoder, mode=opt.mode, use_lora=opt.use_lora, lora_rank=opt.lora_rank, pretrain_control=opt.pretrain_control, transformer_encoder_control=opt.transformer_encoder_control, input_encoder_control=opt.input_encoder_control)
model.to(var["device"])
if opt.mode == "train":
loss_map = {"train": {}, "val": {}}
criterion = masked_mse_loss
optimizer = torch.optim.Adam(model.parameters(), lr=var["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, var["schedule_interval"], gamma=0.9)
scaler = torch.cuda.amp.GradScaler(enabled=var["amp"])
best_val_loss = float("inf")
best_model = None
best_val_score = -100000000000
patience = 0
for epoch in range(1, var["epochs"] + 1):
print(f'epoch: {epoch} RAM used: {psutil.virtual_memory()[2]}%, {psutil.virtual_memory()[3]/1000000000}')
epoch_start_time = time.time()
train_loader = pert_data.dataloader["train_loader"]
valid_loader = pert_data.dataloader["val_loader"]
train_loss = train(
model=model,
train_loader=train_loader,
n_genes=n_genes,
gene_ids=gene_ids,
criterion=criterion,
scaler=scaler,
optimizer=optimizer,
scheduler=scheduler,
logger=logger,
epoch=epoch,
gene_idx_map=gene_idx_map,
random_shuffle=opt.random_shuffle,
always_keep_pert_gene=opt.always_keep_pert_gene,
loss_type=opt.loss_type,
var=var
)
val_res = eval_perturb(valid_loader, model, gene_ids, gene_idx_map, var, loss_type=opt.loss_type)
loss_map = update_loss_map(loss_map, train_loss, val_res)
val_metrics = compute_perturbation_metrics(val_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
logger.info(f"val_metrics at epoch {epoch}: ")
logger.info(val_metrics)
logger.info(f" avg val loss: {val_res['avg_loss']}")
elapsed = time.time() - epoch_start_time
logger.info(f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ")
if opt.validation_selection == "pearson": ##default: select model based on highest pearson score over all genes
val_score = val_metrics["pearson"]
if opt.validation_selection == "loss": ##select model based on what gives lowest loss in the validation set
val_score = 1.0 / val_res["avg_loss"] ##invert: best model will have inverse furthest to the right on the (+) number line
if val_score > best_val_score:
best_val_score = val_score
best_model = copy.deepcopy(model)
logger.info(f"Best model with score {val_score:5.4f}")
patience = 0
else:
patience += 1
if patience >= var["early_stop"]:
logger.info(f"Early stop at epoch {epoch}")
break
scheduler.step()
torch.save(best_model.state_dict(), save_dir / "best_model.pt")
pickle.dump(loss_map, open(save_dir / f"loss_map_{opt.data_name}.pkl", "wb"))
logger.info(f"loss_map: {loss_map}")
if opt.mode in ["test", "analysis"]:
best_model = model
if opt.mode in ["train", "test"]: ##test model always for both mode == train or test
test_res = eval_perturb(pert_data.dataloader["test_loader"], best_model, gene_ids, gene_idx_map, var, loss_type=opt.loss_type)
##GEARS-type metrics
metrics, metrics_pert = compute_metrics(test_res) ##from GEARS library
logger.info(f"test metrics: ")
pickle.dump((metrics, metrics_pert), open(save_dir / f"{opt.model_type}_results_{opt.data_name}.pkl", "wb"))
##scGPT-type metrics
test_metrics = compute_perturbation_metrics(test_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]) ##from scGPT utils library
logger.info(f"{opt.data_name} delta test metrics: {test_metrics}")
pickle.dump(test_metrics, open(save_dir / f"{opt.model_type}_pert_delta_results_{opt.data_name}.pkl", "wb"))
##condition specific performance
condition_map = get_condition_performance_breakdown(test_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
pickle.dump(condition_map, open(save_dir / f"{opt.model_type}_condition_specific_results_{opt.data_name}.pkl", "wb"))
gene_to_pearson_map = get_gene_performance_breakdown(test_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"])
pickle.dump(gene_to_pearson_map, open(save_dir / f"{opt.model_type}_gene_specific_results_{opt.data_name}.pkl", "wb"))
if opt.mode == "analysis":
for plot_type in ["boxplots", "scatterplots"]:
if not os.path.isdir(f"figures/{plot_type}/{opt.data_name}/{opt.model_type}/{opt.load_model.replace('/', '_')}/"):
os.makedirs(f"figures/{plot_type}/{opt.data_name}/{opt.model_type}/{opt.load_model.replace('/', '_')}/")
test_perts = pickle.load(open(f"pickles/{opt.data_name}_perturbation_splits.pkl", "rb"))["test"]
print("test perts: ", test_perts)
# ##get rank metrics
ranks = get_rank(best_model, test_perts, pert_data=pert_data, var=var, gene_ids=gene_ids, gene_idx_map=gene_idx_map)
pickle.dump(ranks, open(f"pickles/rank_metrics_{opt.data_name}_{opt.model_type}.pkl", "wb"))
##make boxplots
pert_to_gene_iqr = plot_perturbation_boxplots(best_model, test_perts, save_directory=f"figures/boxplots/{opt.data_name}/{opt.model_type}/{opt.load_model.replace('/', '_')}/", pert_data=pert_data, var=var, gene_ids=gene_ids, gene_idx_map=gene_idx_map, data_name=opt.data_name, model_type=opt.model_type)
for plot_differential in [True, False]:
pearson_string = "pearson de" if plot_differential == False else "pearson de delta"
print(f"{opt.data_name}/{opt.model_type} {pearson_string} mean/std: ", np.mean(list([x[0] for x in list(pert_to_gene_iqr[f"control_differential={plot_differential}"].values())])), np.std(list(x[0] for x in list(pert_to_gene_iqr[f"control_differential={plot_differential}"].values()))))
print(f"{opt.data_name}/{opt.model_type} Wasserstein distance mean/std: ", np.mean(list([x[1] for x in list(pert_to_gene_iqr[f"control_differential={plot_differential}"].values())])), np.std(list(x[1] for x in list(pert_to_gene_iqr[f"control_differential={plot_differential}"].values()))))
if not os.path.isdir("pickles/pert_to_gene_iqr/"):
os.makedirs("pickles/pert_to_gene_iqr/")
pickle.dump(pert_to_gene_iqr, open(f"pickles/pert_to_gene_iqr/pert_to_gene_iqr_{opt.data_name}_{opt.model_type}.pkl", "wb"))
##convenience data structure
perturbation_structure = get_complete_data_structure_for_perturbation(best_model, test_perts, pert_data=pert_data, var=var, gene_ids=gene_ids, gene_idx_map=gene_idx_map, data_name=opt.data_name, model_type=opt.model_type)
pickle.dump(perturbation_structure, open(f"pickles/pert_data_structure_{opt.data_name}_{opt.model_type}_load_{opt.load_model.replace('/', '_')}.pkl", "wb"))
if __name__ == "__main__":
main()