Skip to content

Commit 7bbcc0b

Browse files
authored
Merge pull request #2 from FrancisCrickInstitute/main
Updated figures
2 parents 46d8d45 + c0e2aec commit 7bbcc0b

2 files changed

Lines changed: 291 additions & 207 deletions

File tree

notebooks/companion_notebook.ipynb

Lines changed: 218 additions & 145 deletions
Large diffs are not rendered by default.

notebooks/utility_functions.py

Lines changed: 73 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def map_wells_to_treatments(data, treatments, treatments_to_compounds, compounds
129129

130130
def generate_swarmplot(plot_order, data, color_dict, treatment_col, variable_of_interest, y_label,
131131
point_size=2, p_values=False, random_seed=42, fig_width=14, fig_height=10, plot_rows=1,
132-
plot_cols=1, n_samples=1, sample_size=-1):
132+
plot_cols=1, n_samples=1, sample_size=-1, filename=None, title=None):
133133
"""
134134
Generates and saves swarm plots for the variable of interest across different treatments.
135135
@@ -172,17 +172,10 @@ def generate_swarmplot(plot_order, data, color_dict, treatment_col, variable_of_
172172
capprops=dict(color="black", linewidth=2, zorder=2),
173173
medianprops=dict(color="black", linewidth=2, zorder=2),
174174
showfliers=False, ax=ax)
175-
# Calculate and plot the confidence intervals
176-
for treatment in plot_order:
177-
y_values = sampled_data[sampled_data[treatment_col] == treatment][variable_of_interest]
178-
# print(f'Treatment: {treatment}, Mean: {y_values.mean()}')
179-
lower, upper = ci(y_values, 0.95)
180-
x_pos = plot_order.index(treatment)
181-
ax.errorbar(x_pos, y_values.mean(), yerr=[[y_values.mean() - lower], [upper - y_values.mean()]],
182-
fmt='none', ecolor='red', capsize=40, capthick=2, zorder=3)
175+
183176
ax.set_ylabel(y_label)
184177
ax.set_xlabel('')
185-
if ~p_values:
178+
if not p_values:
186179
ax.set_ylim(bottom=0, top=1.0)
187180
else:
188181
_, p_value = stats.kruskal(
@@ -223,14 +216,23 @@ def generate_swarmplot(plot_order, data, color_dict, treatment_col, variable_of_
223216

224217
str_p_value = f'p = {dunn_p_values[pair][0]:.3f}'
225218

226-
if dunn_p_values[pair][0] < 0.001:
227-
str_p_value = 'p < 0.001'
219+
if dunn_p_values[pair][0] < 0.0001:
220+
str_p_value = '****'
221+
elif dunn_p_values[pair][0] < 0.001:
222+
str_p_value = '***'
228223
elif dunn_p_values[pair][0] < 0.01:
229-
str_p_value = 'p < 0.01'
224+
str_p_value = '**'
225+
elif dunn_p_values[pair][0] < 0.05:
226+
str_p_value = '*'
230227

231228
# Annotate line with p-value
232229
plt.text((x1 + x2) * .5, y + h, str_p_value, ha='center', va='bottom', color=col, fontsize=20)
233230

231+
if title is not None:
232+
plt.title(title)
233+
plt.tight_layout()
234+
if filename is not None:
235+
plt.savefig(f'../plots/{filename}')
234236
plt.show()
235237

236238

@@ -341,7 +343,7 @@ def plot_effect_size_v_sample_size(sample_sizes, num_iterations, data, treatment
341343

342344

343345
def plot_iqr_v_sample_size(sample_sizes, num_iterations, data, treatment_col, variable_of_interest, y_label,
344-
initial_random_seed=42):
346+
initial_random_seed=42, filename=None):
345347
# Initialize dictionaries to store multiple mean values per sample size for each treatment
346348
mean_values = {treatment: [[] for _ in range(len(sample_sizes))] for treatment in data[treatment_col].unique()}
347349
random_seed = initial_random_seed
@@ -379,15 +381,18 @@ def plot_iqr_v_sample_size(sample_sizes, num_iterations, data, treatment_col, va
379381
plt.xlabel('Number of Cells')
380382
plt.ylabel(y_label)
381383
plt.legend(fontsize=20)
384+
plt.tight_layout()
385+
if filename is not None:
386+
plt.savefig(f'../plots/{filename}')
382387
plt.show()
383388

384389

385390
def plot_cumulative_histogram_samples(data, variable_of_interest, treatment_col, treatment, x_label,
386-
initial_random_seed=42):
391+
initial_random_seed=42, filenames=None):
387392
total_samples = []
388393
max_samples = 500
389394
step = 10
390-
filecount = 1
395+
filecount = 0
391396
random_seed = initial_random_seed
392397
subsample = data[data[treatment_col] == treatment]
393398

@@ -440,43 +445,45 @@ def plot_cumulative_histogram_samples(data, variable_of_interest, treatment_col,
440445
bin_maxes = np.maximum.reduceat(n, np.digitize([q1, q3], bins[:-1]) - 1)
441446
max_density = max(bin_maxes)
442447

443-
# Shade the IQR region
444-
plt.fill_betweenx(np.arange(0, max_density, 0.01), q1, q3, color='grey', alpha=0.3, label='IQR')
445-
446448
plt.title(f'{len(total_samples)} {treatment} Cells')
447449
plt.xlabel(x_label)
448450
plt.ylabel('Frequency (%)')
449-
plt.ylim(bottom=0, top=40)
450-
plt.xlim(left=0.4, right=0.9)
451-
plt.grid(True)
451+
plt.ylim(bottom=0, top=20)
452+
plt.xlim(left=0, right=1)
453+
# plt.grid(True)
454+
plt.tight_layout()
455+
if filenames is not None:
456+
plt.savefig(f'../plots/{filenames[filecount]}')
457+
452458
plt.show()
453459
filecount = filecount + 1
454460

455-
# print(
456-
# f'Median: {np.median(sample_data)} IQR: {np.percentile(sample_data, 75) - np.percentile(sample_data, 25)}')
457-
458461
# Break the loop if we have included all available samples
459462
if remaining_samples <= new_samples_count:
460463
break
461464

465+
mean_values = [x - 0.4649 for x in mean_values]
466+
median_values = [x - 0.5108 for x in median_values]
467+
std_values = [x - 0.1306 for x in std_values]
468+
iqr_values = [x - 0.1288 for x in iqr_values]
462469
plt.figure(figsize=(14, 10))
463-
ax1 = plt.gca()
464-
ax1.scatter(sample_sizes, mean_values, label='_Mean', alpha=0.5, color='blue')
465-
ax1.scatter(sample_sizes, median_values, label='_Median', alpha=0.5, color='orange')
466-
ax1.plot(sample_sizes, mean_values, label='Mean', color='blue')
467-
ax1.plot(sample_sizes, median_values, label='Median', color='orange')
468-
ax1.set_ylabel(f'Mean, Median of {x_label}')
469-
ax1.set_xlabel('Number of Cells')
470-
ax2 = ax1.twinx()
471-
ax2.scatter(sample_sizes, std_values, label='_Standard Deviation', alpha=0.5, color='gray')
472-
ax2.scatter(sample_sizes, iqr_values, label='_IQR', alpha=0.5, color='purple')
473-
ax2.plot(sample_sizes, std_values, label='Standard Deviation', color='gray')
474-
ax2.plot(sample_sizes, iqr_values, label='IQR', color='purple')
475-
ax2.set_ylabel(f'SD, IQR of {x_label}')
470+
plt.scatter(sample_sizes, mean_values, label='_Mean', alpha=0.5, color='blue')
471+
plt.scatter(sample_sizes, median_values, label='_Median', alpha=0.5, color='orange')
472+
plt.plot(sample_sizes, mean_values, label='Mean', color='blue')
473+
plt.plot(sample_sizes, median_values, label='Median', color='orange')
474+
plt.ylabel(f'Difference relative to statistic evaluated\n for all cells in this well')
475+
plt.xlabel('Number of Cells')
476+
plt.axhline(y=0.0, color='black', linestyle='dotted', linewidth=2)
477+
plt.scatter(sample_sizes, std_values, label='_Standard Deviation', alpha=0.5, color='gray')
478+
plt.scatter(sample_sizes, iqr_values, label='_IQR', alpha=0.5, color='purple')
479+
plt.plot(sample_sizes, std_values, label='Standard Deviation', color='gray')
480+
plt.plot(sample_sizes, iqr_values, label='IQR', color='purple')
476481
# Create a single legend
477-
lines, labels = ax1.get_legend_handles_labels()
478-
lines2, labels2 = ax2.get_legend_handles_labels()
479-
ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4)
482+
# plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4)
483+
plt.legend()
484+
plt.tight_layout()
485+
if filenames is not None:
486+
plt.savefig(f'../plots/{filenames[filecount]}')
480487
plt.show()
481488

482489

@@ -537,7 +544,8 @@ def plot_p_v_sample_size(sample_sizes, num_iterations, data, treatment_col, vari
537544

538545

539546
def generate_superplot(plot_order, treatments, data, color_dict, treatment_col, variable_of_interest,
540-
y_label, random_seed=42, fig_width=23, fig_height=12, sample_size=-1, point_size=3):
547+
y_label, random_seed=42, fig_width=16, fig_height=8, sample_size=-1, point_size=3,
548+
filename=None):
541549
mean_data = pd.DataFrame()
542550

543551
for t in treatments:
@@ -553,38 +561,41 @@ def generate_superplot(plot_order, treatments, data, color_dict, treatment_col,
553561

554562
ReplicateAverages = mean_data.groupby([treatment_col, 'Replicate'], as_index=False).agg(
555563
{variable_of_interest: "mean"})
556-
plt.figure(figsize=(1.2 * fig_width, fig_height))
557-
ax = plt.subplot(1, 1, 1)
558-
sns.boxplot(x=treatment_col, y=variable_of_interest, data=ReplicateAverages, order=plot_order, color='white',
559-
showfliers=False, linecolor='black', linewidth=2, zorder=1, boxprops=dict(facecolor='none'))
560-
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", data=mean_data, size=1.1 * point_size,
561-
order=plot_order,
562-
zorder=0, palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
563-
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", size=25, edgecolor="k", linewidth=2,
564-
data=ReplicateAverages, order=plot_order, zorder=2,
565-
palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
566-
plt.ylim(bottom=0.0, top=1.0)
567-
plt.xlabel('')
568-
plt.ylabel(y_label)
569-
plt.title(f'{sample_size} cells per population')
570-
ax.legend_.remove()
571-
plt.show()
572-
plt.close()
573-
574564
plt.figure(figsize=(fig_width, fig_height))
575565
ax = plt.subplot(1, 1, 1)
576566
sns.boxplot(x=treatment_col, y=variable_of_interest, data=ReplicateAverages, order=plot_order, color='white',
577567
showfliers=False, linecolor='black', linewidth=2, zorder=1, boxprops=dict(facecolor='none'))
578568
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", data=mean_data, size=point_size,
579569
order=plot_order,
580570
zorder=0, palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
581-
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", size=25, edgecolor="k", linewidth=2,
571+
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", size=20, edgecolor="k", linewidth=2,
582572
data=ReplicateAverages, order=plot_order, zorder=2,
583573
palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
584-
plt.ylim(bottom=0.42, top=0.6)
574+
plt.ylim(bottom=0.0, top=1.0)
585575
plt.xlabel('')
586576
plt.ylabel(y_label)
587577
plt.title(f'{sample_size} cells per population')
588578
ax.legend_.remove()
579+
plt.tight_layout()
580+
if filename is not None:
581+
plt.savefig(f'../plots/{filename}')
589582
plt.show()
590583
plt.close()
584+
585+
# plt.figure(figsize=(fig_width, fig_height))
586+
# ax = plt.subplot(1, 1, 1)
587+
# sns.boxplot(x=treatment_col, y=variable_of_interest, data=ReplicateAverages, order=plot_order, color='white',
588+
# showfliers=False, linecolor='black', linewidth=2, zorder=1, boxprops=dict(facecolor='none'))
589+
# sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", data=mean_data, size=point_size,
590+
# order=plot_order,
591+
# zorder=0, palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
592+
# sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", size=25, edgecolor="k", linewidth=2,
593+
# data=ReplicateAverages, order=plot_order, zorder=2,
594+
# palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
595+
# plt.ylim(bottom=0.42, top=0.6)
596+
# plt.xlabel('')
597+
# plt.ylabel(y_label)
598+
# plt.title(f'{sample_size} cells per population')
599+
# ax.legend_.remove()
600+
# plt.show()
601+
# plt.close()

0 commit comments

Comments
 (0)