Skip to content

Commit f8a5805

Browse files
committed
use survival curve colors in increasing order of risk
1 parent 541ba48 commit f8a5805

1 file changed

Lines changed: 92 additions & 17 deletions

File tree

flexynesis/utils.py

Lines changed: 92 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@
3939
from lifelines import CoxPHFitter, KaplanMeierFitter
4040
from lifelines.statistics import logrank_test, multivariate_logrank_test
4141
from lifelines.utils import concordance_index
42-
from plotnine import (aes, annotate, element_text, geom_abline, geom_errorbarh,
43-
geom_line, geom_point, geom_smooth, geom_step, geom_text,
44-
ggplot, ggtitle, labs, scale_color_gradient,
45-
scale_color_manual, theme, theme_bw, theme_minimal)
42+
from plotnine import (aes, annotate, element_blank, element_line, element_text,
43+
geom_abline, geom_errorbarh, geom_line, geom_point,
44+
geom_smooth, geom_step, geom_text, ggplot, ggtitle, labs,
45+
scale_color_gradient, scale_color_manual, theme, theme_bw,
46+
theme_minimal)
4647
from sklearn.cluster import KMeans
4748
from sklearn.metrics import silhouette_score
4849
from sklearn.metrics.pairwise import euclidean_distances
@@ -173,9 +174,19 @@ def plot_dim_reduced(
173174
return p
174175

175176

176-
def plot_kaplan_meier_curves(durations, events, categorical_variable):
177+
def plot_kaplan_meier_curves(durations, events, categorical_variable, title=None):
177178
"""
178-
Kaplan–Meier curves with alphabetical label ordering + shared palette.
179+
Kaplan–Meier curves; groups are colored by increasing risk using the same
180+
discrete palette as ``get_color_mapping`` (lowest risk → first tab10 color).
181+
182+
Risk is the fraction in each group with an observed event on or before the
183+
pooled median follow-up time; ties break alphabetically by group name.
184+
185+
Args:
186+
durations: follow-up times
187+
events: event indicators
188+
categorical_variable: group labels
189+
title: plot title; default ``Kaplan-Meier Survival Curves by Group``
179190
"""
180191
data = pd.DataFrame(
181192
{
@@ -187,14 +198,35 @@ def plot_kaplan_meier_curves(durations, events, categorical_variable):
187198
}
188199
)
189200

190-
# shared palette + fixed legend order
191-
color_mapping = get_color_mapping(data["Group"])
192-
order = list(color_mapping.keys()) # alphabetical order
201+
d = pd.to_numeric(data["Duration"], errors="coerce")
202+
ev = pd.to_numeric(data["Event"], errors="coerce")
203+
grp = data["Group"].astype(str)
204+
ok = d.notna() & ev.notna()
205+
if not ok.any():
206+
order = sorted(pd.unique(grp))
207+
else:
208+
d0, ev0, g0 = d[ok], ev[ok], grp[ok]
209+
t_cut = float(np.nanpercentile(d0.to_numpy(dtype=float), 50))
210+
da = d0.to_numpy(dtype=float)
211+
early = (ev0.to_numpy(dtype=float) > 0) & (da <= t_cut)
212+
risks = {}
213+
for g in pd.unique(g0):
214+
m = (g0 == g).to_numpy()
215+
k = int(m.sum())
216+
risks[g] = float(early[m].sum() / k) if k else 0.0
217+
order = sorted(risks.keys(), key=lambda x: (risks[x], x))
218+
219+
# Reuse get_color_mapping palette order via zero-padded placeholders (sort order = 0..n-1)
220+
n = len(order)
221+
ph = [f"_{i:06d}" for i in range(n)]
222+
pal = get_color_mapping(ph)
223+
hexes = [pal[k] for k in sorted(pal.keys())]
224+
color_mapping = dict(zip(order, hexes))
193225

194226
# compute KM per group
195227
kmf = KaplanMeierFitter()
196228
survival_curves = []
197-
for g in order: # iterate in the same alphabetical order
229+
for g in order:
198230
gd = data[data["Group"] == g]
199231
if len(gd) == 0:
200232
continue
@@ -230,14 +262,33 @@ def plot_kaplan_meier_curves(durations, events, categorical_variable):
230262
else:
231263
p_text = "Only one group — log-rank test not applicable"
232264

265+
if title is None:
266+
title = "Kaplan-Meier Survival Curves by Group"
267+
233268
p = (
234269
ggplot(plot_data, aes(x="Time", y="Survival", color="Group"))
235-
+ geom_step()
236-
+ labs(x="Time", y="Survival Probability", color="Group")
237-
+ ggtitle("Kaplan-Meier Survival Curves by Group")
238-
+ annotate("text", x=0.1, y=0.1, label=p_text, size=10, ha="left")
270+
+ geom_step(size=1.15)
271+
+ labs(title=title, x="Time", y="Survival probability", color="Group")
272+
+ annotate(
273+
"text",
274+
x=0.1,
275+
y=0.1,
276+
label=p_text,
277+
size=10,
278+
ha="left",
279+
alpha=0.85,
280+
)
239281
+ theme_minimal()
240-
+ theme(legend_title=element_text(size=10, weight="bold"))
282+
+ theme(
283+
plot_title=element_text(size=13, weight="bold"),
284+
axis_title=element_text(size=11),
285+
axis_text=element_text(size=10),
286+
legend_title=element_text(size=10, weight="bold"),
287+
legend_text=element_text(size=10),
288+
legend_position="right",
289+
panel_grid_major=element_line(color="#e0e0e0", size=0.55),
290+
panel_grid_minor=element_blank(),
291+
)
241292
+ scale_color_manual(values=color_mapping)
242293
)
243294
return p
@@ -1229,13 +1280,15 @@ def recursive_binary_split_minN(
12291280
"""
12301281
Recursively split df by optimal cutoff from flexynesis.utils.find_optimal_cutoff.
12311282
Stop splitting when pval >= alpha or resulting child would have < min_samples_per_group.
1232-
Returns df.copy() with 'auto_group' integer labels.
1283+
Returns df.copy() with ``auto_group`` string labels ``G1``, ``G2``, ... ordered by
1284+
increasing risk (mean ``score`` among rows with follow-up time at or below the
1285+
pooled median follow-up time; if a group has no such rows, the group's overall
1286+
mean ``score`` is used).
12331287
"""
12341288
df = df.copy()
12351289
groups = {}
12361290
next_gid = 0
12371291
queue = deque([df])
1238-
12391292
while queue:
12401293
node = queue.popleft()
12411294
n = len(node)
@@ -1270,6 +1323,28 @@ def recursive_binary_split_minN(
12701323
queue.append(right)
12711324

12721325
df["auto_group"] = df.index.map(groups)
1326+
1327+
# Relabel G1, G2, ... by increasing risk (early-window mean pred score)
1328+
t_series = pd.to_numeric(df[time], errors="coerce")
1329+
t_cut = t_series.median()
1330+
early = t_series <= t_cut
1331+
uids = sorted(df["auto_group"].unique())
1332+
risk_by_gid = {}
1333+
for g in uids:
1334+
in_g = df["auto_group"] == g
1335+
early_in_g = in_g & early
1336+
if early_in_g.any():
1337+
risk_by_gid[g] = float(
1338+
pd.to_numeric(df.loc[early_in_g, score], errors="coerce").mean()
1339+
)
1340+
else:
1341+
risk_by_gid[g] = float(
1342+
pd.to_numeric(df.loc[in_g, score], errors="coerce").mean()
1343+
)
1344+
ordered = sorted(uids, key=lambda x: (risk_by_gid[x], x))
1345+
gid_to_label = {old: f"G{i + 1}" for i, old in enumerate(ordered)}
1346+
df["auto_group"] = df["auto_group"].map(gid_to_label)
1347+
12731348
return df
12741349

12751350

0 commit comments

Comments
 (0)