Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
363 changes: 218 additions & 145 deletions notebooks/companion_notebook.ipynb

Large diffs are not rendered by default.

135 changes: 73 additions & 62 deletions notebooks/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def map_wells_to_treatments(data, treatments, treatments_to_compounds, compounds

def generate_swarmplot(plot_order, data, color_dict, treatment_col, variable_of_interest, y_label,
point_size=2, p_values=False, random_seed=42, fig_width=14, fig_height=10, plot_rows=1,
plot_cols=1, n_samples=1, sample_size=-1):
plot_cols=1, n_samples=1, sample_size=-1, filename=None, title=None):
"""
Generates and saves swarm plots for the variable of interest across different treatments.

Expand Down Expand Up @@ -172,17 +172,10 @@ def generate_swarmplot(plot_order, data, color_dict, treatment_col, variable_of_
capprops=dict(color="black", linewidth=2, zorder=2),
medianprops=dict(color="black", linewidth=2, zorder=2),
showfliers=False, ax=ax)
# Calculate and plot the confidence intervals
for treatment in plot_order:
y_values = sampled_data[sampled_data[treatment_col] == treatment][variable_of_interest]
# print(f'Treatment: {treatment}, Mean: {y_values.mean()}')
lower, upper = ci(y_values, 0.95)
x_pos = plot_order.index(treatment)
ax.errorbar(x_pos, y_values.mean(), yerr=[[y_values.mean() - lower], [upper - y_values.mean()]],
fmt='none', ecolor='red', capsize=40, capthick=2, zorder=3)

ax.set_ylabel(y_label)
ax.set_xlabel('')
if ~p_values:
if not p_values:
ax.set_ylim(bottom=0, top=1.0)
else:
_, p_value = stats.kruskal(
Expand Down Expand Up @@ -223,14 +216,23 @@ def generate_swarmplot(plot_order, data, color_dict, treatment_col, variable_of_

str_p_value = f'p = {dunn_p_values[pair][0]:.3f}'

if dunn_p_values[pair][0] < 0.001:
str_p_value = 'p < 0.001'
if dunn_p_values[pair][0] < 0.0001:
str_p_value = '****'
elif dunn_p_values[pair][0] < 0.001:
str_p_value = '***'
elif dunn_p_values[pair][0] < 0.01:
str_p_value = 'p < 0.01'
str_p_value = '**'
elif dunn_p_values[pair][0] < 0.05:
str_p_value = '*'

# Annotate line with p-value
plt.text((x1 + x2) * .5, y + h, str_p_value, ha='center', va='bottom', color=col, fontsize=20)

if title is not None:
plt.title(title)
plt.tight_layout()
if filename is not None:
plt.savefig(f'../plots/{filename}')
plt.show()


Expand Down Expand Up @@ -341,7 +343,7 @@ def plot_effect_size_v_sample_size(sample_sizes, num_iterations, data, treatment


def plot_iqr_v_sample_size(sample_sizes, num_iterations, data, treatment_col, variable_of_interest, y_label,
initial_random_seed=42):
initial_random_seed=42, filename=None):
# Initialize dictionaries to store multiple mean values per sample size for each treatment
mean_values = {treatment: [[] for _ in range(len(sample_sizes))] for treatment in data[treatment_col].unique()}
random_seed = initial_random_seed
Expand Down Expand Up @@ -379,15 +381,18 @@ def plot_iqr_v_sample_size(sample_sizes, num_iterations, data, treatment_col, va
plt.xlabel('Number of Cells')
plt.ylabel(y_label)
plt.legend(fontsize=20)
plt.tight_layout()
if filename is not None:
plt.savefig(f'../plots/{filename}')
plt.show()


def plot_cumulative_histogram_samples(data, variable_of_interest, treatment_col, treatment, x_label,
initial_random_seed=42):
initial_random_seed=42, filenames=None):
total_samples = []
max_samples = 500
step = 10
filecount = 1
filecount = 0
random_seed = initial_random_seed
subsample = data[data[treatment_col] == treatment]

Expand Down Expand Up @@ -440,43 +445,45 @@ def plot_cumulative_histogram_samples(data, variable_of_interest, treatment_col,
bin_maxes = np.maximum.reduceat(n, np.digitize([q1, q3], bins[:-1]) - 1)
max_density = max(bin_maxes)

# Shade the IQR region
plt.fill_betweenx(np.arange(0, max_density, 0.01), q1, q3, color='grey', alpha=0.3, label='IQR')

plt.title(f'{len(total_samples)} {treatment} Cells')
plt.xlabel(x_label)
plt.ylabel('Frequency (%)')
plt.ylim(bottom=0, top=40)
plt.xlim(left=0.4, right=0.9)
plt.grid(True)
plt.ylim(bottom=0, top=20)
plt.xlim(left=0, right=1)
# plt.grid(True)
plt.tight_layout()
if filenames is not None:
plt.savefig(f'../plots/{filenames[filecount]}')

plt.show()
filecount = filecount + 1

# print(
# f'Median: {np.median(sample_data)} IQR: {np.percentile(sample_data, 75) - np.percentile(sample_data, 25)}')

# Break the loop if we have included all available samples
if remaining_samples <= new_samples_count:
break

mean_values = [x - 0.4649 for x in mean_values]
median_values = [x - 0.5108 for x in median_values]
std_values = [x - 0.1306 for x in std_values]
iqr_values = [x - 0.1288 for x in iqr_values]
plt.figure(figsize=(14, 10))
ax1 = plt.gca()
ax1.scatter(sample_sizes, mean_values, label='_Mean', alpha=0.5, color='blue')
ax1.scatter(sample_sizes, median_values, label='_Median', alpha=0.5, color='orange')
ax1.plot(sample_sizes, mean_values, label='Mean', color='blue')
ax1.plot(sample_sizes, median_values, label='Median', color='orange')
ax1.set_ylabel(f'Mean, Median of {x_label}')
ax1.set_xlabel('Number of Cells')
ax2 = ax1.twinx()
ax2.scatter(sample_sizes, std_values, label='_Standard Deviation', alpha=0.5, color='gray')
ax2.scatter(sample_sizes, iqr_values, label='_IQR', alpha=0.5, color='purple')
ax2.plot(sample_sizes, std_values, label='Standard Deviation', color='gray')
ax2.plot(sample_sizes, iqr_values, label='IQR', color='purple')
ax2.set_ylabel(f'SD, IQR of {x_label}')
plt.scatter(sample_sizes, mean_values, label='_Mean', alpha=0.5, color='blue')
plt.scatter(sample_sizes, median_values, label='_Median', alpha=0.5, color='orange')
plt.plot(sample_sizes, mean_values, label='Mean', color='blue')
plt.plot(sample_sizes, median_values, label='Median', color='orange')
plt.ylabel(f'Difference relative to statistic evaluated\n for all cells in this well')
plt.xlabel('Number of Cells')
plt.axhline(y=0.0, color='black', linestyle='dotted', linewidth=2)
plt.scatter(sample_sizes, std_values, label='_Standard Deviation', alpha=0.5, color='gray')
plt.scatter(sample_sizes, iqr_values, label='_IQR', alpha=0.5, color='purple')
plt.plot(sample_sizes, std_values, label='Standard Deviation', color='gray')
plt.plot(sample_sizes, iqr_values, label='IQR', color='purple')
# Create a single legend
lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4)
# plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4)
plt.legend()
plt.tight_layout()
if filenames is not None:
plt.savefig(f'../plots/{filenames[filecount]}')
plt.show()


Expand Down Expand Up @@ -537,7 +544,8 @@ def plot_p_v_sample_size(sample_sizes, num_iterations, data, treatment_col, vari


def generate_superplot(plot_order, treatments, data, color_dict, treatment_col, variable_of_interest,
y_label, random_seed=42, fig_width=23, fig_height=12, sample_size=-1, point_size=3):
y_label, random_seed=42, fig_width=16, fig_height=8, sample_size=-1, point_size=3,
filename=None):
mean_data = pd.DataFrame()

for t in treatments:
Expand All @@ -553,38 +561,41 @@ def generate_superplot(plot_order, treatments, data, color_dict, treatment_col,

ReplicateAverages = mean_data.groupby([treatment_col, 'Replicate'], as_index=False).agg(
{variable_of_interest: "mean"})
plt.figure(figsize=(1.2 * fig_width, fig_height))
ax = plt.subplot(1, 1, 1)
sns.boxplot(x=treatment_col, y=variable_of_interest, data=ReplicateAverages, order=plot_order, color='white',
showfliers=False, linecolor='black', linewidth=2, zorder=1, boxprops=dict(facecolor='none'))
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", data=mean_data, size=1.1 * point_size,
order=plot_order,
zorder=0, palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", size=25, edgecolor="k", linewidth=2,
data=ReplicateAverages, order=plot_order, zorder=2,
palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
plt.ylim(bottom=0.0, top=1.0)
plt.xlabel('')
plt.ylabel(y_label)
plt.title(f'{sample_size} cells per population')
ax.legend_.remove()
plt.show()
plt.close()

plt.figure(figsize=(fig_width, fig_height))
ax = plt.subplot(1, 1, 1)
sns.boxplot(x=treatment_col, y=variable_of_interest, data=ReplicateAverages, order=plot_order, color='white',
showfliers=False, linecolor='black', linewidth=2, zorder=1, boxprops=dict(facecolor='none'))
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", data=mean_data, size=point_size,
order=plot_order,
zorder=0, palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", size=25, edgecolor="k", linewidth=2,
sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", size=20, edgecolor="k", linewidth=2,
data=ReplicateAverages, order=plot_order, zorder=2,
palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
plt.ylim(bottom=0.42, top=0.6)
plt.ylim(bottom=0.0, top=1.0)
plt.xlabel('')
plt.ylabel(y_label)
plt.title(f'{sample_size} cells per population')
ax.legend_.remove()
plt.tight_layout()
if filename is not None:
plt.savefig(f'../plots/{filename}')
plt.show()
plt.close()

# plt.figure(figsize=(fig_width, fig_height))
# ax = plt.subplot(1, 1, 1)
# sns.boxplot(x=treatment_col, y=variable_of_interest, data=ReplicateAverages, order=plot_order, color='white',
# showfliers=False, linecolor='black', linewidth=2, zorder=1, boxprops=dict(facecolor='none'))
# sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", data=mean_data, size=point_size,
# order=plot_order,
# zorder=0, palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
# sns.swarmplot(x=treatment_col, y=variable_of_interest, hue="Replicate", size=25, edgecolor="k", linewidth=2,
# data=ReplicateAverages, order=plot_order, zorder=2,
# palette={0: 'cornflowerblue', 1: 'gray', 2: 'orange'})
# plt.ylim(bottom=0.42, top=0.6)
# plt.xlabel('')
# plt.ylabel(y_label)
# plt.title(f'{sample_size} cells per population')
# ax.legend_.remove()
# plt.show()
# plt.close()