4747# ' snapped to the nearest entry in \code{rf_model$time.interest} — see the
4848# ' \strong{Survival forests} section below. When \code{NULL} (default),
4949# ' three quartile points of \code{time.interest} are used.
50+ # ' @param partial.type Character; type of predicted value for survival
51+ # ' forests, passed through to \code{\link[randomForestSRC]{partial.rfsrc}}.
52+ # ' One of \code{"surv"} (default), \code{"chf"}, or \code{"mort"}. Ignored
53+ # ' for non-survival forests. \code{partial.rfsrc()} requires a non-\code{NULL}
54+ # ' value for survival families; supplying it here avoids a cryptic
55+ # ' \dQuote{argument is of length zero} error from the underlying C code.
5056# ' @param cat_limit Variables with fewer than \code{cat_limit} unique values in
5157# ' \code{newx} are treated as categorical; all others are continuous.
5258# ' Defaults to 10.
@@ -89,6 +95,7 @@ gg_partial_rfsrc <- function(rf_model,
8995 xvar2.name = NULL ,
9096 newx = NULL ,
9197 partial.time = NULL ,
98+ partial.type = c(" surv" , " chf" , " mort" ),
9299 cat_limit = 10 ,
93100 n_eval = 25 ) {
94101 if (is.null(newx )) {
@@ -112,17 +119,28 @@ gg_partial_rfsrc <- function(rf_model,
112119 is_surv <- ! is.null(rf_model $ family ) && grepl(" surv" , rf_model $ family )
113120 if (is_surv ) {
114121 partial.time <- snap_partial_time(rf_model , partial.time )
122+ # partial.rfsrc() requires a non-NULL partial.type for survival forests;
123+ # NULL triggers a zero-length comparison inside the C code.
124+ partial.type <- match.arg(partial.type )
125+ } else {
126+ partial.type <- NULL
115127 }
116128
117129 if (is.null(xvar2.name )) {
118130 pdta <- partial_no_group(xvar.names , newx , rf_model ,
119- cat_limit , n_eval , is_surv , partial.time )
131+ cat_limit , n_eval , is_surv , partial.time ,
132+ partial.type )
120133 } else {
121134 pdta <- partial_with_group(xvar.names , xvar2.name , newx , rf_model ,
122- cat_limit , n_eval , is_surv , partial.time )
135+ cat_limit , n_eval , is_surv , partial.time ,
136+ partial.type )
123137 }
124138
125- split_partial_result(do.call(" rbind" , pdta ))
139+ result <- split_partial_result(do.call(" rbind" , pdta ))
140+ # Carry partial.type so plot.gg_partial_rfsrc() can pick the correct
141+ # y-axis label (Survival / CHF / Mortality).
142+ attr(result , " partial.type" ) <- partial.type
143+ result
126144}
127145
128146# # ---- unexported helpers -------------------------------------------------------
@@ -184,7 +202,7 @@ make_eval_grid <- function(xname, newx, cat_limit, n_eval) {
184202
185203# # Thin wrapper around partial.rfsrc that builds the argument list.
186204call_partial_rfsrc <- function (rf_model , xname , xval ,
187- is_surv , partial.time ,
205+ is_surv , partial.time , partial.type ,
188206 xvar2.name = NULL , x2val = NULL ) {
189207 args <- list (
190208 object = rf_model ,
@@ -197,44 +215,62 @@ call_partial_rfsrc <- function(rf_model, xname, xval,
197215 }
198216 if (is_surv ) {
199217 args $ partial.time <- partial.time
218+ args $ partial.type <- partial.type
200219 }
201220 do.call(randomForestSRC :: partial.rfsrc , args )
202221}
203222
204223# # Process a single predictor variable and return a tidy data.frame (or NULL).
205224partial_one_var <- function (xname , newx , rf_model ,
206225 cat_limit , n_eval , is_surv , partial.time ,
226+ partial.type ,
207227 xvar2.name = NULL , x2val = NULL ) {
208228 eg <- make_eval_grid(xname , newx , cat_limit , n_eval )
209229 if (is.null(eg )) return (NULL )
210230 xval <- eg $ xval
211231 gr <- eg $ categorical
212232 partial.obj <- call_partial_rfsrc(rf_model , xname , xval ,
213- is_surv , partial.time ,
233+ is_surv , partial.time , partial.type ,
214234 xvar2.name , x2val )
215235 pout <- randomForestSRC :: get.partial.plot.data(partial.obj , granule = gr )
216- out_dta <- data.frame (x = pout $ x , yhat = pout $ yhat )
236+ # Survival forests with >1 partial.time return yhat as an
237+ # [length(partial.values) x length(partial.time)] matrix; expand to long form
238+ # so each (x, time) pair is its own row. For non-survival or single-time
239+ # cases yhat is already a vector of length(partial.values).
240+ if (is.matrix(pout $ yhat )) {
241+ pt <- if (! is.null(pout $ partial.time )) pout $ partial.time else seq_len(ncol(pout $ yhat ))
242+ out_dta <- data.frame (
243+ x = rep(pout $ x , times = length(pt )),
244+ yhat = as.numeric(pout $ yhat ),
245+ time = rep(pt , each = length(pout $ x ))
246+ )
247+ } else {
248+ out_dta <- data.frame (x = pout $ x , yhat = pout $ yhat )
249+ if (! is.null(pout $ partial.time )) {
250+ out_dta $ time <- pout $ partial.time
251+ }
252+ }
217253 out_dta $ name <- xname
218254 out_dta $ type <- c(" continuous" , " categorical" )[gr + 1L ]
219- if (! is.null(pout $ partial.time )) {
220- out_dta $ time <- pout $ partial.time
221- }
222255 out_dta
223256}
224257
225258# # Compute partial dependence across xvar.names (no grouping variable).
226259partial_no_group <- function (xvar.names , newx , rf_model ,
227- cat_limit , n_eval , is_surv , partial.time ) {
260+ cat_limit , n_eval , is_surv , partial.time ,
261+ partial.type ) {
228262 pdta <- lapply(xvar.names , partial_one_var ,
229263 newx = newx , rf_model = rf_model ,
230264 cat_limit = cat_limit , n_eval = n_eval ,
231- is_surv = is_surv , partial.time = partial.time )
265+ is_surv = is_surv , partial.time = partial.time ,
266+ partial.type = partial.type )
232267 Filter(Negate(is.null ), pdta )
233268}
234269
235270# # Compute partial dependence across xvar.names for each level of xvar2.name.
236271partial_with_group <- function (xvar.names , xvar2.name , newx , rf_model ,
237- cat_limit , n_eval , is_surv , partial.time ) {
272+ cat_limit , n_eval , is_surv , partial.time ,
273+ partial.type ) {
238274 xv2 <- unique(newx [[xvar2.name ]])
239275 xv2 <- xv2 [! is.na(xv2 )]
240276 if (length(xv2 ) == 0L ) {
@@ -248,6 +284,7 @@ partial_with_group <- function(xvar.names, xvar2.name, newx, rf_model,
248284 newx = newx , rf_model = rf_model ,
249285 cat_limit = cat_limit , n_eval = n_eval ,
250286 is_surv = is_surv , partial.time = partial.time ,
287+ partial.type = partial.type ,
251288 xvar2.name = xvar2.name , x2val = x2val )
252289 p1dta <- Filter(Negate(is.null ), p1dta )
253290 if (length(p1dta ) == 0L ) return (NULL )
0 commit comments