diff --git a/alphaquant/cluster/residual_decorrelation.py b/alphaquant/cluster/residual_decorrelation.py index 79b33e33..4519019f 100644 --- a/alphaquant/cluster/residual_decorrelation.py +++ b/alphaquant/cluster/residual_decorrelation.py @@ -1,3 +1,72 @@ +"""Residual decorrelation: remove correlated siblings before z-score aggregation. + +Overview +-------- +When AlphaQuant aggregates child nodes (e.g. peptides → protein, fragments → +peptide) using Stouffer's method, inflated sibling correlations bias the +combined z-score upward. This module identifies and prunes the most +correlated children at every level of the ion tree so that the surviving +set's pairwise correlation distribution matches a condition-shuffled null. + +Algorithm +--------- +1. **Residual computation** (``attach_lm_residuals``): + For every base ion the within-condition mean intensity is subtracted, + yielding condition-mean residuals. Residuals for higher-level nodes are + the row-wise mean of their children's residuals, propagated bottom-up. + These residuals capture shared technical variation independently of the + fold-change signal. + +2. **Per-parent precomputation** (``_build_parent``, ``ParentPrecompute``): + For each parent node at a given level the children's residual vectors are + stacked into a matrix and a Pearson correlation matrix ``C`` is computed. + A greedy removal order is then computed once: at each step the child with + the highest mean pairwise correlation to the surviving set is removed. + The maximum pairwise correlation after each removal is stored as + ``max_r_trajectory``, making it cheap to replay any cutoff later via + ``survivors_at``. + +3. **Null distribution** (``_cross_parent_shuffle_null``): + Rows are permuted across parents (cross-parent shuffle) to produce a + baseline that represents what sibling correlations look like when children + are exchanged between unrelated proteins. + +4. **Level sweep** (``run_level_sweep``): + A grid of correlation cutoffs (default 1.0 → 0.1) is scanned. For each + cutoff the surviving correlation values across all parents are collected + and compared to the null via a one-sided excess-CDF distance ``D`` + (``_excess_cdf_distance``): the maximum over all ``r`` of + ``F_null(r) − F_corrected(r)``, i.e. how much the corrected distribution + still exceeds the null. The lowest cutoff with ``D ≤ tolerance`` is + chosen. If none qualifies the tightest cutoff is used regardless. + +5. **Application** (``apply_residual_decorrelation``): + Main entry point. Runs steps 1–4 for every ``LEVEL_PAIRS`` pair, marks + pruned children with ``node.exclude_residual_decorrelation = True``, then + re-aggregates node statistics bottom-up with the decorrelation-aware + aggregation mode. + +Structure +--------- +Data classes + ParentPrecompute – precomputed correlation matrix + removal trajectory + LevelSweepResult – outcome of one level sweep (cutoffs, distances, traces) + +Internal helpers + _node_matches_level – type-string → node matching + _build_parent – build ParentPrecompute for one parent node + _pair_rs_from_C – extract upper-triangle r values given survivors + _cross_parent_shuffle_null – permutation-based null distribution + _aggregate_pair_rs – collect r values across parents at a cutoff + _excess_cdf_distance – one-sided CDF distance metric + +Public API + attach_lm_residuals – attach within-condition residuals to tree nodes + run_level_sweep – sweep cutoffs for one (parent, child) level pair + apply_residual_decorrelation – orchestrate the full pipeline (main entry point) + plot_level_sweep_cdfs – CDF comparison plot for one level result + plot_level_sweep_diagnostics – full diagnostic figure (CDF + sweep trace) +""" from __future__ import annotations from dataclasses import dataclass, field @@ -34,6 +103,21 @@ @dataclass class ParentPrecompute: + """Precomputed correlation structure for one parent node. + + Built once per parent by ``_build_parent``; the removal order and + max-r trajectory are stored so that ``survivors_at`` can replay any + cutoff in O(n) without recomputing correlations. + + Attributes + ---------- + parent_node: the tree node this precompute belongs to + child_nodes: ordered tuple of children whose residuals were used + C: (n × n) Pearson correlation matrix; diagonal is NaN + remove_order: greedy removal sequence (indices into child_nodes) + max_r_trajectory: max pairwise r after removing k children (length n) + """ + parent_node: anytree.Node child_nodes: tuple[anytree.Node, ...] C: np.ndarray @@ -41,15 +125,24 @@ class ParentPrecompute: max_r_trajectory: np.ndarray def survivors_at(self, cutoff: float, min_keep: int) -> np.ndarray: + """Return a boolean mask of children that survive at ``cutoff``. + + Replays the greedy removal order until ``max_r_trajectory`` drops + below ``cutoff``, always retaining at least ``min_keep`` children. + """ n = self.C.shape[0] if n == 0: return np.zeros(0, dtype=bool) + # at most n-min_keep children may be removed k_max = max(0, n - min_keep) + # scan the trajectory: stop at the first step where max_r is already below cutoff k = 0 while k <= k_max and self.max_r_trajectory[k] > cutoff: k += 1 + # clamp in case the loop overshot (shouldn't happen with monotone trajectory) if k > k_max: k = k_max + # replay: mark the first k entries of the greedy removal order as dead alive = np.ones(n, dtype=bool) if k > 0: alive[self.remove_order[:k]] = False @@ -58,6 +151,22 @@ def survivors_at(self, cutoff: float, min_keep: int) -> np.ndarray: @dataclass class LevelSweepResult: + """Outcome of a full cutoff sweep for one (parent_level, child_level) pair. + + Attributes + ---------- + level: (parent_level, child_level) string pair + cutoff: chosen correlation cutoff + d_before/d_after: excess CDF distance before and after pruning + n_parents: total parents examined at this level + parents_touched: parents where at least one child was dropped + children_dropped: total children marked for exclusion + grid_trace: list of (cutoff, distance, dropped, touched) per grid step + unmodified_sorted: sorted pairwise r values before pruning (for plotting) + corrected_sorted: sorted pairwise r values after pruning (for plotting) + null_sorted: sorted null distribution r values (for plotting) + """ + level: tuple[str, str] cutoff: float d_before: float @@ -78,15 +187,25 @@ def _node_matches_level(node, level: str) -> bool: def _build_parent(parent_node, child_nodes, mat): + """Build a ParentPrecompute by greedily computing the removal order. + + At each step the child with the highest mean pairwise correlation to the + remaining set is removed. The maximum pairwise r after each removal is + recorded in ``max_r_trajectory`` and monotonically enforced (non-increasing) + so that ``survivors_at`` can use a simple threshold scan. + """ if mat.shape[0] < 2: return None + # compute full Pearson correlation matrix across children's residual vectors with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"invalid value encountered") C = np.corrcoef(mat) + # constant rows produce NaN correlations; replace with 0 so they don't distort means if not np.all(np.isfinite(C)): C = np.nan_to_num(C, nan=0.0, posinf=1.0, neginf=-1.0) C = C.copy() + # NaN on diagonal so nanmean excludes self-correlation when averaging rows np.fill_diagonal(C, np.nan) n = C.shape[0] @@ -94,6 +213,7 @@ def _build_parent(parent_node, child_nodes, mat): remove_order = [] max_r = [] + # record the max pairwise r before any removal (step 0 of the trajectory) with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN slice encountered") init_max = float(np.nanmax(C)) if n >= 2 else -np.inf @@ -102,15 +222,19 @@ def _build_parent(parent_node, child_nodes, mat): max_r.append(init_max) while alive.sum() > 1: + # extract the submatrix of currently surviving children sub_idx = np.where(alive)[0] cc = C[np.ix_(sub_idx, sub_idx)] with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"Mean of empty slice") warnings.filterwarnings("ignore", r"All-NaN slice encountered") mean_r = np.nanmean(cc, axis=1) + # the "worst" child is the one most correlated on average with its siblings; + # replace NaN with -inf so it is never chosen as worst (avoids argmax crash) worst_local = int(np.nanargmax(np.where(np.isnan(mean_r), -np.inf, mean_r))) alive[sub_idx[worst_local]] = False remove_order.append(sub_idx[worst_local]) + # record the new maximum pairwise r among survivors after this removal if alive.sum() >= 2: rem = np.where(alive)[0] cc2 = C[np.ix_(rem, rem)] @@ -124,6 +248,8 @@ def _build_parent(parent_node, child_nodes, mat): max_r.append(m) traj = np.asarray(max_r, dtype=np.float64) + # enforce monotone non-increasing: numerical noise can cause slight upward bumps + # which would break the threshold scan in survivors_at for i in range(1, traj.size): if traj[i] > traj[i - 1]: traj[i] = traj[i - 1] @@ -142,25 +268,39 @@ def _pair_rs_from_C(C: np.ndarray, survivors: np.ndarray) -> np.ndarray: return np.empty(0, dtype=np.float64) idx = np.where(survivors)[0] sub = C[np.ix_(idx, idx)] + # k=1 skips the diagonal, giving only unique off-diagonal pairs iu = np.triu_indices(sub.shape[0], k=1) vals = sub[iu] + # drop NaN entries that arise from originally constant residual vectors vals = vals[~np.isnan(vals)] return vals.astype(np.float64, copy=False) def _cross_parent_shuffle_null(mats: list[np.ndarray], rng: np.random.Generator) -> np.ndarray: + """Build a null correlation distribution by shuffling rows across parents. + + All residual rows from every parent are pooled, randomly permuted, then + re-assigned to groups of the original sizes. Pairwise correlations within + each group are computed and concatenated. This destroys within-parent + structure while preserving group sizes, giving the expected correlation + distribution under the null hypothesis that children are unrelated. + """ if not mats: return np.empty(0, dtype=np.float64) + # remember original group sizes so shuffled rows can be re-partitioned identically sizes = [m.shape[0] for m in mats] + # pool all residual rows from all parents into one matrix and shuffle pool = np.vstack(mats) pool = pool[rng.permutation(pool.shape[0])] out = [] idx = 0 for size in sizes: + # re-assign the next 'size' shuffled rows to this group chunk = pool[idx:idx + size] idx += size if size < 2: continue + # skip constant rows: they produce NaN correlations and add no information keep = chunk.std(axis=1) > 0 if keep.sum() < 2: continue @@ -190,12 +330,22 @@ def _aggregate_pair_rs(parents: list[ParentPrecompute], cutoff: float, min_keep: def _excess_cdf_distance(corrected: np.ndarray, null_sorted: np.ndarray) -> float: + """One-sided excess CDF distance between corrected and null distributions. + + Returns max over r of ``F_null(r) − F_corrected(r)``, clipped to 0 from + below. A value of 0 means the corrected distribution is nowhere above the + null; higher values indicate residual excess correlation. + """ if corrected.size == 0 or null_sorted.size == 0: return 0.0 corr_sorted = np.sort(corrected) + # evaluate both CDFs on the union of all observed r values grid = np.unique(np.concatenate([corr_sorted, null_sorted])) + # searchsorted with side="right" gives F(r) = P(X ≤ r) at each grid point f_corr = np.searchsorted(corr_sorted, grid, side="right") / corr_sorted.size f_null = np.searchsorted(null_sorted, grid, side="right") / null_sorted.size + # one-sided: only penalise when the corrected distribution exceeds the null + # (F_null > F_corr means more mass above r in corrected than null → excess correlation) return float(np.max(np.maximum(f_null - f_corr, 0.0))) @@ -208,6 +358,28 @@ def run_level_sweep( min_keep: int = DEFAULT_MIN_KEEP, level: tuple[str, str] = ("", ""), ): + """Sweep correlation cutoffs and return the mildest one within tolerance. + + For each cutoff in ``cutoff_grid`` (scanned from loose to tight), the + surviving pairwise r values are collected and the excess CDF distance to + the null is computed. The first cutoff whose distance falls at or below + ``tolerance`` is chosen. If none qualifies the tightest cutoff is used + unconditionally. + + Parameters + ---------- + parents: precomputed parent structures for this level pair + null_sorted: sorted null pairwise r values (from cross-parent shuffle) + cutoff_grid: correlation thresholds to scan, ordered loose → tight + tolerance: maximum allowed excess CDF distance D after pruning + min_keep: minimum children to retain per parent regardless of cutoff + level: (parent_level, child_level) label stored in the result + + Returns + ------- + LevelSweepResult with the chosen cutoff and full diagnostic traces. + """ + # measure baseline excess distance with no pruning (cutoff = 1.0 keeps everyone) baseline = _aggregate_pair_rs(parents, 1.0, min_keep) d_before = _excess_cdf_distance(baseline, null_sorted) @@ -228,9 +400,12 @@ def run_level_sweep( corrected = np.concatenate(chunks) if chunks else np.empty(0, dtype=np.float64) d = _excess_cdf_distance(corrected, null_sorted) trace.append((cutoff, d, dropped, touched)) + # take the first (loosest) cutoff that already satisfies the tolerance — + # prefer dropping as few children as possible if chosen is None and d <= tolerance: chosen = (cutoff, d, corrected, dropped, touched) + # if no cutoff reached tolerance, fall back to the tightest one in the grid if chosen is None: cutoff, d, dropped, touched = trace[-1] corrected = _aggregate_pair_rs(parents, cutoff, min_keep) @@ -260,25 +435,32 @@ def attach_lm_residuals(protnodes, df_c1_normed, df_c2_normed, min_n_per_cond=2) start from those files must convert them back to log2 before using this helper. """ + # build a single intensity matrix with all samples from both conditions X = pd.concat([df_c1_normed, df_c2_normed], axis=1) c1_cols = list(df_c1_normed.columns) c2_cols = list(df_c2_normed.columns) X = X.astype(float) + # compute within-condition means per ion (row-wise) m1 = X[c1_cols].mean(axis=1, skipna=True) m2 = X[c2_cols].mean(axis=1, skipna=True) + # subtract the within-condition mean: residual = intensity - condition mean + # this removes the fold-change signal, leaving only condition-independent noise res = X.copy() res[c1_cols] = X[c1_cols].sub(m1, axis=0) res[c2_cols] = X[c2_cols].sub(m2, axis=0) + # mask ions with too few valid values in either condition — their residuals are unreliable n1_ok = X[c1_cols].notna().sum(axis=1) >= int(min_n_per_cond) n2_ok = X[c2_cols].notna().sum(axis=1) >= int(min_n_per_cond) res.loc[~(n1_ok & n2_ok), :] = np.nan for protnode in protnodes: + # initialise residuals to None on every node before filling for node in PreOrderIter(protnode): node.residuals = None + # assign residual vectors to base (leaf) ions by matching ion name to the matrix index for node in PreOrderIter(protnode): if node.type != "base": continue @@ -287,6 +469,9 @@ def attach_lm_residuals(protnodes, df_c1_normed, df_c2_normed, min_n_per_cond=2) else: node.residuals = None + # propagate residuals bottom-up: each non-base node gets the column-wise mean + # of its children's residual vectors, so higher-level nodes carry an averaged + # representation of shared technical variation across their subtree for level_nodes in aqcluster_utils.iterate_through_tree_levels_bottom_to_top(protnode): for node in level_nodes: if node.type == "base": @@ -304,6 +489,7 @@ def attach_lm_residuals(protnodes, df_c1_normed, df_c2_normed, min_n_per_cond=2) warnings.filterwarnings("ignore", message="Mean of empty slice") with np.errstate(all="ignore"): mean_vec = np.nanmean(stacked, axis=0) + # if all samples are NaN after averaging, treat as missing node.residuals = None if np.all(np.isnan(mean_vec)) else mean_vec @@ -319,12 +505,15 @@ def _collect_level_parents(protnodes, parent_level, child_level): if not _node_matches_level(child, child_level): continue v = getattr(child, "residuals", None) + # skip children without residuals or with any NaN sample — + # NaN entries would propagate into the correlation matrix if v is None or not isinstance(v, np.ndarray): continue if np.any(np.isnan(v)): continue child_nodes.append(child) vecs.append(v) + # need at least 2 children to form a correlation matrix if len(vecs) < 2: continue mat = np.vstack(vecs) @@ -345,18 +534,49 @@ def apply_residual_decorrelation( aggregation_mode="stouffer_decorrelation", null_seed=42, ): + """Main entry point: run full residual decorrelation on a list of protein nodes. + + Steps + ----- + 1. Attach within-condition mean residuals to every node (``attach_lm_residuals``). + 2. For each level pair in ``LEVEL_PAIRS``, build a cross-parent shuffle null, + run the cutoff sweep, and mark pruned children with + ``node.exclude_residual_decorrelation = True``. + 3. Optionally apply PTM fragment selection if ``PTM_FRAGMENT_SELECTION`` is set. + 4. Re-aggregate all node statistics bottom-up using ``aggregation_mode``. + 5. Strip residual arrays from nodes to keep the tree serializable. + + Parameters + ---------- + protnodes: list of root protein nodes (anytree) + df_c1_normed: log2-normalized intensities for condition 1 (ions × samples) + df_c2_normed: log2-normalized intensities for condition 2 (ions × samples) + tolerance: maximum excess CDF distance D allowed after pruning + min_keep: minimum children retained per parent at each level + cutoff_grid: correlation cutoffs to sweep per level pair + aggregation_mode: z-aggregation mode used when re-aggregating after pruning + null_seed: random seed for the cross-parent shuffle null + + Returns + ------- + pandas.DataFrame summarising cutoffs, distances, and drop counts per level. + """ + # reset exclusion flags in case this function is called more than once on the same nodes for protnode in protnodes: for node in PreOrderIter(protnode): node.exclude_residual_decorrelation = False node.exclude_ptm_fragment_selection = False + # step 1: compute within-condition residuals and attach them to every node attach_lm_residuals(protnodes, df_c1_normed, df_c2_normed) rng = np.random.default_rng(null_seed) level_results = [] + # step 2: run the sweep for each level pair independently for parent_level, child_level in LEVEL_PAIRS: parents = _collect_level_parents(protnodes, parent_level, child_level) + # build the null from the same residual matrices used for the real sweep mats = [ np.vstack([child.residuals for child in pp.child_nodes]) for pp in parents @@ -381,12 +601,14 @@ def apply_residual_decorrelation( LOGGER.info(msg) print(msg, flush=True) + # mark children that did not survive the chosen cutoff for pp in parents: survivors = pp.survivors_at(sweep.cutoff, min_keep) for keep, child in zip(survivors, pp.child_nodes): if not keep: child.exclude_residual_decorrelation = True + # step 3 (optional): apply PTM fragment selection on top of decorrelation exclusions if aqvariables.PTM_FRAGMENT_SELECTION: n_ptm_dropped, n_ptm_parents = aqcluster_utils.apply_ptm_fragment_selection( protnodes, @@ -404,6 +626,7 @@ def apply_residual_decorrelation( flush=True, ) + # step 4: re-aggregate node statistics bottom-up now that exclusion flags are set for protnode in protnodes: for level_nodes in aqcluster_utils.iterate_through_tree_levels_bottom_to_top(protnode): for node in level_nodes: @@ -416,8 +639,8 @@ def apply_residual_decorrelation( aggregation_mode=aggregation_mode, ) - # Residual vectors are only needed during the sweep; remove them before - # downstream ML reordering / JSON export to keep the tree serializable. + # step 5: residual vectors are only needed during the sweep; remove them before + # downstream ML reordering / JSON export to keep the tree serializable for protnode in protnodes: for node in PreOrderIter(protnode): if hasattr(node, "residuals"):