You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ifnp.std(pred[:,i]) <0.000001: ##mode collapse for models like mean, if std == 0 we cannot compute pearson (will be NaN), add very small random noise to prediction
forbatch, batch_datainenumerate(old_loader): ##batch_data is of type torch_geometric.data.batch.DataBatch, batch_data[i] is of type torch_geometric.data.data.Data
Copy file name to clipboardExpand all lines: runner.py
+27-7Lines changed: 27 additions & 7 deletions
Original file line number
Diff line number
Diff line change
@@ -22,10 +22,13 @@ def main():
22
22
parser.add_argument("--use_lora", type=bool_flag, default=False, help="True if we want to use LoRa for finetuning")
23
23
parser.add_argument("--lora_rank", type=int, default=8, help="if use_lora, specifies the inner dimension of the low-rank matrices to train")
24
24
parser.add_argument("--config_path", type=str, default="config/default_config.json", help="path to JSON configuration file to use for setting up model")
parser.add_argument("--validation_selection", type=str, default="pearson", help="how to select the best model during training, if 'pearson' will be by pearson correlation between predicted and actual expression over validation set, if 'loss' will be by minimal loss")
parser.add_argument("--fixed_seed", type=bool_flag, default=True, help="True if we want to use a constant fixed seed")
29
+
parser.add_argument("--cross_validation", type=bool_flag, default=False, help="if True will cross validate instead of single random split, --cross_validation_fold arg must also be set")
30
+
parser.add_argument("--cross_validation_fold", type=int, default=None, help="which fold to train")
31
+
29
32
opt=parser.parse_args()
30
33
check_args(opt)
31
34
matplotlib.rcParams["savefig.transparent"] =False
@@ -70,7 +73,10 @@ def main():
70
73
logger.info("WARNING: filtering dataloaders! but keeping pert_data.adata the same")
ifopt.cross_validation==False: ##save ranks to global pickles/ only on main split
125
+
pickle.dump(ranks, open(rank_save, "wb"))
119
126
##GEARS-type metrics
120
127
mean_res=eval_perturb(pert_data.dataloader["test_loader"], mean_pred_model, gene_ids=[], gene_idx_map={}, var={"device":"cpu"}, loss_type=opt.loss_type) ##keep on cpu, no need to shuttle to gpu for mean pred model
raiseException("model_type must be one of scGPT, simple_affine, mean_control, mean_perturbed, mean_control+perturbed, smart_mean_control, smart_mean_perturbed, smart_mean_control+perturbed")
0 commit comments