|
| 1 | +""" pyplots.ai |
| 2 | +forest-basic: Meta-Analysis Forest Plot |
| 3 | +Library: seaborn 0.13.2 | Python 3.13.11 |
| 4 | +Quality: 92/100 | Created: 2025-12-27 |
| 5 | +""" |
| 6 | + |
| 7 | +import matplotlib.patches as mpatches |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import numpy as np |
| 10 | +import pandas as pd |
| 11 | +import seaborn as sns |
| 12 | + |
| 13 | + |
| 14 | +# Data: Meta-analysis of treatment effect (mean difference) from 10 studies |
| 15 | +np.random.seed(42) |
| 16 | + |
| 17 | +studies = [ |
| 18 | + "Smith et al. 2018", |
| 19 | + "Johnson et al. 2019", |
| 20 | + "Williams et al. 2019", |
| 21 | + "Brown et al. 2020", |
| 22 | + "Davis et al. 2020", |
| 23 | + "Miller et al. 2021", |
| 24 | + "Wilson et al. 2021", |
| 25 | + "Moore et al. 2022", |
| 26 | + "Taylor et al. 2022", |
| 27 | + "Anderson et al. 2023", |
| 28 | +] |
| 29 | + |
| 30 | +# Effect sizes (mean differences) - some favor treatment, some favor control |
| 31 | +effect_sizes = [-0.45, 0.12, -0.28, -0.52, 0.05, -0.38, -0.15, -0.42, -0.22, -0.35] |
| 32 | +ci_widths = [0.35, 0.28, 0.42, 0.25, 0.55, 0.32, 0.38, 0.30, 0.45, 0.28] |
| 33 | +ci_lower = [e - w for e, w in zip(effect_sizes, ci_widths, strict=True)] |
| 34 | +ci_upper = [e + w for e, w in zip(effect_sizes, ci_widths, strict=True)] |
| 35 | +weights = [12.5, 8.2, 6.8, 14.1, 5.5, 10.3, 7.9, 11.8, 6.2, 9.7] |
| 36 | + |
| 37 | +# Calculate pooled estimate (weighted mean) |
| 38 | +pooled_effect = np.average(effect_sizes, weights=weights) |
| 39 | +pooled_se = np.sqrt(1 / np.sum([w / (ci_w**2) for w, ci_w in zip(weights, ci_widths, strict=True)])) |
| 40 | +pooled_ci_lower = pooled_effect - 1.96 * pooled_se |
| 41 | +pooled_ci_upper = pooled_effect + 1.96 * pooled_se |
| 42 | + |
| 43 | +df = pd.DataFrame( |
| 44 | + {"study": studies, "effect": effect_sizes, "ci_lower": ci_lower, "ci_upper": ci_upper, "weight": weights} |
| 45 | +) |
| 46 | + |
| 47 | +# Sort by effect size |
| 48 | +df = df.sort_values("effect", ascending=True).reset_index(drop=True) |
| 49 | + |
| 50 | +# Create figure |
| 51 | +fig, ax = plt.subplots(figsize=(16, 9)) |
| 52 | +sns.set_style("whitegrid") |
| 53 | + |
| 54 | +# Y positions for studies (leave space at bottom for pooled estimate) |
| 55 | +y_positions = np.arange(len(df)) + 1.5 |
| 56 | + |
| 57 | +# Scale marker sizes based on weight (larger for more precise studies) |
| 58 | +marker_sizes = (df["weight"] / df["weight"].max()) * 400 + 100 |
| 59 | + |
| 60 | +# Plot confidence intervals as horizontal lines |
| 61 | +for i, (_, row) in enumerate(df.iterrows()): |
| 62 | + ax.hlines(y=y_positions[i], xmin=row["ci_lower"], xmax=row["ci_upper"], color="#306998", linewidth=2.5, zorder=1) |
| 63 | + |
| 64 | +# Plot point estimates using seaborn |
| 65 | +sns.scatterplot( |
| 66 | + data=df, |
| 67 | + x="effect", |
| 68 | + y=y_positions, |
| 69 | + size="weight", |
| 70 | + sizes=(100, 500), |
| 71 | + color="#306998", |
| 72 | + edgecolor="white", |
| 73 | + linewidth=1.5, |
| 74 | + legend=False, |
| 75 | + ax=ax, |
| 76 | + zorder=2, |
| 77 | +) |
| 78 | + |
| 79 | +# Add study labels on the left |
| 80 | +for i, (_, row) in enumerate(df.iterrows()): |
| 81 | + ax.text(-1.4, y_positions[i], row["study"], fontsize=14, va="center", ha="left", fontweight="medium") |
| 82 | + |
| 83 | +# Add effect size values on the right |
| 84 | +for i, (_, row) in enumerate(df.iterrows()): |
| 85 | + ax.text( |
| 86 | + 1.1, |
| 87 | + y_positions[i], |
| 88 | + f"{row['effect']:.2f} [{row['ci_lower']:.2f}, {row['ci_upper']:.2f}]", |
| 89 | + fontsize=12, |
| 90 | + va="center", |
| 91 | + ha="left", |
| 92 | + family="monospace", |
| 93 | + ) |
| 94 | + |
| 95 | +# Draw pooled estimate diamond |
| 96 | +diamond_y = 0.3 |
| 97 | +diamond_height = 0.4 |
| 98 | +diamond = mpatches.Polygon( |
| 99 | + [ |
| 100 | + [pooled_effect, diamond_y], |
| 101 | + [pooled_ci_lower, diamond_y + diamond_height / 2], |
| 102 | + [pooled_effect, diamond_y + diamond_height], |
| 103 | + [pooled_ci_upper, diamond_y + diamond_height / 2], |
| 104 | + ], |
| 105 | + closed=True, |
| 106 | + facecolor="#FFD43B", |
| 107 | + edgecolor="#306998", |
| 108 | + linewidth=2, |
| 109 | + zorder=3, |
| 110 | +) |
| 111 | +ax.add_patch(diamond) |
| 112 | + |
| 113 | +# Add pooled estimate label |
| 114 | +ax.text(-1.4, diamond_y + diamond_height / 2, "Pooled Estimate", fontsize=14, va="center", ha="left", fontweight="bold") |
| 115 | +ax.text( |
| 116 | + 1.1, |
| 117 | + diamond_y + diamond_height / 2, |
| 118 | + f"{pooled_effect:.2f} [{pooled_ci_lower:.2f}, {pooled_ci_upper:.2f}]", |
| 119 | + fontsize=12, |
| 120 | + va="center", |
| 121 | + ha="left", |
| 122 | + family="monospace", |
| 123 | + fontweight="bold", |
| 124 | +) |
| 125 | + |
| 126 | +# Vertical reference line at null effect (0) |
| 127 | +ax.axvline(x=0, color="#666666", linestyle="--", linewidth=2, zorder=0, alpha=0.7) |
| 128 | + |
| 129 | +# Separator line above pooled estimate |
| 130 | +ax.axhline(y=1.0, color="#CCCCCC", linewidth=1.5, zorder=0) |
| 131 | + |
| 132 | +# Styling |
| 133 | +ax.set_xlim(-1.5, 1.8) |
| 134 | +ax.set_ylim(-0.3, len(df) + 2) |
| 135 | +ax.set_xlabel("Mean Difference (Treatment - Control)", fontsize=20) |
| 136 | +ax.set_ylabel("") |
| 137 | +ax.set_title("forest-basic · seaborn · pyplots.ai", fontsize=24, fontweight="bold", pad=20) |
| 138 | + |
| 139 | +# Remove y-axis ticks (study names are shown as text) |
| 140 | +ax.set_yticks([]) |
| 141 | + |
| 142 | +# Style x-axis ticks |
| 143 | +ax.tick_params(axis="x", labelsize=16) |
| 144 | + |
| 145 | +# Add annotation for interpretation |
| 146 | +ax.text( |
| 147 | + -0.75, |
| 148 | + len(df) + 1.5, |
| 149 | + "← Favors Treatment", |
| 150 | + fontsize=14, |
| 151 | + ha="center", |
| 152 | + va="center", |
| 153 | + color="#306998", |
| 154 | + fontweight="medium", |
| 155 | +) |
| 156 | +ax.text( |
| 157 | + 0.75, len(df) + 1.5, "Favors Control →", fontsize=14, ha="center", va="center", color="#666666", fontweight="medium" |
| 158 | +) |
| 159 | + |
| 160 | +# Adjust grid |
| 161 | +ax.grid(axis="x", alpha=0.3, linestyle="--") |
| 162 | +ax.grid(axis="y", visible=False) |
| 163 | + |
| 164 | +# Remove top and right spines |
| 165 | +sns.despine(left=True) |
| 166 | + |
| 167 | +plt.tight_layout() |
| 168 | +plt.savefig("plot.png", dpi=300, bbox_inches="tight") |
0 commit comments