@@ -353,6 +353,7 @@ def plot_subset_model_scores(mode, include_simple_affine=False):
353353 if key in path :
354354 assigned_key = key
355355 break
356+
356357 if dataset in ["adam_corrected_upr" , "norman" , "replogle_k562_essential" ]: ##only include the 3 core datasets in this calculation for statistical significance (for main results)
357358 if unreduced_map [assigned_key ][dataset ] == "" :
358359 unreduced_map [assigned_key ][dataset ] = unreduced
@@ -437,6 +438,8 @@ def get_p_val_comparisons(unreduced_map, x_labels):
437438 else :
438439 significance_char = ""
439440 print (f"{ p1 } | { p2 } : { x_label } : sample_sizes: { len (consolidated [p1 ][x_label ])} , { len (consolidated [p2 ][x_label ])} , mean_difference: { abs (np .mean (consolidated [p1 ][x_label ]) - np .mean (consolidated [p2 ][x_label ]))} , p_val: { significance_char } { p_value } { significance_char } " )
441+ # print(f" {np.mean(consolidated[p1][x_label])} | {np.mean(consolidated[p2][x_label])}")
442+ # print(f" {consolidated[p1][x_label]} | {consolidated[p2][x_label]}")
440443
441444def plot_model_losses ():
442445 paths = []
@@ -1351,6 +1354,10 @@ def get_perturbation_level_metrics(preds_df, actual_df, control_means_df, pertur
13511354
13521355def plot_condition_specific_performance ():
13531356 """
1357+ Two plots:
1358+ 1) scatterplot where each point is a target with x = mean model pearson delta, y = scGPT pearson delta
1359+
1360+ 2)
13541361 Plot performance of different perturbation conditions
13551362 plot will be specific dataset pearson delta / pearson de delta
13561363 x-axis: conditions
@@ -1376,14 +1383,49 @@ def plot_condition_specific_performance():
13761383 model_map [model ][dataset ] = pickle .load (open (f"pickles/gears_results/gears_condition_specific_results_{ dataset } _{ model_run_id } .pkl" , "rb" ))
13771384 else : ##scGPT eval done with front-running already
13781385 model_map [model ][dataset ] = pickle .load (open (f"save/test_condition_specific_performance/{ model } _condition_specific_results_{ dataset } .pkl" , "rb" ))
1379- ##select 20 conditions at random per study and keep consistent for all models
1380- # dataset_to_conditions = {dataset: [] for dataset in datasets}
1381- # for dataset in model_map["scGPT"]:
1382- # ##select 20 conditions at random and sort
1383- # n = 20
1384- # plot_conditions = sorted(random.sample(list(model_map["scGPT"][dataset].keys()), n))
1385- # dataset_to_conditions[dataset] = plot_conditions
13861386
1387+ ##scatterplot where each point is a target with x = mean model pearson delta, y = scGPT pearson delta
1388+ for dataset in datasets :
1389+ for metric in ["pearson_delta" ]:
1390+ for model in ["scGPT" , "gears" ]:
1391+ fig , ax = plt .subplots ()
1392+ x1 , y1 = [], [] ##when mean model is better
1393+ x2 , y2 = [], [] ##when deep model is better
1394+ x3 , y3 = [], [] ##when equal
1395+ # targets = []
1396+ for target in model_map ["smart_mean_perturbed" ][dataset ]:
1397+ deep_score = model_map [model ][dataset ][target ][metric ]
1398+ mean_score = model_map ["smart_mean_perturbed" ][dataset ][target ][metric ]
1399+ if mean_score > deep_score :
1400+ x1 .append (deep_score )
1401+ y1 .append (mean_score )
1402+ elif deep_score > mean_score :
1403+ x2 .append (deep_score )
1404+ y2 .append (mean_score )
1405+ # targets.append(target)
1406+ else :
1407+ x3 .append (deep_score )
1408+ y3 .append (mean_score )
1409+ print (dataset , targets )
1410+ ax .scatter (x1 , y1 , color = color_map ["smart_mean_perturbed" ])
1411+ ax .scatter (x2 , y2 , color = color_map [model ])
1412+ ax .scatter (x3 , y3 , color = "grey" )
1413+ ##label the targets where deep did better
1414+ # print(targets)
1415+ # for i, label in enumerate(targets):
1416+ # if label in ["RPS5+ctrl", "RPL35A+ctrl"]: ##manually move some of these because they overlap too much and can't see them
1417+ # plt.annotate(label, (x2[i], y2[i] + 0.03), fontsize=4)
1418+ # else:
1419+ # plt.annotate(label, (x2[i], y2[i]), fontsize=4)
1420+ ax .set_xlabel (models [model ])
1421+ ax .set_ylabel ("CRISPR-informed Mean" )
1422+ ax .plot ([- .3 , 1 ], [- .3 , 1 ], '--' , alpha = 0.75 , color = "black" , zorder = 0 ) ##plot x = y line
1423+ plt .xlim ((- .3 , 1.03 ))
1424+ plt .ylim ((- .3 , 1.03 ))
1425+ plt .title (f"{ models [model ]} vs CRISPR-informed Mean{ metric_label_map [metric ]} \n by Target for { get_dataset_title (dataset )} " )
1426+ plt .savefig (f"outputs/breakdown/scatter_{ dataset } _{ model } _{ metric } .png" , dpi = 300 )
1427+
1428+ ##bar graphs of top 20 and bottom 20
13871429 ##select CRISPR mean's top 20 and bottom 20 conditions per metric per study and keep consistent for all models
13881430 dataset_to_conditions = {place : {metric : {dataset : [] for dataset in datasets } for metric in metrics } for place in places }
13891431 for dataset in model_map ["smart_mean_perturbed" ]:
@@ -1425,7 +1467,10 @@ def plot_condition_specific_performance():
14251467 x = x + width
14261468 min_y = min (min_y , min (y ))
14271469 plt .ylim ((min_y - 0.1 , 1.01 ))
1428- plt .title (f"Performance of CRISPR-informed Mean's { place } \n Perturbation Conditions for { get_dataset_title (dataset )} " , fontsize = 9 )
1470+ if dataset == "adam_corrected_upr" : ##this dataset just has 20 test set perturbations
1471+ plt .title (f"Performance of All Test Set Conditions for { get_dataset_title (dataset )} " , fontsize = 9 )
1472+ else :
1473+ plt .title (f"Performance of CRISPR-informed Mean's { place } \n Perturbation Conditions for { get_dataset_title (dataset )} " , fontsize = 9 )
14291474 box = ax .get_position ()
14301475 ax .set_position ([box .x0 , box .y0 , box .width , box .height * 0.85 ])
14311476 ax .legend (loc = 'upper right' , prop = {"size" :7 }, bbox_to_anchor = (1 , 1.38 ))
@@ -1570,7 +1615,6 @@ def plot_cross_validation():
15701615get_avg_baseline (mode = 1 )
15711616plot_perturbench_comparison (mode = 1 )
15721617plot_subset_model_scores (mode = 1 )
1573- ##plot_subset_model_scores(mode=1, include_simple_affine=True) ##for checking statistical significance, actual plot is too crowded
15741618plot_rank_scores (mode = 1 , include_perturbench = True )
15751619plot_model_scores (mode = 1 )
15761620plot_simple_affine_run_times ()
@@ -1588,6 +1632,7 @@ def plot_cross_validation():
15881632plot_wasserstein_pert_gene_comparison ()
15891633
15901634#not for manuscript figure generation
1591- plot_model_losses ()
1592- root_dirs = ["save/no_pretraining/" , "save/default_config_baseline/" , "save/simple_affine/" , "save/simple_affine_with_pretraining/" , "save/perturbench/" , "pickles/gears_results/" ]
1593- find_best_models (mode = 1 )
1635+ # plot_subset_model_scores(mode=1, include_simple_affine=True) ##for checking statistical significance, actual plot is too crowded
1636+ # plot_model_losses()
1637+ # root_dirs = ["save/no_pretraining/", "save/default_config_baseline/", "save/simple_affine/", "save/simple_affine_with_pretraining/", "save/perturbench/", "pickles/gears_results/"]
1638+ # find_best_models(mode=1)
0 commit comments