|
40 | 40 | #' predictions. For \code{randomForest}, \code{oob = TRUE} uses out-of-bag |
41 | 41 | #' vote probabilities (\code{object$votes}); \code{FALSE} uses in-bag |
42 | 42 | #' \code{predict(type = "prob")}. |
| 43 | +#' @param per_class Logical; if \code{TRUE} and the forest has more than two |
| 44 | +#' classes, return per-class one-vs-rest ROC curves in a long-format |
| 45 | +#' \code{data.frame} with a \code{class} factor column and a named AUC |
| 46 | +#' vector attribute (ordered by descending AUC). Binary forests treat |
| 47 | +#' \code{per_class = TRUE} as a no-op. Currently honoured by the |
| 48 | +#' \code{randomForest} method only. |
43 | 49 | #' @param ... Extra arguments (currently unused). |
44 | 50 | #' |
45 | 51 | #' @return A \code{gg_roc} \code{data.frame} with one row per unique prediction |
|
93 | 99 | #' @aliases gg_roc gg_roc.rfsrc gg_roc.randomForest |
94 | 100 |
|
95 | 101 | #' @export |
96 | | -gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) { |
| 102 | +gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, |
| 103 | + per_class = FALSE, ...) { |
97 | 104 | # Validate that the object was grown with randomForestSRC (grow or predict) |
98 | 105 | # or is a randomForest object — the two supported class signatures. |
99 | 106 | if (sum(inherits(object, c("rfsrc", "grow"), TRUE) == c(1, 2)) != 2 && |
@@ -136,38 +143,57 @@ gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) { |
136 | 143 | invisible(gg_dta) |
137 | 144 | } |
138 | 145 | #' @export |
139 | | -gg_roc <- function(object, which_outcome, oob = TRUE, ...) { |
| 146 | +gg_roc <- function(object, which_outcome, oob = TRUE, per_class = FALSE, ...) { |
140 | 147 | UseMethod("gg_roc", object) |
141 | 148 | } |
142 | 149 |
|
143 | 150 | #' @export |
144 | | -gg_roc.randomForest <- function(object, which_outcome, oob = TRUE, ...) { |
145 | | - # Validate that the object is a genuine randomForest instance. |
| 151 | +gg_roc.randomForest <- function(object, which_outcome, oob = TRUE, |
| 152 | + per_class = FALSE, ...) { |
146 | 153 | if (!inherits(object, "randomForest")) { |
147 | | - stop( |
148 | | - "gg_roc.randomForest only works for objects of class 'randomForest'." |
149 | | - ) |
| 154 | + stop("gg_roc.randomForest only works for objects of class 'randomForest'.") |
150 | 155 | } |
151 | | - |
152 | | - # Default to computing the ROC curve for all outcome classes. |
153 | 156 | if (missing(which_outcome)) { |
154 | 157 | which_outcome <- "all" |
155 | 158 | } |
156 | | - |
157 | 159 | if (!(object$type == "classification")) { |
158 | 160 | stop("gg_roc only works with classification forests") |
159 | 161 | } |
160 | 162 |
|
| 163 | + lvls <- levels(object$y) |
| 164 | + n_class <- length(lvls) |
| 165 | + |
| 166 | + # ── per_class = TRUE path (multi-class only) ───────────────────────────── |
| 167 | + if (isTRUE(per_class) && n_class > 2L) { |
| 168 | + if (!missing(which_outcome) && !identical(which_outcome, "all")) { |
| 169 | + message("which_outcome is ignored when per_class = TRUE.") |
| 170 | + } |
| 171 | + prob <- .rf_prob_matrix(object, oob, lvls) |
| 172 | + dta <- object$y |
| 173 | + curves <- lapply(seq_along(lvls), function(k) { |
| 174 | + cv <- .rf_one_class_roc(dta, prob, k, lvls) |
| 175 | + cv$class <- lvls[k] |
| 176 | + cv |
| 177 | + }) |
| 178 | + auc_vals <- vapply(curves, calc_auc, numeric(1L)) |
| 179 | + names(auc_vals) <- lvls |
| 180 | + auc_ord <- order(auc_vals, decreasing = TRUE) |
| 181 | + auc_vals <- auc_vals[auc_ord] |
| 182 | + gg_dta <- do.call(rbind, curves) |
| 183 | + gg_dta$class <- factor(gg_dta$class, levels = lvls[auc_ord]) |
| 184 | + class(gg_dta) <- c("gg_roc", class(gg_dta)) |
| 185 | + attr(gg_dta, "auc") <- auc_vals |
| 186 | + gg_dta <- .set_provenance(gg_dta, object) |
| 187 | + return(invisible(gg_dta)) |
| 188 | + } |
| 189 | + |
| 190 | + # ── Standard path (binary, or per_class not requested) ────────────────── |
161 | 191 | # For randomForest objects the response is stored in $y (not $yvar). |
162 | 192 | gg_dta <- # nolint: object_usage_linter |
163 | | - calc_roc(object, |
164 | | - object$y, |
165 | | - which_outcome = which_outcome, |
166 | | - oob = oob |
167 | | - ) |
168 | | - class(gg_dta) <- c("gg_roc", class(gg_dta)) |
169 | | - gg_dta <- .set_provenance(gg_dta, object) |
170 | | - |
| 193 | + calc_roc(object, object$y, which_outcome = which_outcome, oob = oob) |
| 194 | + class(gg_dta) <- c("gg_roc", class(gg_dta)) |
| 195 | + attr(gg_dta, "auc") <- calc_auc(gg_dta) |
| 196 | + gg_dta <- .set_provenance(gg_dta, object) |
171 | 197 | invisible(gg_dta) |
172 | 198 | } |
173 | 199 |
|
|
0 commit comments