3939from lifelines import CoxPHFitter , KaplanMeierFitter
4040from lifelines .statistics import logrank_test , multivariate_logrank_test
4141from lifelines .utils import concordance_index
42- from plotnine import (aes , annotate , element_text , geom_abline , geom_errorbarh ,
43- geom_line , geom_point , geom_smooth , geom_step , geom_text ,
44- ggplot , ggtitle , labs , scale_color_gradient ,
45- scale_color_manual , theme , theme_bw , theme_minimal )
42+ from plotnine import (aes , annotate , element_blank , element_line , element_text ,
43+ geom_abline , geom_errorbarh , geom_line , geom_point ,
44+ geom_smooth , geom_step , geom_text , ggplot , ggtitle , labs ,
45+ scale_color_gradient , scale_color_manual , theme , theme_bw ,
46+ theme_minimal )
4647from sklearn .cluster import KMeans
4748from sklearn .metrics import silhouette_score
4849from sklearn .metrics .pairwise import euclidean_distances
@@ -173,9 +174,19 @@ def plot_dim_reduced(
173174 return p
174175
175176
176- def plot_kaplan_meier_curves (durations , events , categorical_variable ):
177+ def plot_kaplan_meier_curves (durations , events , categorical_variable , title = None ):
177178 """
178- Kaplan–Meier curves with alphabetical label ordering + shared palette.
179+ Kaplan–Meier curves; groups are colored by increasing risk using the same
180+ discrete palette as ``get_color_mapping`` (lowest risk → first tab10 color).
181+
182+ Risk is the fraction in each group with an observed event on or before the
183+ pooled median follow-up time; ties break alphabetically by group name.
184+
185+ Args:
186+ durations: follow-up times
187+ events: event indicators
188+ categorical_variable: group labels
189+ title: plot title; default ``Kaplan-Meier Survival Curves by Group``
179190 """
180191 data = pd .DataFrame (
181192 {
@@ -187,14 +198,35 @@ def plot_kaplan_meier_curves(durations, events, categorical_variable):
187198 }
188199 )
189200
190- # shared palette + fixed legend order
191- color_mapping = get_color_mapping (data ["Group" ])
192- order = list (color_mapping .keys ()) # alphabetical order
201+ d = pd .to_numeric (data ["Duration" ], errors = "coerce" )
202+ ev = pd .to_numeric (data ["Event" ], errors = "coerce" )
203+ grp = data ["Group" ].astype (str )
204+ ok = d .notna () & ev .notna ()
205+ if not ok .any ():
206+ order = sorted (pd .unique (grp ))
207+ else :
208+ d0 , ev0 , g0 = d [ok ], ev [ok ], grp [ok ]
209+ t_cut = float (np .nanpercentile (d0 .to_numpy (dtype = float ), 50 ))
210+ da = d0 .to_numpy (dtype = float )
211+ early = (ev0 .to_numpy (dtype = float ) > 0 ) & (da <= t_cut )
212+ risks = {}
213+ for g in pd .unique (g0 ):
214+ m = (g0 == g ).to_numpy ()
215+ k = int (m .sum ())
216+ risks [g ] = float (early [m ].sum () / k ) if k else 0.0
217+ order = sorted (risks .keys (), key = lambda x : (risks [x ], x ))
218+
219+ # Reuse get_color_mapping palette order via zero-padded placeholders (sort order = 0..n-1)
220+ n = len (order )
221+ ph = [f"_{ i :06d} " for i in range (n )]
222+ pal = get_color_mapping (ph )
223+ hexes = [pal [k ] for k in sorted (pal .keys ())]
224+ color_mapping = dict (zip (order , hexes ))
193225
194226 # compute KM per group
195227 kmf = KaplanMeierFitter ()
196228 survival_curves = []
197- for g in order : # iterate in the same alphabetical order
229+ for g in order :
198230 gd = data [data ["Group" ] == g ]
199231 if len (gd ) == 0 :
200232 continue
@@ -230,14 +262,33 @@ def plot_kaplan_meier_curves(durations, events, categorical_variable):
230262 else :
231263 p_text = "Only one group — log-rank test not applicable"
232264
265+ if title is None :
266+ title = "Kaplan-Meier Survival Curves by Group"
267+
233268 p = (
234269 ggplot (plot_data , aes (x = "Time" , y = "Survival" , color = "Group" ))
235- + geom_step ()
236- + labs (x = "Time" , y = "Survival Probability" , color = "Group" )
237- + ggtitle ("Kaplan-Meier Survival Curves by Group" )
238- + annotate ("text" , x = 0.1 , y = 0.1 , label = p_text , size = 10 , ha = "left" )
270+ + geom_step (size = 1.15 )
271+ + labs (title = title , x = "Time" , y = "Survival probability" , color = "Group" )
272+ + annotate (
273+ "text" ,
274+ x = 0.1 ,
275+ y = 0.1 ,
276+ label = p_text ,
277+ size = 10 ,
278+ ha = "left" ,
279+ alpha = 0.85 ,
280+ )
239281 + theme_minimal ()
240- + theme (legend_title = element_text (size = 10 , weight = "bold" ))
282+ + theme (
283+ plot_title = element_text (size = 13 , weight = "bold" ),
284+ axis_title = element_text (size = 11 ),
285+ axis_text = element_text (size = 10 ),
286+ legend_title = element_text (size = 10 , weight = "bold" ),
287+ legend_text = element_text (size = 10 ),
288+ legend_position = "right" ,
289+ panel_grid_major = element_line (color = "#e0e0e0" , size = 0.55 ),
290+ panel_grid_minor = element_blank (),
291+ )
241292 + scale_color_manual (values = color_mapping )
242293 )
243294 return p
@@ -1229,13 +1280,15 @@ def recursive_binary_split_minN(
12291280 """
12301281 Recursively split df by optimal cutoff from flexynesis.utils.find_optimal_cutoff.
12311282 Stop splitting when pval >= alpha or resulting child would have < min_samples_per_group.
1232- Returns df.copy() with 'auto_group' integer labels.
1283+ Returns df.copy() with ``auto_group`` string labels ``G1``, ``G2``, ... ordered by
1284+ increasing risk (mean ``score`` among rows with follow-up time at or below the
1285+ pooled median follow-up time; if a group has no such rows, the group's overall
1286+ mean ``score`` is used).
12331287 """
12341288 df = df .copy ()
12351289 groups = {}
12361290 next_gid = 0
12371291 queue = deque ([df ])
1238-
12391292 while queue :
12401293 node = queue .popleft ()
12411294 n = len (node )
@@ -1270,6 +1323,28 @@ def recursive_binary_split_minN(
12701323 queue .append (right )
12711324
12721325 df ["auto_group" ] = df .index .map (groups )
1326+
1327+ # Relabel G1, G2, ... by increasing risk (early-window mean pred score)
1328+ t_series = pd .to_numeric (df [time ], errors = "coerce" )
1329+ t_cut = t_series .median ()
1330+ early = t_series <= t_cut
1331+ uids = sorted (df ["auto_group" ].unique ())
1332+ risk_by_gid = {}
1333+ for g in uids :
1334+ in_g = df ["auto_group" ] == g
1335+ early_in_g = in_g & early
1336+ if early_in_g .any ():
1337+ risk_by_gid [g ] = float (
1338+ pd .to_numeric (df .loc [early_in_g , score ], errors = "coerce" ).mean ()
1339+ )
1340+ else :
1341+ risk_by_gid [g ] = float (
1342+ pd .to_numeric (df .loc [in_g , score ], errors = "coerce" ).mean ()
1343+ )
1344+ ordered = sorted (uids , key = lambda x : (risk_by_gid [x ], x ))
1345+ gid_to_label = {old : f"G{ i + 1 } " for i , old in enumerate (ordered )}
1346+ df ["auto_group" ] = df ["auto_group" ].map (gid_to_label )
1347+
12731348 return df
12741349
12751350
0 commit comments