Skip to content

Commit 38b5bed

Browse files
authored
Merge pull request #115 from MannLabs/update_plots
Enhanced visualization controls and tests
2 parents 718ea1e + fc0b15d commit 38b5bed

7 files changed

Lines changed: 392 additions & 68 deletions

File tree

alphaquant/cluster/cluster_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,13 @@ def remove_outlier_fragion_childs(childs):
223223
"""Filters extreme fragment ions before aggregating to peptide level.
224224
225225
When a peptide has many fragment ions, this function selects a subset to avoid
226-
bias from extreme outliers. For >4 fragments, it keeps the 5 most central fragments
226+
bias from extreme outliers. For >4 fragments, it keeps the 4 most central fragments
227227
(ranked by z-value). For ≤4 fragments, all are retained.
228228
229+
This function also sets the is_outlier_fragment attribute on all child nodes to
230+
mark which fragments are excluded from aggregation (similar to is_outlier_peptide
231+
for peptides).
232+
229233
Args:
230234
childs: List of fragment ion nodes (children of a peptide node)
231235
@@ -241,9 +245,7 @@ def remove_outlier_fragion_childs(childs):
241245
idxs_to_use = sorted_idxs_zvals[:median_idx+1]
242246
else:
243247
idxs_to_use = sorted_idxs_zvals
244-
return [childs[idx] for idx in idxs_to_use]
245-
246-
if len(zvals) > 4:
248+
elif len(zvals) > 4:
247249
sorted_idxs_zvals = np.argsort(zvals)
248250
median_idx = math.floor(len(zvals)/2)
249251
idx_start = median_idx - 2
@@ -253,6 +255,11 @@ def remove_outlier_fragion_childs(childs):
253255
# When there are 4 or fewer children, use all of them
254256
idxs_to_use = list(range(len(childs)))
255257

258+
# Mark which fragments are outliers (excluded from aggregation)
259+
idxs_to_use_set = set(idxs_to_use)
260+
for i, child in enumerate(childs):
261+
child.is_outlier_fragment = i not in idxs_to_use_set
262+
256263
return [childs[idx] for idx in idxs_to_use]
257264

258265

alphaquant/plotting/fcviz.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@
66
import alphamap.organisms_data
77
import alphaquant.utils.utils as aq_utils
88
import alphaquant.resources.database_loader as aq_db_loader
9+
import re
10+
11+
def _format_tree_label_string(labelstring: str) -> str:
12+
"""Local copy of the tree label formatter to avoid circular imports.
13+
14+
Mirrors TreeLabelFormatter.format_label_string without importing treeviz.
15+
"""
16+
# Cut leading type classifier like 'SEQ_' etc.
17+
labelstring = re.sub(r'^[a-zA-Z0-9]+_', '', labelstring)
18+
# Remove leading/trailing underscores
19+
labelstring = labelstring.strip('_')
20+
# Remove default ion suffix
21+
labelstring = labelstring.replace('_noloss_1', '')
22+
# Replace separators with line breaks
23+
result = labelstring.replace('_', '\n')
24+
result = result.replace('[', '\n')
25+
result = result.replace(']', '\n')
26+
return result
927

1028
import alphaquant.config.config as aqconfig
1129
import logging
@@ -18,7 +36,9 @@ def __init__(self, condition1, condition2, results_directory, samplemap_file,
1836
order_along_protein_sequence = False, organism = 'Human',colorlist = aq_plot_base.AlphaQuantColorMap().colorlist, tree_level = 'seq', protein_identifier = 'gene_symbol', label_rotation = 90, add_stripplot = False,
1937
narrowing_factor_for_fcplot = 1/14, rescale_factor_x = 1.0, rescale_factor_y = 2,
2038
figsize = None, showfliers = True,
21-
show_node_annotations = False, node_annotation_attributes = None, node_annotation_formats = None):
39+
show_node_annotations = False, node_annotation_attributes = None, node_annotation_formats = None,
40+
hide_root_in_tree = False,
41+
exclude_outlier_fragments = True):
2242

2343
"""
2444
Class to visualize the peptide fold changes of a protein (precursor, fragment fcs etc an also be visualized). Can be initialized once and subsequently used to visualize different proteins with the visualize_protein function.
@@ -40,11 +60,13 @@ def __init__(self, condition1, condition2, results_directory, samplemap_file,
4060
show_node_annotations (bool): Whether to show statistical annotations on tree nodes.
4161
node_annotation_attributes (list): List of node attributes to display (e.g., ['p_val', 'z_val', 'fc']).
4262
node_annotation_formats (dict): Custom formatting for each attribute.
63+
exclude_outlier_fragments (bool): Whether to exclude outlier fragments from plots. Defaults to True.
4364
4465
"""
4566

4667
self.plotconfig = PlotConfig(label_rotation = label_rotation, add_stripplot = add_stripplot, narrowing_factor_for_fcplot = narrowing_factor_for_fcplot, rescale_factor_x = rescale_factor_x, rescale_factor_y = rescale_factor_y, colorlist = colorlist, protein_identifier = protein_identifier, tree_level = tree_level, organism = organism, order_peptides_along_protein_sequence=order_along_protein_sequence, figsize=figsize, showfliers=showfliers,
47-
show_node_annotations=show_node_annotations, node_annotation_attributes=node_annotation_attributes, node_annotation_formats=node_annotation_formats)
68+
show_node_annotations=show_node_annotations, node_annotation_attributes=node_annotation_attributes, node_annotation_formats=node_annotation_formats,
69+
hide_root_in_tree=hide_root_in_tree, exclude_outlier_fragments=exclude_outlier_fragments)
4870

4971
self.quantification_info = CondpairQuantificationInfo((condition1, condition2), results_directory, samplemap_file)
5072

@@ -85,7 +107,13 @@ class PlotConfig():
85107
def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_for_fcplot = 1/14, rescale_factor_x = 1.0, rescale_factor_y = 2,
86108
colorlist = aq_plot_base.AlphaQuantColorMap().colorlist, protein_identifier = 'gene_symbol', tree_level = 'seq', organism = 'Human',
87109
order_peptides_along_protein_sequence = False, figsize = None, showfliers = True,
88-
show_node_annotations = False, node_annotation_attributes = None, node_annotation_formats = None, node_fontsize = 12):
110+
show_node_annotations = False, node_annotation_attributes = None, node_annotation_formats = None, node_fontsize = 12,
111+
tree_to_fc_height_ratio = 1.0, subplot_spacing = 0.3,
112+
node_size = 600,
113+
shortened_xticklabels = False,
114+
remove_leaf_labels_in_tree = False,
115+
hide_root_in_tree = False,
116+
exclude_outlier_fragments = True):
89117
"""
90118
Configuration class for plotting.
91119
@@ -103,6 +131,9 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_
103131
node_annotation_attributes (list): List of node attributes to display (e.g., ['p_val', 'z_val', 'fc']).
104132
node_annotation_formats (dict): Custom formatting for each attribute.
105133
node_fontsize (int): Font size for tree node labels.
134+
exclude_outlier_fragments (bool): Whether to exclude fragment ions marked as outliers from plots.
135+
When True (default), only fragments used in statistical aggregation are displayed.
136+
Mirrors the fragment_outlier_filtering behavior from the analysis pipeline.
106137
"""
107138
self.label_rotation = label_rotation
108139
self.add_stripplot = add_stripplot
@@ -115,6 +146,13 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_
115146
self.figsize = figsize
116147
self.showfliers = showfliers
117148
self.node_fontsize = node_fontsize
149+
self.tree_to_fc_height_ratio = tree_to_fc_height_ratio
150+
self.subplot_spacing = subplot_spacing
151+
self.node_size = node_size
152+
self.shortened_xticklabels = shortened_xticklabels
153+
self.remove_leaf_labels_in_tree = remove_leaf_labels_in_tree
154+
self.hide_root_in_tree = hide_root_in_tree
155+
self.exclude_outlier_fragments = exclude_outlier_fragments
118156

119157
# Node annotation configuration
120158
self.show_node_annotations = show_node_annotations
@@ -291,7 +329,10 @@ def __init__(self, protein_node, quantification_info : CondpairQuantificationInf
291329

292330
def _init_melted_df(self):
293331
protein_intensity_df_getter = ProteinIntensityDataFrameGetter(self._protein_node, self._quantification_info)
294-
self._melted_df = protein_intensity_df_getter.get_melted_df_all(self._plotconfig.parent_level)
332+
self._melted_df = protein_intensity_df_getter.get_melted_df_all(
333+
self._plotconfig.parent_level,
334+
exclude_outlier_fragments=self._plotconfig.exclude_outlier_fragments
335+
)
295336

296337
def _define_parent2elements(self):# for example you have precursor as a parent and ms1 and ms2 as the leafs
297338
if self._parent2elements is None:
@@ -310,7 +351,14 @@ def _plot_all_child_elements(self):
310351
melted_df_subset = self._subset_to_elements(self._melted_df, elements)
311352
colormap = ClusterColorMapper(self._plotconfig.colorlist).get_element2color(melted_df_subset)
312353
ProteinPlot = IonFoldChangePlotter(melted_df=melted_df_subset, condpair = self._quantification_info.condpair, plotconfig=self._plotconfig)
313-
ProteinPlot.plot_fcs_with_specified_color_scheme(colormap,self._axes[idx])
354+
355+
# Build xticklabels from the actual leaf node labels used in the tree (base part only)
356+
xticklabels = None
357+
if getattr(self._plotconfig, 'shortened_xticklabels', False):
358+
name2label = self._map_specified_level_to_formatted_leaf_label_base()
359+
xticklabels = [name2label.get(name, name) for name in ProteinPlot.precursors]
360+
361+
ProteinPlot.plot_fcs_with_specified_color_scheme(colormap, self._axes[idx], xticklabels=xticklabels)
314362
#self._set_title_of_subplot(ax = self._axes[idx], peptide_nodes = cluster_sorted_groups_of_peptide_nodes[idx], first_subplot=idx==0)
315363
self._set_yaxes_to_same_scale()
316364
self._set_title()
@@ -371,6 +419,27 @@ def _get_color_from_list(self, idx):
371419
modulo_idx = idx % (len(self._colormap)) #if idx becomes larger than the list length, start at 0 again
372420
return self._colormap[modulo_idx]
373421

422+
def _map_specified_level_to_formatted_leaf_label_base(self):
423+
"""Create mapping from node.name (specified level) to the formatted base label used in the tree.
424+
425+
For leaf nodes, the tree label is built from node.name_reduced and formatted with
426+
the same rules as the tree. We return only the first line (base ion like 'y5').
427+
"""
428+
mapping = {}
429+
try:
430+
level_nodes = anytree.findall(self._protein_node, filter_=lambda x: hasattr(x, 'children'))
431+
for n in level_nodes:
432+
try:
433+
base_source = getattr(n, 'name_reduced', n.name)
434+
formatted = _format_tree_label_string(base_source)
435+
base = formatted.split('\n')[0]
436+
mapping[n.name] = base
437+
except Exception:
438+
mapping[n.name] = n.name
439+
except Exception:
440+
pass
441+
return mapping
442+
374443
def _label_x_and_y(self):
375444
self._fig.supylabel("log2(FC)")
376445

@@ -409,9 +478,9 @@ def __init__(self, protein_node, quantification_info : CondpairQuantificationInf
409478
self._quantification_info= quantification_info
410479
self._ion_header = ion_header
411480

412-
def get_melted_df_all(self, specified_level):
481+
def get_melted_df_all(self, specified_level, exclude_outlier_fragments=True):
413482
melted_df = ProteinIntensityDfFormatter( self._protein_node, self._quantification_info, self._ion_header).get_melted_protein_ion_intensity_table()
414-
melted_df = ProteinQuantDfAnnotator(self._protein_node, specified_level).get_annotated_melted_df(melted_df)
483+
melted_df = ProteinQuantDfAnnotator(self._protein_node, specified_level, exclude_outlier_fragments=exclude_outlier_fragments).get_annotated_melted_df(melted_df)
415484
return melted_df
416485

417486
def get_melted_df_selected_peptides(self, protein_id, selected_peptides, specified_level):
@@ -464,9 +533,10 @@ def _melt_protein_dataframe(self, protein_df):
464533

465534
class ProteinQuantDfAnnotator():
466535

467-
def __init__(self, protein_node, specified_level):
536+
def __init__(self, protein_node, specified_level, exclude_outlier_fragments=True):
468537
self._protein_node = protein_node
469538
self._specified_level = specified_level
539+
self._exclude_outlier_fragments = exclude_outlier_fragments
470540

471541
self._ion2is_included = {}
472542
self._ion2ml_score = {}
@@ -512,6 +582,10 @@ def _fill_ion_mapping_dicts(self):
512582
for level_node in level_nodes:
513583
for child in level_node.children:
514584
for leaf in child.leaves:
585+
# Skip fragment ions that were filtered out during aggregation (if flag is enabled)
586+
if self._exclude_outlier_fragments and hasattr(leaf, 'is_outlier_fragment') and leaf.is_outlier_fragment:
587+
continue
588+
515589
self._ion2is_included[leaf.name] = aqclustutils.check_if_node_is_included(child)
516590
self._ion2ml_score[leaf.name] = self._get_ml_score_if_possible(child)
517591
self._ion2level[leaf.name] = child.name
@@ -630,7 +704,7 @@ def plot_fcs_ml_score_unicolor(self, color, ax = None):
630704
self.plot_fcs_with_specified_color_scheme(colormap_single_color, ax)
631705
return ax
632706

633-
def plot_fcs_with_specified_color_scheme(self, colormap, ax):
707+
def plot_fcs_with_specified_color_scheme(self, colormap, ax, xticklabels=None):
634708
if type(colormap) == type(dict()):
635709
colormap = {idx: colormap.get(self.precursors[idx]) for idx in range(len(self.precursors))}
636710

@@ -640,12 +714,18 @@ def plot_fcs_with_specified_color_scheme(self, colormap, ax):
640714
self._plot_fcs_with_boxplot(colormap, ax)
641715

642716
idxs = list(range(len(self.precursors)))
643-
ax.set_xticks(idxs, labels = self.precursors, rotation = 'vertical')
717+
if xticklabels is not None:
718+
ax.set_xticks(idxs, labels=xticklabels, rotation='vertical')
719+
elif getattr(self._plotconfig, 'shortened_xticklabels', False):
720+
# Fallback: Only the base part of the label from the ion names
721+
formatted = [_format_tree_label_string(x).split('\n')[0] for x in self.precursors]
722+
ax.set_xticks(idxs, labels=formatted, rotation='vertical')
723+
else:
724+
ax.set_xticks(idxs, labels=self.precursors, rotation='vertical')
644725

645726
def _plot_fcs_with_swarmplot(self, colormap, ax):
646-
sns.stripplot(data = self.fcs, ax=ax, palette=colormap)
647-
sns.boxplot(data = self.fcs, ax=ax, showfliers=self._plotconfig.showfliers,
648-
boxprops=dict(facecolor="none", edgecolor="black"))
727+
sns.stripplot(data = self.fcs, ax=ax, color='#404040', alpha=1.0, size=3)
728+
sns.boxplot(data = self.fcs, ax=ax, palette=colormap, showfliers=self._plotconfig.showfliers)
649729

650730
def _plot_fcs_with_boxplot(self, colormap, ax):
651731
sns.boxplot(data = self.fcs, ax=ax, palette=colormap, showfliers=self._plotconfig.showfliers)

alphaquant/plotting/pairwise.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ def plot_sample_vs_median_fcs(df_c1_normed, df_c2_normed):
8383

8484

8585

86-
def volcano_plot(results_df, fc_header="log2fc", fdr_header="fdr", fdr_cutoff=0.05,
86+
def volcano_plot(results_df, fc_header="log2fc", fdr_header="fdr", fdr_cutoff=0.05,
8787
log2fc_cutoff=0.5, xlim=None, ylim = None,
88-
organism_column=None, organism2color_dict=None,
88+
organism_column=None, organism2color_dict=None,
8989
color_only_significant=True, alpha= None,ax = None,
9090
draw_vertical_lines = True, draw_horizontal_lines = True,
91-
ground_truth_ratios = None):
92-
91+
ground_truth_ratios = None, point_size=None):
92+
9393
results_df[fdr_header] = results_df[fdr_header].replace(0, np.min(results_df[fdr_header].replace(0, 1.0)))
9494
fdrs = results_df[fdr_header].to_numpy()
9595
fcs = results_df[fc_header].to_numpy()
@@ -113,11 +113,15 @@ def volcano_plot(results_df, fc_header="log2fc", fdr_header="fdr", fdr_cutoff=0.
113113
if alpha is None:
114114
alpha = max(0.1, min(0.7, 0.7 - 0.6 * (len(fdrs) / 1000)))
115115

116-
scatter = sns.scatterplot(data=results_df, x=fc_header, y='-log10(fdr)',
117-
c=results_df['color'].to_list(), ax=ax, legend=None, alpha = alpha)
116+
# Set default point size if not provided
117+
if point_size is None:
118+
point_size = 20 # Default seaborn size
119+
120+
scatter = sns.scatterplot(data=results_df, x=fc_header, y='-log10(fdr)',
121+
c=results_df['color'].to_list(), ax=ax, legend=None, alpha = alpha, s=point_size)
118122
for scatter_collection in scatter.collections:
119123
scatter_collection.set_rasterized(True)
120-
124+
121125
# Drawing vertical lines for fold change thresholds and horizontal lines for p-value threshold
122126
if draw_vertical_lines:
123127
if log2fc_cutoff !=0:
@@ -126,18 +130,18 @@ def volcano_plot(results_df, fc_header="log2fc", fdr_header="fdr", fdr_cutoff=0.
126130
if draw_horizontal_lines:
127131
if fdr_cutoff !=0:
128132
ax.axhline(y=-np.log10(fdr_cutoff), linestyle='--', color='black')
129-
133+
130134
ax.set_xlabel("log2(FC)")
131135
ax.set_ylabel("-log10(FDR)")
132136

133137

134-
138+
135139
if xlim is None:
136140
maxfc = max(abs(results_df[fc_header])) + 0.5
137141
ax.set_xlim(-maxfc, maxfc)
138142
else:
139143
ax.set_xlim(xlim)
140-
144+
141145
if ylim:
142146
ax.set_ylim(ylim)
143147

@@ -154,7 +158,7 @@ def volcano_plot(results_df, fc_header="log2fc", fdr_header="fdr", fdr_cutoff=0.
154158
def add_color_column(results_df ,organism2color_dict, organism_column, color_only_significant):
155159
# Create a color column based on organism and significance
156160
if organism2color_dict:
157-
results_df['color'] = results_df.apply(lambda row: organism2color_dict[row[organism_column]]
161+
results_df['color'] = results_df.apply(lambda row: organism2color_dict[row[organism_column]]
158162
if row['is_significant'] or not color_only_significant else 'gray', axis=1)
159163
else:
160164
results_df['color'] = np.where(results_df['is_significant'], 'green', 'gray')

alphaquant/plotting/tree_and_fc_viz.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self, protein_node, quantification_info: aqfcviz.CondpairQuantifica
1919
self._plotconfig = plotconfig
2020
self._shorten_protein_to_level()
2121
self._sort_tree_according_to_plotconfig()
22+
self._maybe_hide_root_in_tree_for_single_branch()
2223
self._define_fig_and_ax()
2324
self._plot_tree()
2425
self._plot_fcs()
@@ -31,6 +32,19 @@ def _shorten_protein_to_level(self):
3132
def _sort_tree_according_to_plotconfig(self):
3233
self._protein_node = aqtreeutils.TreeSorter(self._plotconfig, self._protein_node).get_sorted_tree()
3334

35+
def _maybe_hide_root_in_tree_for_single_branch(self):
36+
try:
37+
already_set = getattr(self._plotconfig, 'hide_root_in_tree', False)
38+
except Exception:
39+
already_set = False
40+
41+
if not already_set:
42+
try:
43+
if hasattr(self._protein_node, 'children') and len(self._protein_node.children) <= 1:
44+
self._plotconfig.hide_root_in_tree = True
45+
except Exception:
46+
pass
47+
3448
def _define_fig_and_ax(self):
3549
axis_creator = aqtreeviz.TreePlotAxisCreator(self._protein_node, self._plotconfig)
3650
axis_creator.define_combined_tree_fc_fig_and_axes()

0 commit comments

Comments
 (0)