|
| 1 | +""" pyplots.ai |
| 2 | +survival-kaplan-meier: Kaplan-Meier Survival Plot |
| 3 | +Library: letsplot 4.8.2 | Python 3.13.11 |
| 4 | +Quality: 91/100 | Created: 2025-12-29 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +from lets_plot import ( |
| 10 | + LetsPlot, |
| 11 | + aes, |
| 12 | + element_text, |
| 13 | + geom_point, |
| 14 | + geom_ribbon, |
| 15 | + geom_step, |
| 16 | + ggplot, |
| 17 | + ggsave, |
| 18 | + ggsize, |
| 19 | + labs, |
| 20 | + scale_color_manual, |
| 21 | + scale_fill_manual, |
| 22 | + theme, |
| 23 | + theme_minimal, |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +LetsPlot.setup_html() |
| 28 | + |
| 29 | +# Data - Simulated clinical trial survival data for two treatment groups |
| 30 | +np.random.seed(42) |
| 31 | + |
| 32 | +n_per_group = 80 |
| 33 | + |
| 34 | + |
| 35 | +def generate_survival_data(n, hazard_rate, group_name): |
| 36 | + """Generate survival data with exponential distribution.""" |
| 37 | + times = np.random.exponential(scale=1 / hazard_rate, size=n) |
| 38 | + times = np.clip(times, 0, 36) # Max follow-up 36 months |
| 39 | + censored = times >= 36 |
| 40 | + times[censored] = 36 |
| 41 | + event = (~censored).astype(int) |
| 42 | + # Add random censoring (20% of non-terminal events) |
| 43 | + random_censor = np.random.random(n) < 0.2 |
| 44 | + event[random_censor] = 0 |
| 45 | + return pd.DataFrame({"time": times, "event": event, "group": group_name}) |
| 46 | + |
| 47 | + |
| 48 | +# Treatment group (lower hazard = better survival) |
| 49 | +treatment = generate_survival_data(n_per_group, hazard_rate=0.04, group_name="Treatment") |
| 50 | +# Control group (higher hazard = worse survival) |
| 51 | +control = generate_survival_data(n_per_group, hazard_rate=0.08, group_name="Control") |
| 52 | +df = pd.concat([treatment, control], ignore_index=True) |
| 53 | + |
| 54 | + |
| 55 | +# Kaplan-Meier estimator function |
| 56 | +def kaplan_meier(time, event): |
| 57 | + """Compute Kaplan-Meier survival curve with confidence intervals.""" |
| 58 | + df_km = pd.DataFrame({"time": time, "event": event}).sort_values("time") |
| 59 | + unique_times = np.sort(df_km["time"].unique()) |
| 60 | + n_at_risk = len(df_km) |
| 61 | + survival = 1.0 |
| 62 | + results = [{"time": 0, "survival": 1.0, "ci_lower": 1.0, "ci_upper": 1.0, "n_at_risk": n_at_risk}] |
| 63 | + var_sum = 0 |
| 64 | + |
| 65 | + for t in unique_times: |
| 66 | + at_time = df_km[df_km["time"] == t] |
| 67 | + d = at_time["event"].sum() # Number of events |
| 68 | + n = n_at_risk # Number at risk |
| 69 | + if n > 0 and d > 0: |
| 70 | + survival *= 1 - d / n |
| 71 | + var_sum += d / (n * (n - d)) if n > d else 0 |
| 72 | + # Greenwood's formula for variance |
| 73 | + se = survival * np.sqrt(var_sum) if var_sum > 0 else 0 |
| 74 | + ci_lower = max(0, survival - 1.96 * se) |
| 75 | + ci_upper = min(1, survival + 1.96 * se) |
| 76 | + results.append({"time": t, "survival": survival, "ci_lower": ci_lower, "ci_upper": ci_upper, "n_at_risk": n}) |
| 77 | + n_at_risk -= len(at_time) |
| 78 | + |
| 79 | + return pd.DataFrame(results) |
| 80 | + |
| 81 | + |
| 82 | +# Compute Kaplan-Meier for each group |
| 83 | +km_treatment = kaplan_meier(treatment["time"], treatment["event"]) |
| 84 | +km_treatment["group"] = "Treatment" |
| 85 | +km_control = kaplan_meier(control["time"], control["event"]) |
| 86 | +km_control["group"] = "Control" |
| 87 | +km_data = pd.concat([km_treatment, km_control], ignore_index=True) |
| 88 | + |
| 89 | +# Find censored observations for tick marks |
| 90 | +censored_treatment = treatment[treatment["event"] == 0].copy() |
| 91 | +censored_control = control[control["event"] == 0].copy() |
| 92 | + |
| 93 | + |
| 94 | +def get_survival_at_time(km_df, t): |
| 95 | + """Get survival probability at a given time.""" |
| 96 | + km_df = km_df.sort_values("time") |
| 97 | + idx = km_df[km_df["time"] <= t].index |
| 98 | + if len(idx) == 0: |
| 99 | + return 1.0 |
| 100 | + return km_df.loc[idx[-1], "survival"] |
| 101 | + |
| 102 | + |
| 103 | +censored_treatment["survival"] = censored_treatment["time"].apply(lambda t: get_survival_at_time(km_treatment, t)) |
| 104 | +censored_treatment["group"] = "Treatment" |
| 105 | +censored_control["survival"] = censored_control["time"].apply(lambda t: get_survival_at_time(km_control, t)) |
| 106 | +censored_control["group"] = "Control" |
| 107 | +censored_data = pd.concat([censored_treatment, censored_control], ignore_index=True) |
| 108 | + |
| 109 | +# Colors |
| 110 | +colors = ["#306998", "#DC2626"] # Python Blue for Treatment, Red for Control |
| 111 | + |
| 112 | +# Plot |
| 113 | +plot = ( |
| 114 | + ggplot() |
| 115 | + # Confidence interval ribbons |
| 116 | + + geom_ribbon(aes(x="time", ymin="ci_lower", ymax="ci_upper", fill="group"), data=km_data, alpha=0.2) |
| 117 | + # Step functions for survival curves |
| 118 | + + geom_step(aes(x="time", y="survival", color="group"), data=km_data, size=1.5) |
| 119 | + # Censored observation tick marks |
| 120 | + + geom_point( |
| 121 | + aes(x="time", y="survival", color="group"), |
| 122 | + data=censored_data, |
| 123 | + shape=3, # Plus sign for censoring ticks |
| 124 | + size=4, |
| 125 | + stroke=2, |
| 126 | + ) |
| 127 | + + scale_color_manual(values=colors) |
| 128 | + + scale_fill_manual(values=colors) |
| 129 | + + labs( |
| 130 | + x="Time (months)", |
| 131 | + y="Survival Probability", |
| 132 | + title="survival-kaplan-meier · letsplot · pyplots.ai", |
| 133 | + color="Group", |
| 134 | + fill="Group", |
| 135 | + ) |
| 136 | + + theme_minimal() |
| 137 | + + theme( |
| 138 | + axis_title=element_text(size=20), |
| 139 | + axis_text=element_text(size=16), |
| 140 | + plot_title=element_text(size=24), |
| 141 | + legend_text=element_text(size=16), |
| 142 | + legend_title=element_text(size=18), |
| 143 | + ) |
| 144 | + + ggsize(1600, 900) |
| 145 | +) |
| 146 | + |
| 147 | +# Save |
| 148 | +ggsave(plot, "plot.png", scale=3) |
| 149 | +ggsave(plot, "plot.html") |
0 commit comments