Skip to content

Commit 2addf16

Browse files
feat(altair): implement survival-kaplan-meier (#2482)
## Implementation: `survival-kaplan-meier` - altair Implements the **altair** version of `survival-kaplan-meier`. **File:** `plots/survival-kaplan-meier/implementations/altair.py` --- :robot: *[impl-generate workflow](https://github.com/MarkusNeusinger/pyplots/actions/runs/20584329111)* --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 414bebd commit 2addf16

2 files changed

Lines changed: 199 additions & 0 deletions

File tree

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
""" pyplots.ai
2+
survival-kaplan-meier: Kaplan-Meier Survival Plot
3+
Library: altair 6.0.0 | Python 3.13.11
4+
Quality: 92/100 | Created: 2025-12-29
5+
"""
6+
7+
import altair as alt
8+
import numpy as np
9+
import pandas as pd
10+
11+
12+
# Data - Clinical trial with two treatment groups
13+
np.random.seed(42)
14+
15+
# Generate survival data for two groups
16+
n_per_group = 80
17+
18+
# Treatment A (better survival)
19+
time_a = np.random.exponential(scale=24, size=n_per_group)
20+
time_a = np.clip(time_a, 1, 36) # Follow-up period: 36 months
21+
event_a = np.random.binomial(1, 0.65, size=n_per_group) # 65% event rate
22+
23+
# Treatment B (standard)
24+
time_b = np.random.exponential(scale=16, size=n_per_group)
25+
time_b = np.clip(time_b, 1, 36)
26+
event_b = np.random.binomial(1, 0.75, size=n_per_group) # 75% event rate
27+
28+
# Combine into dataframe
29+
df = pd.DataFrame(
30+
{
31+
"time": np.concatenate([time_a, time_b]),
32+
"event": np.concatenate([event_a, event_b]),
33+
"group": ["Treatment A"] * n_per_group + ["Treatment B"] * n_per_group,
34+
}
35+
)
36+
37+
38+
# Kaplan-Meier estimator function
39+
def kaplan_meier(time, event):
40+
"""Calculate Kaplan-Meier survival estimates with confidence intervals."""
41+
# Sort by time
42+
order = np.argsort(time)
43+
time = time[order]
44+
event = event[order]
45+
46+
# Get unique event times
47+
unique_times = np.unique(time[event == 1])
48+
49+
# Calculate survival at each time point
50+
survival = 1.0
51+
times = [0]
52+
survivals = [1.0]
53+
ci_lower = [1.0]
54+
ci_upper = [1.0]
55+
var_sum = 0
56+
57+
for t in unique_times:
58+
at_risk = np.sum(time >= t)
59+
events = np.sum((time == t) & (event == 1))
60+
61+
if at_risk > 0:
62+
survival *= (at_risk - events) / at_risk
63+
# Greenwood's formula for variance
64+
if at_risk > events:
65+
var_sum += events / (at_risk * (at_risk - events))
66+
67+
times.append(t)
68+
survivals.append(survival)
69+
70+
# 95% confidence interval using log transformation
71+
se = survival * np.sqrt(var_sum) if var_sum > 0 else 0
72+
ci_lower.append(max(0, survival - 1.96 * se))
73+
ci_upper.append(min(1, survival + 1.96 * se))
74+
75+
# Extend to max time
76+
max_time = time.max()
77+
times.append(max_time)
78+
survivals.append(survival)
79+
ci_lower.append(ci_lower[-1])
80+
ci_upper.append(ci_upper[-1])
81+
82+
return np.array(times), np.array(survivals), np.array(ci_lower), np.array(ci_upper)
83+
84+
85+
# Calculate KM estimates for each group
86+
km_data = []
87+
for group_name in ["Treatment A", "Treatment B"]:
88+
mask = df["group"] == group_name
89+
times, survivals, ci_low, ci_high = kaplan_meier(df.loc[mask, "time"].values, df.loc[mask, "event"].values)
90+
91+
for i in range(len(times)):
92+
km_data.append(
93+
{
94+
"Time (Months)": times[i],
95+
"Survival Probability": survivals[i],
96+
"CI Lower": ci_low[i],
97+
"CI Upper": ci_high[i],
98+
"Group": group_name,
99+
}
100+
)
101+
102+
km_df = pd.DataFrame(km_data)
103+
104+
# Get censored observations for tick marks
105+
censored = df[df["event"] == 0].copy()
106+
censored_marks = []
107+
for _, row in censored.iterrows():
108+
mask = (km_df["Group"] == row["group"]) & (km_df["Time (Months)"] <= row["time"])
109+
if mask.any():
110+
surv_at_censor = km_df.loc[mask, "Survival Probability"].iloc[-1]
111+
censored_marks.append(
112+
{"Time (Months)": row["time"], "Survival Probability": surv_at_censor, "Group": row["group"]}
113+
)
114+
115+
censored_df = pd.DataFrame(censored_marks)
116+
117+
# Define colors
118+
color_scale = alt.Scale(domain=["Treatment A", "Treatment B"], range=["#306998", "#FFD43B"])
119+
120+
# Step line for survival curves (with legend)
121+
survival_line = (
122+
alt.Chart(km_df)
123+
.mark_line(interpolate="step-after", strokeWidth=4)
124+
.encode(
125+
x=alt.X("Time (Months):Q", scale=alt.Scale(domain=[0, 38]), title="Time (Months)"),
126+
y=alt.Y("Survival Probability:Q", scale=alt.Scale(domain=[0, 1.05]), title="Survival Probability"),
127+
color=alt.Color("Group:N", scale=color_scale),
128+
)
129+
)
130+
131+
# Confidence interval bands
132+
ci_band = (
133+
alt.Chart(km_df)
134+
.mark_area(interpolate="step-after", opacity=0.25)
135+
.encode(
136+
x=alt.X("Time (Months):Q"),
137+
y=alt.Y("CI Lower:Q", title=""),
138+
y2=alt.Y2("CI Upper:Q"),
139+
color=alt.Color("Group:N", scale=color_scale, legend=None),
140+
)
141+
)
142+
143+
# Censored observation marks
144+
censor_marks = (
145+
alt.Chart(censored_df)
146+
.mark_tick(thickness=3, size=25)
147+
.encode(
148+
x=alt.X("Time (Months):Q"),
149+
y=alt.Y("Survival Probability:Q", title=""),
150+
color=alt.Color("Group:N", scale=color_scale, legend=None),
151+
)
152+
)
153+
154+
# Combine layers using + operator and resolve legend
155+
chart = (
156+
(ci_band + survival_line + censor_marks)
157+
.resolve_legend(color="independent")
158+
.properties(
159+
width=1600,
160+
height=900,
161+
title=alt.Title("survival-kaplan-meier · altair · pyplots.ai", fontSize=32, anchor="middle", offset=20),
162+
)
163+
.configure_axis(labelFontSize=18, titleFontSize=22, gridOpacity=0.3, gridDash=[4, 4])
164+
.configure_view(strokeWidth=0)
165+
.configure_legend(titleFontSize=20, labelFontSize=18, symbolStrokeWidth=4)
166+
)
167+
168+
# Save outputs
169+
chart.save("plot.png", scale_factor=3.0)
170+
chart.save("plot.html")
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
library: altair
2+
specification_id: survival-kaplan-meier
3+
created: '2025-12-29T22:50:58Z'
4+
updated: '2025-12-29T22:59:02Z'
5+
generated_by: claude-opus-4-5-20251101
6+
workflow_run: 20584329111
7+
issue: 0
8+
python_version: 3.13.11
9+
library_version: 6.0.0
10+
preview_url: https://storage.googleapis.com/pyplots-images/plots/survival-kaplan-meier/altair/plot.png
11+
preview_thumb: https://storage.googleapis.com/pyplots-images/plots/survival-kaplan-meier/altair/plot_thumb.png
12+
preview_html: https://storage.googleapis.com/pyplots-images/plots/survival-kaplan-meier/altair/plot.html
13+
quality_score: 92
14+
review:
15+
strengths:
16+
- Excellent visual execution with step-function curves, CI bands, and censoring
17+
marks all clearly visible and properly layered
18+
- Colorblind-safe color palette (blue vs yellow) with good contrast
19+
- Correct implementation of Kaplan-Meier algorithm including Greenwood formula for
20+
variance estimation
21+
- Proper use of Altair layered chart composition and mark types (line, area, tick)
22+
- Clinically realistic data scenario with appropriate sample sizes and event rates
23+
weaknesses:
24+
- Code contains a kaplan_meier() function definition which violates the KISS structure
25+
requirement (imports → data → plot → save, no functions)
26+
- Y-axis scale extends to 1.10 creating unnecessary whitespace above the survival
27+
probability ceiling of 1.0
28+
- Legend placed inside plot area (upper-right) could overlap data in some scenarios;
29+
consider placing outside

0 commit comments

Comments
 (0)