|
| 1 | +""" pyplots.ai |
| 2 | +ma-differential-expression: MA Plot for Differential Expression |
| 3 | +Library: altair 6.0.0 | Python 3.14.3 |
| 4 | +Quality: 91/100 | Created: 2026-03-20 |
| 5 | +""" |
| 6 | + |
| 7 | +import altair as alt |
| 8 | +import numpy as np |
| 9 | +import pandas as pd |
| 10 | + |
| 11 | + |
| 12 | +# Data - Simulated RNA-seq differential expression results |
| 13 | +np.random.seed(42) |
| 14 | +n_genes = 15000 |
| 15 | + |
| 16 | +mean_expression = np.concatenate( |
| 17 | + [ |
| 18 | + np.random.exponential(3, n_genes // 3), |
| 19 | + np.random.uniform(0.5, 12, n_genes // 3), |
| 20 | + np.random.normal(6, 2.5, n_genes - 2 * (n_genes // 3)), |
| 21 | + ] |
| 22 | +) |
| 23 | +mean_expression = np.clip(mean_expression, 0.1, 16) |
| 24 | + |
| 25 | +log_fold_change = np.random.normal(0, 0.4, n_genes) |
| 26 | +n_up = 400 |
| 27 | +n_down = 350 |
| 28 | +up_idx = np.random.choice(n_genes, n_up, replace=False) |
| 29 | +remaining = np.setdiff1d(np.arange(n_genes), up_idx) |
| 30 | +down_idx = np.random.choice(remaining, n_down, replace=False) |
| 31 | +log_fold_change[up_idx] = np.random.uniform(1.0, 4.5, n_up) |
| 32 | +log_fold_change[down_idx] = np.random.uniform(-4.5, -1.0, n_down) |
| 33 | + |
| 34 | +significant = np.zeros(n_genes, dtype=bool) |
| 35 | +significant[up_idx] = True |
| 36 | +significant[down_idx] = True |
| 37 | + |
| 38 | +gene_names = [f"Gene{i}" for i in range(n_genes)] |
| 39 | + |
| 40 | +# Select top genes spread across expression range for labeling |
| 41 | +top_up_sorted = up_idx[np.argsort(-log_fold_change[up_idx])] |
| 42 | +top_down_sorted = down_idx[np.argsort(log_fold_change[down_idx])] |
| 43 | +up_names = ["BRCA1", "MYC", "EGFR"] |
| 44 | +down_names = ["PTEN", "RB1", "KRAS"] |
| 45 | +label_idx = np.concatenate([top_up_sorted[:3], top_down_sorted[:3]]) |
| 46 | +example_names = up_names + down_names |
| 47 | +for i, idx in enumerate(label_idx): |
| 48 | + gene_names[idx] = example_names[i] |
| 49 | + |
| 50 | +df = pd.DataFrame( |
| 51 | + { |
| 52 | + "mean_expression": mean_expression, |
| 53 | + "log_fold_change": log_fold_change, |
| 54 | + "significant": significant, |
| 55 | + "gene_name": gene_names, |
| 56 | + } |
| 57 | +) |
| 58 | +df["status"] = np.where( |
| 59 | + ~df["significant"], "Not significant", np.where(df["log_fold_change"] > 0, "Upregulated", "Downregulated") |
| 60 | +) |
| 61 | +df_labels = df.loc[df["gene_name"].isin(example_names)].copy() |
| 62 | +df_labels = df_labels.sort_values("mean_expression").reset_index(drop=True) |
| 63 | +df_labels["label_x"] = df_labels["mean_expression"].values |
| 64 | +df_labels["label_y"] = df_labels["log_fold_change"].values |
| 65 | +for i in range(1, len(df_labels)): |
| 66 | + if abs(df_labels.loc[i, "label_x"] - df_labels.loc[i - 1, "label_x"]) < 1.5: |
| 67 | + if abs(df_labels.loc[i, "label_y"] - df_labels.loc[i - 1, "label_y"]) < 0.8: |
| 68 | + df_labels.loc[i, "label_y"] += 0.5 * (1 if df_labels.loc[i, "log_fold_change"] > 0 else -1) |
| 69 | + |
| 70 | +# Separate non-significant and significant for layering with different sizes/opacity |
| 71 | +df_nonsig = df[~df["significant"]].copy() |
| 72 | +df_sig = df[df["significant"]].copy() |
| 73 | + |
| 74 | +# X and Y axis definitions (shared) |
| 75 | +x_axis = alt.X( |
| 76 | + "mean_expression:Q", |
| 77 | + title="Mean Expression (A)", |
| 78 | + axis=alt.Axis( |
| 79 | + labelFontSize=18, |
| 80 | + titleFontSize=22, |
| 81 | + titleColor="#333333", |
| 82 | + labelColor="#555555", |
| 83 | + domain=False, |
| 84 | + tickSize=6, |
| 85 | + tickColor="#999999", |
| 86 | + tickWidth=1, |
| 87 | + ), |
| 88 | +) |
| 89 | +y_axis = alt.Y( |
| 90 | + "log_fold_change:Q", |
| 91 | + title="Log₂ Fold Change (M)", |
| 92 | + axis=alt.Axis( |
| 93 | + labelFontSize=18, |
| 94 | + titleFontSize=22, |
| 95 | + titleColor="#333333", |
| 96 | + labelColor="#555555", |
| 97 | + domain=False, |
| 98 | + tickSize=6, |
| 99 | + tickColor="#999999", |
| 100 | + tickWidth=1, |
| 101 | + ), |
| 102 | +) |
| 103 | + |
| 104 | +# Background shading bands for fold-change regions |
| 105 | +fc_band_data = pd.DataFrame({"y": [-1], "y2": [1]}) |
| 106 | +fc_band = alt.Chart(fc_band_data).mark_rect(color="#F0F4F8", opacity=0.5).encode(y="y:Q", y2="y2:Q") |
| 107 | + |
| 108 | +# Non-significant points (small, faint gray) |
| 109 | +points_nonsig = ( |
| 110 | + alt.Chart(df_nonsig) |
| 111 | + .mark_point(filled=True, size=20, opacity=0.25, strokeWidth=0, color="#CCCCCC") |
| 112 | + .encode( |
| 113 | + x=x_axis, |
| 114 | + y=y_axis, |
| 115 | + tooltip=[ |
| 116 | + alt.Tooltip("gene_name:N", title="Gene"), |
| 117 | + alt.Tooltip("mean_expression:Q", title="Mean Expr", format=".2f"), |
| 118 | + alt.Tooltip("log_fold_change:Q", title="Log₂ FC", format=".2f"), |
| 119 | + ], |
| 120 | + ) |
| 121 | +) |
| 122 | + |
| 123 | +# Interactive selection for highlighting genes on hover |
| 124 | +highlight = alt.selection_point(on="pointerover", fields=["gene_name"], empty=False) |
| 125 | + |
| 126 | +# Significant points with shape encoding for accessibility (triangles=up, squares=down) |
| 127 | +color_scale = alt.Scale(domain=["Upregulated", "Downregulated"], range=["#D7263D", "#306998"]) |
| 128 | +shape_scale = alt.Scale(domain=["Upregulated", "Downregulated"], range=["triangle-up", "square"]) |
| 129 | +points_sig = ( |
| 130 | + alt.Chart(df_sig) |
| 131 | + .mark_point(filled=True, stroke="white", strokeWidth=0.5) |
| 132 | + .encode( |
| 133 | + x=alt.X("mean_expression:Q"), |
| 134 | + y=alt.Y("log_fold_change:Q"), |
| 135 | + color=alt.Color( |
| 136 | + "status:N", |
| 137 | + scale=color_scale, |
| 138 | + legend=alt.Legend( |
| 139 | + title=None, |
| 140 | + labelFontSize=16, |
| 141 | + symbolSize=200, |
| 142 | + orient="none", |
| 143 | + legendX=1250, |
| 144 | + legendY=5, |
| 145 | + direction="horizontal", |
| 146 | + padding=8, |
| 147 | + ), |
| 148 | + ), |
| 149 | + shape=alt.Shape("status:N", scale=shape_scale, legend=None), |
| 150 | + size=alt.condition(highlight, alt.value(160), alt.value(80)), |
| 151 | + tooltip=[ |
| 152 | + alt.Tooltip("gene_name:N", title="Gene"), |
| 153 | + alt.Tooltip("mean_expression:Q", title="Mean Expr", format=".2f"), |
| 154 | + alt.Tooltip("log_fold_change:Q", title="Log₂ FC", format=".2f"), |
| 155 | + alt.Tooltip("status:N", title="Status"), |
| 156 | + ], |
| 157 | + ) |
| 158 | + .add_params(highlight) |
| 159 | +) |
| 160 | + |
| 161 | +# Reference lines |
| 162 | +zero_line = alt.Chart(pd.DataFrame({"y": [0]})).mark_rule(color="#333333", strokeWidth=1.5, opacity=0.6).encode(y="y:Q") |
| 163 | + |
| 164 | +fc_thresholds = ( |
| 165 | + alt.Chart(pd.DataFrame({"y": [-1, 1]})) |
| 166 | + .mark_rule(color="#777777", strokeWidth=2, strokeDash=[8, 6], opacity=0.6) |
| 167 | + .encode(y="y:Q") |
| 168 | +) |
| 169 | + |
| 170 | +# LOESS smoothing curve |
| 171 | +loess_line = ( |
| 172 | + alt.Chart(df) |
| 173 | + .transform_loess("mean_expression", "log_fold_change", bandwidth=0.3) |
| 174 | + .mark_line(color="#D4770B", strokeWidth=2.5, opacity=0.7) |
| 175 | + .encode(x="mean_expression:Q", y="log_fold_change:Q") |
| 176 | +) |
| 177 | + |
| 178 | +# Gene labels for top DE genes with conditional bold on hover |
| 179 | +labels = ( |
| 180 | + alt.Chart(df_labels) |
| 181 | + .mark_text(fontSize=16, fontStyle="italic", fontWeight="bold", color="#222222", dy=-14, align="center") |
| 182 | + .encode(x="label_x:Q", y="label_y:Q", text="gene_name:N") |
| 183 | +) |
| 184 | + |
| 185 | +# Compose chart |
| 186 | +chart = ( |
| 187 | + (fc_band + zero_line + fc_thresholds + points_nonsig + points_sig + loess_line + labels) |
| 188 | + .properties( |
| 189 | + width=1600, |
| 190 | + height=900, |
| 191 | + title=alt.Title( |
| 192 | + "ma-differential-expression · altair · pyplots.ai", |
| 193 | + fontSize=28, |
| 194 | + color="#222222", |
| 195 | + anchor="middle", |
| 196 | + offset=10, |
| 197 | + subtitle="RNA-seq differential expression: upregulated (red) and downregulated (blue) genes", |
| 198 | + subtitleFontSize=17, |
| 199 | + subtitleColor="#777777", |
| 200 | + subtitlePadding=6, |
| 201 | + ), |
| 202 | + ) |
| 203 | + .configure_axis( |
| 204 | + labelFontSize=18, |
| 205 | + titleFontSize=22, |
| 206 | + titlePadding=12, |
| 207 | + grid=True, |
| 208 | + gridOpacity=0.12, |
| 209 | + gridColor="#cccccc", |
| 210 | + gridDash=[3, 3], |
| 211 | + ) |
| 212 | + .configure_view(strokeWidth=0) |
| 213 | + .interactive() |
| 214 | +) |
| 215 | + |
| 216 | +# Save |
| 217 | +chart.save("plot.png", scale_factor=3.0) |
| 218 | +chart.save("plot.html") |
0 commit comments