|
| 1 | +""" pyplots.ai |
| 2 | +survival-kaplan-meier: Kaplan-Meier Survival Plot |
| 3 | +Library: matplotlib 3.10.8 | Python 3.13.11 |
| 4 | +Quality: 92/100 | Created: 2025-12-29 |
| 5 | +""" |
| 6 | + |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +import numpy as np |
| 9 | + |
| 10 | + |
| 11 | +# Data - Clinical trial survival data for two treatment groups |
| 12 | +np.random.seed(42) |
| 13 | + |
| 14 | +# Generate realistic survival times (exponential-like distribution) |
| 15 | +n_per_group = 80 |
| 16 | + |
| 17 | +# Treatment group (better survival) |
| 18 | +treatment_times = np.random.exponential(scale=24, size=n_per_group) |
| 19 | +treatment_times = np.clip(treatment_times, 0.5, 60) # Cap at 60 months |
| 20 | +treatment_censored = np.random.binomial(1, 0.25, n_per_group) # 25% censored |
| 21 | +treatment_events = 1 - treatment_censored |
| 22 | + |
| 23 | +# Control group (worse survival) |
| 24 | +control_times = np.random.exponential(scale=16, size=n_per_group) |
| 25 | +control_times = np.clip(control_times, 0.5, 60) |
| 26 | +control_censored = np.random.binomial(1, 0.20, n_per_group) # 20% censored |
| 27 | +control_events = 1 - control_censored |
| 28 | + |
| 29 | +# Kaplan-Meier calculation for Treatment group |
| 30 | +order = np.argsort(treatment_times) |
| 31 | +treat_t_sorted = treatment_times[order] |
| 32 | +treat_e_sorted = treatment_events[order] |
| 33 | +treat_unique = np.unique(treat_t_sorted[treat_e_sorted == 1]) |
| 34 | + |
| 35 | +treat_time_pts = [0.0] |
| 36 | +treat_surv_probs = [1.0] |
| 37 | +treat_std_errs = [0.0] |
| 38 | +treat_surv = 1.0 |
| 39 | +treat_var = 0.0 |
| 40 | + |
| 41 | +for t in treat_unique: |
| 42 | + n_risk = np.sum(treat_t_sorted >= t) |
| 43 | + d = np.sum((treat_t_sorted == t) & (treat_e_sorted == 1)) |
| 44 | + if n_risk > 0 and d > 0: |
| 45 | + treat_surv *= (n_risk - d) / n_risk |
| 46 | + if n_risk > d: |
| 47 | + treat_var += d / (n_risk * (n_risk - d)) |
| 48 | + treat_time_pts.append(t) |
| 49 | + treat_surv_probs.append(treat_surv) |
| 50 | + treat_std_errs.append(np.sqrt(treat_var) * treat_surv if treat_surv > 0 else 0) |
| 51 | + |
| 52 | +treat_times_km = np.array(treat_time_pts) |
| 53 | +treat_surv_km = np.array(treat_surv_probs) |
| 54 | +treat_se_km = np.array(treat_std_errs) |
| 55 | + |
| 56 | +# Kaplan-Meier calculation for Control group |
| 57 | +order = np.argsort(control_times) |
| 58 | +ctrl_t_sorted = control_times[order] |
| 59 | +ctrl_e_sorted = control_events[order] |
| 60 | +ctrl_unique = np.unique(ctrl_t_sorted[ctrl_e_sorted == 1]) |
| 61 | + |
| 62 | +ctrl_time_pts = [0.0] |
| 63 | +ctrl_surv_probs = [1.0] |
| 64 | +ctrl_std_errs = [0.0] |
| 65 | +ctrl_surv = 1.0 |
| 66 | +ctrl_var = 0.0 |
| 67 | + |
| 68 | +for t in ctrl_unique: |
| 69 | + n_risk = np.sum(ctrl_t_sorted >= t) |
| 70 | + d = np.sum((ctrl_t_sorted == t) & (ctrl_e_sorted == 1)) |
| 71 | + if n_risk > 0 and d > 0: |
| 72 | + ctrl_surv *= (n_risk - d) / n_risk |
| 73 | + if n_risk > d: |
| 74 | + ctrl_var += d / (n_risk * (n_risk - d)) |
| 75 | + ctrl_time_pts.append(t) |
| 76 | + ctrl_surv_probs.append(ctrl_surv) |
| 77 | + ctrl_std_errs.append(np.sqrt(ctrl_var) * ctrl_surv if ctrl_surv > 0 else 0) |
| 78 | + |
| 79 | +ctrl_times_km = np.array(ctrl_time_pts) |
| 80 | +ctrl_surv_km = np.array(ctrl_surv_probs) |
| 81 | +ctrl_se_km = np.array(ctrl_std_errs) |
| 82 | + |
| 83 | +# Get censored observation times for tick marks |
| 84 | +treat_censor_times = treatment_times[treatment_events == 0] |
| 85 | +ctrl_censor_times = control_times[control_events == 0] |
| 86 | + |
| 87 | +# Interpolate survival at censored times for tick marks |
| 88 | +treat_censor_surv = np.interp(treat_censor_times, treat_times_km, treat_surv_km) |
| 89 | +ctrl_censor_surv = np.interp(ctrl_censor_times, ctrl_times_km, ctrl_surv_km) |
| 90 | + |
| 91 | +# Plot |
| 92 | +fig, ax = plt.subplots(figsize=(16, 9)) |
| 93 | + |
| 94 | +# Python colors |
| 95 | +treatment_color = "#306998" |
| 96 | +control_color = "#FFD43B" |
| 97 | + |
| 98 | +# Treatment group curve with CI |
| 99 | +ax.step(treat_times_km, treat_surv_km, where="post", color=treatment_color, linewidth=3, label="Treatment Group") |
| 100 | +treat_upper = np.clip(treat_surv_km + 1.96 * treat_se_km, 0, 1) |
| 101 | +treat_lower = np.clip(treat_surv_km - 1.96 * treat_se_km, 0, 1) |
| 102 | +ax.fill_between(treat_times_km, treat_lower, treat_upper, step="post", alpha=0.2, color=treatment_color) |
| 103 | + |
| 104 | +# Control group curve with CI |
| 105 | +ax.step(ctrl_times_km, ctrl_surv_km, where="post", color=control_color, linewidth=3, label="Control Group") |
| 106 | +ctrl_upper = np.clip(ctrl_surv_km + 1.96 * ctrl_se_km, 0, 1) |
| 107 | +ctrl_lower = np.clip(ctrl_surv_km - 1.96 * ctrl_se_km, 0, 1) |
| 108 | +ax.fill_between(ctrl_times_km, ctrl_lower, ctrl_upper, step="post", alpha=0.3, color=control_color) |
| 109 | + |
| 110 | +# Censored observation tick marks |
| 111 | +ax.scatter(treat_censor_times, treat_censor_surv, marker="|", s=400, color=treatment_color, linewidth=2, zorder=5) |
| 112 | +ax.scatter(ctrl_censor_times, ctrl_censor_surv, marker="|", s=400, color="#CC9A00", linewidth=2, zorder=5) |
| 113 | + |
| 114 | +# Calculate median survival times |
| 115 | +treat_median_idx = np.where(treat_surv_km <= 0.5)[0] |
| 116 | +ctrl_median_idx = np.where(ctrl_surv_km <= 0.5)[0] |
| 117 | + |
| 118 | +treat_median = treat_times_km[treat_median_idx[0]] if len(treat_median_idx) > 0 else None |
| 119 | +ctrl_median = ctrl_times_km[ctrl_median_idx[0]] if len(ctrl_median_idx) > 0 else None |
| 120 | + |
| 121 | +# Add median survival annotation lines |
| 122 | +if treat_median: |
| 123 | + ax.axhline(y=0.5, color="gray", linestyle=":", linewidth=1.5, alpha=0.5) |
| 124 | + ax.axvline(x=treat_median, color=treatment_color, linestyle=":", linewidth=1.5, alpha=0.7) |
| 125 | +if ctrl_median: |
| 126 | + ax.axvline(x=ctrl_median, color="#CC9A00", linestyle=":", linewidth=1.5, alpha=0.7) |
| 127 | + |
| 128 | +# Add median text |
| 129 | +median_text = "" |
| 130 | +if treat_median: |
| 131 | + median_text += f"Treatment median: {treat_median:.1f} mo" |
| 132 | +if ctrl_median: |
| 133 | + median_text += f"\nControl median: {ctrl_median:.1f} mo" |
| 134 | + |
| 135 | +ax.text( |
| 136 | + 0.98, |
| 137 | + 0.02, |
| 138 | + median_text.strip(), |
| 139 | + transform=ax.transAxes, |
| 140 | + fontsize=16, |
| 141 | + verticalalignment="bottom", |
| 142 | + horizontalalignment="right", |
| 143 | + bbox={"boxstyle": "round,pad=0.5", "facecolor": "white", "edgecolor": "gray", "alpha": 0.8}, |
| 144 | +) |
| 145 | + |
| 146 | +# Styling |
| 147 | +ax.set_xlabel("Time (months)", fontsize=20) |
| 148 | +ax.set_ylabel("Survival Probability", fontsize=20) |
| 149 | +ax.set_title("survival-kaplan-meier · matplotlib · pyplots.ai", fontsize=24) |
| 150 | +ax.tick_params(axis="both", labelsize=16) |
| 151 | +ax.set_xlim(0, 65) |
| 152 | +ax.set_ylim(0, 1.05) |
| 153 | +ax.legend(fontsize=16, loc="upper right") |
| 154 | +ax.grid(True, alpha=0.3, linestyle="--") |
| 155 | + |
| 156 | +plt.tight_layout() |
| 157 | +plt.savefig("plot.png", dpi=300, bbox_inches="tight") |
0 commit comments