Skip to content

Commit 41f7098

Browse files
committed
updates in response to reviewer feedback
1 parent bf3ad2d commit 41f7098

1 file changed

Lines changed: 57 additions & 12 deletions

File tree

analysis.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def plot_subset_model_scores(mode, include_simple_affine=False):
353353
if key in path:
354354
assigned_key = key
355355
break
356+
356357
if dataset in ["adam_corrected_upr", "norman", "replogle_k562_essential"]: ##only include the 3 core datasets in this calculation for statistical significance (for main results)
357358
if unreduced_map[assigned_key][dataset] == "":
358359
unreduced_map[assigned_key][dataset] = unreduced
@@ -437,6 +438,8 @@ def get_p_val_comparisons(unreduced_map, x_labels):
437438
else:
438439
significance_char = ""
439440
print(f"{p1} | {p2}: {x_label}: sample_sizes: {len(consolidated[p1][x_label])}, {len(consolidated[p2][x_label])}, mean_difference: {abs(np.mean(consolidated[p1][x_label]) - np.mean(consolidated[p2][x_label]))}, p_val: {significance_char}{p_value}{significance_char}")
441+
# print(f" {np.mean(consolidated[p1][x_label])} | {np.mean(consolidated[p2][x_label])}")
442+
# print(f" {consolidated[p1][x_label]} | {consolidated[p2][x_label]}")
440443

441444
def plot_model_losses():
442445
paths = []
@@ -1351,6 +1354,10 @@ def get_perturbation_level_metrics(preds_df, actual_df, control_means_df, pertur
13511354

13521355
def plot_condition_specific_performance():
13531356
"""
1357+
Two plots:
1358+
1) scatterplot where each point is a target with x = mean model pearson delta, y = scGPT pearson delta
1359+
1360+
2)
13541361
Plot performance of different perturbation conditions
13551362
plot will be specific dataset pearson delta / pearson de delta
13561363
x-axis: conditions
@@ -1376,14 +1383,49 @@ def plot_condition_specific_performance():
13761383
model_map[model][dataset] = pickle.load(open(f"pickles/gears_results/gears_condition_specific_results_{dataset}_{model_run_id}.pkl", "rb"))
13771384
else: ##scGPT eval done with front-running already
13781385
model_map[model][dataset] = pickle.load(open(f"save/test_condition_specific_performance/{model}_condition_specific_results_{dataset}.pkl", "rb"))
1379-
##select 20 conditions at random per study and keep consistent for all models
1380-
# dataset_to_conditions = {dataset: [] for dataset in datasets}
1381-
# for dataset in model_map["scGPT"]:
1382-
# ##select 20 conditions at random and sort
1383-
# n = 20
1384-
# plot_conditions = sorted(random.sample(list(model_map["scGPT"][dataset].keys()), n))
1385-
# dataset_to_conditions[dataset] = plot_conditions
13861386

1387+
##scatterplot where each point is a target with x = mean model pearson delta, y = scGPT pearson delta
1388+
for dataset in datasets:
1389+
for metric in ["pearson_delta"]:
1390+
for model in ["scGPT", "gears"]:
1391+
fig, ax = plt.subplots()
1392+
x1, y1 = [], [] ##when mean model is better
1393+
x2, y2 = [], [] ##when deep model is better
1394+
x3, y3 = [], [] ##when equal
1395+
# targets = []
1396+
for target in model_map["smart_mean_perturbed"][dataset]:
1397+
deep_score = model_map[model][dataset][target][metric]
1398+
mean_score = model_map["smart_mean_perturbed"][dataset][target][metric]
1399+
if mean_score > deep_score:
1400+
x1.append(deep_score)
1401+
y1.append(mean_score)
1402+
elif deep_score > mean_score:
1403+
x2.append(deep_score)
1404+
y2.append(mean_score)
1405+
# targets.append(target)
1406+
else:
1407+
x3.append(deep_score)
1408+
y3.append(mean_score)
1409+
print(dataset, targets)
1410+
ax.scatter(x1, y1, color=color_map["smart_mean_perturbed"])
1411+
ax.scatter(x2, y2, color=color_map[model])
1412+
ax.scatter(x3, y3, color="grey")
1413+
##label the targets where deep did better
1414+
# print(targets)
1415+
# for i, label in enumerate(targets):
1416+
# if label in ["RPS5+ctrl", "RPL35A+ctrl"]: ##manually move some of these because they overlap too much and can't see them
1417+
# plt.annotate(label, (x2[i], y2[i] + 0.03), fontsize=4)
1418+
# else:
1419+
# plt.annotate(label, (x2[i], y2[i]), fontsize=4)
1420+
ax.set_xlabel(models[model])
1421+
ax.set_ylabel("CRISPR-informed Mean")
1422+
ax.plot([-.3, 1], [-.3, 1], '--', alpha=0.75, color="black", zorder=0) ##plot x = y line
1423+
plt.xlim((-.3, 1.03))
1424+
plt.ylim((-.3, 1.03))
1425+
plt.title(f"{models[model]} vs CRISPR-informed Mean{metric_label_map[metric]}\nby Target for {get_dataset_title(dataset)}")
1426+
plt.savefig(f"outputs/breakdown/scatter_{dataset}_{model}_{metric}.png", dpi=300)
1427+
1428+
##bar graphs of top 20 and bottom 20
13871429
##select CRISPR mean's top 20 and bottom 20 conditions per metric per study and keep consistent for all models
13881430
dataset_to_conditions = {place: {metric: {dataset: [] for dataset in datasets} for metric in metrics} for place in places}
13891431
for dataset in model_map["smart_mean_perturbed"]:
@@ -1425,7 +1467,10 @@ def plot_condition_specific_performance():
14251467
x = x + width
14261468
min_y = min(min_y, min(y))
14271469
plt.ylim((min_y - 0.1, 1.01))
1428-
plt.title(f"Performance of CRISPR-informed Mean's {place}\nPerturbation Conditions for {get_dataset_title(dataset)}", fontsize=9)
1470+
if dataset == "adam_corrected_upr": ##this dataset just has 20 test set perturbations
1471+
plt.title(f"Performance of All Test Set Conditions for {get_dataset_title(dataset)}", fontsize=9)
1472+
else:
1473+
plt.title(f"Performance of CRISPR-informed Mean's {place}\nPerturbation Conditions for {get_dataset_title(dataset)}", fontsize=9)
14291474
box = ax.get_position()
14301475
ax.set_position([box.x0, box.y0, box.width, box.height * 0.85])
14311476
ax.legend(loc='upper right', prop={"size":7}, bbox_to_anchor=(1, 1.38))
@@ -1570,7 +1615,6 @@ def plot_cross_validation():
15701615
get_avg_baseline(mode=1)
15711616
plot_perturbench_comparison(mode=1)
15721617
plot_subset_model_scores(mode=1)
1573-
##plot_subset_model_scores(mode=1, include_simple_affine=True) ##for checking statistical significance, actual plot is too crowded
15741618
plot_rank_scores(mode=1, include_perturbench=True)
15751619
plot_model_scores(mode=1)
15761620
plot_simple_affine_run_times()
@@ -1588,6 +1632,7 @@ def plot_cross_validation():
15881632
plot_wasserstein_pert_gene_comparison()
15891633

15901634
#not for manuscript figure generation
1591-
plot_model_losses()
1592-
root_dirs = ["save/no_pretraining/", "save/default_config_baseline/", "save/simple_affine/", "save/simple_affine_with_pretraining/", "save/perturbench/", "pickles/gears_results/"]
1593-
find_best_models(mode=1)
1635+
# plot_subset_model_scores(mode=1, include_simple_affine=True) ##for checking statistical significance, actual plot is too crowded
1636+
# plot_model_losses()
1637+
# root_dirs = ["save/no_pretraining/", "save/default_config_baseline/", "save/simple_affine/", "save/simple_affine_with_pretraining/", "save/perturbench/", "pickles/gears_results/"]
1638+
# find_best_models(mode=1)

0 commit comments

Comments
 (0)