diff --git a/.Rbuildignore b/.Rbuildignore index 8c59eb52..75af5ebb 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -48,4 +48,5 @@ framed.sty # FUSE filesystem temporaries (safe to ignore; R CMD build already skips dotfiles) ^R/\.fuse_hidden ^\.positai$ +^\.superpowers$ ^vignettes/.*_cache$ diff --git a/DESCRIPTION b/DESCRIPTION index 0426059a..dd57db74 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: ggRandomForests Type: Package Title: Visually Exploring Random Forests -Version: 2.7.3.9005 +Version: 2.7.3.9006 Date: 2026-05-21 Authors@R: person("John", "Ehrlinger", role = c("aut", "cre"), diff --git a/NEWS.md b/NEWS.md index 6271461c..7529652a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,8 +1,21 @@ Package: ggRandomForests -Version: 2.7.3.9005 +Version: 2.7.3.9006 ggRandomForests v2.8.0 (development) — continued ================================================= +* **`gg_roc`: per-class one-vs-rest ROC curves (#88, closes #72).** + - `gg_roc()` gains a `per_class = FALSE` argument. When `per_class = TRUE` + and the forest has more than two classes, returns a long-format `gg_roc` + data frame with a `class` factor column and a named AUC vector attribute + (one entry per class, ordered by descending AUC). + - `plot.gg_roc()` gains a `panel = c("overlay", "facet")` argument. When the + `gg_roc` object contains a `class` column, the overlay path colours curves + by class; the facet path wraps each class into its own panel. + - `summary.gg_roc()` now prints named per-class AUC values when the `class` + column is present. + - Binary forests: `per_class = TRUE` is a silent no-op (single-curve result + returned unchanged). + - ROC confidence intervals are deferred to v2.9.0 (issue #7 / #72-CIs). * **varPro variable dependency: `gg_udependent()` (Phase 3).** - `gg_udependent()` extracts cross-variable dependency scores from a `uvarpro` fit using `varPro::get.beta.entropy()` + diff --git a/R/gg_roc.R b/R/gg_roc.R index 71b279b9..41dee50a 100644 --- a/R/gg_roc.R +++ b/R/gg_roc.R @@ -40,6 +40,12 @@ #' predictions. For \code{randomForest}, \code{oob = TRUE} uses out-of-bag #' vote probabilities (\code{object$votes}); \code{FALSE} uses in-bag #' \code{predict(type = "prob")}. +#' @param per_class Logical; if \code{TRUE} and the forest has more than two +#' classes, return per-class one-vs-rest ROC curves in a long-format +#' \code{data.frame} with a \code{class} factor column and a named AUC +#' vector attribute (ordered by descending AUC). Binary forests treat +#' \code{per_class = TRUE} as a no-op. Currently honoured by the +#' \code{randomForest} method only. #' @param ... Extra arguments (currently unused). #' #' @return A \code{gg_roc} \code{data.frame} with one row per unique prediction @@ -93,7 +99,8 @@ #' @aliases gg_roc gg_roc.rfsrc gg_roc.randomForest #' @export -gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) { +gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, + per_class = FALSE, ...) { # Validate that the object was grown with randomForestSRC (grow or predict) # or is a randomForest object — the two supported class signatures. if (sum(inherits(object, c("rfsrc", "grow"), TRUE) == c(1, 2)) != 2 && @@ -136,38 +143,57 @@ gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) { invisible(gg_dta) } #' @export -gg_roc <- function(object, which_outcome, oob = TRUE, ...) { +gg_roc <- function(object, which_outcome, oob = TRUE, per_class = FALSE, ...) { UseMethod("gg_roc", object) } #' @export -gg_roc.randomForest <- function(object, which_outcome, oob = TRUE, ...) { - # Validate that the object is a genuine randomForest instance. +gg_roc.randomForest <- function(object, which_outcome, oob = TRUE, + per_class = FALSE, ...) { if (!inherits(object, "randomForest")) { - stop( - "gg_roc.randomForest only works for objects of class 'randomForest'." - ) + stop("gg_roc.randomForest only works for objects of class 'randomForest'.") } - - # Default to computing the ROC curve for all outcome classes. if (missing(which_outcome)) { which_outcome <- "all" } - if (!(object$type == "classification")) { stop("gg_roc only works with classification forests") } + lvls <- levels(object$y) + n_class <- length(lvls) + + # ── per_class = TRUE path (multi-class only) ───────────────────────────── + if (isTRUE(per_class) && n_class > 2L) { + if (!missing(which_outcome) && !identical(which_outcome, "all")) { + message("which_outcome is ignored when per_class = TRUE.") + } + prob <- .rf_prob_matrix(object, oob, lvls) + dta <- object$y + curves <- lapply(seq_along(lvls), function(k) { + cv <- .rf_one_class_roc(dta, prob, k, lvls) + cv$class <- lvls[k] + cv + }) + auc_vals <- vapply(curves, calc_auc, numeric(1L)) + names(auc_vals) <- lvls + auc_ord <- order(auc_vals, decreasing = TRUE) + auc_vals <- auc_vals[auc_ord] + gg_dta <- do.call(rbind, curves) + gg_dta$class <- factor(gg_dta$class, levels = lvls[auc_ord]) + class(gg_dta) <- c("gg_roc", class(gg_dta)) + attr(gg_dta, "auc") <- auc_vals + gg_dta <- .set_provenance(gg_dta, object) + return(invisible(gg_dta)) + } + + # ── Standard path (binary, or per_class not requested) ────────────────── # For randomForest objects the response is stored in $y (not $yvar). gg_dta <- # nolint: object_usage_linter - calc_roc(object, - object$y, - which_outcome = which_outcome, - oob = oob - ) - class(gg_dta) <- c("gg_roc", class(gg_dta)) - gg_dta <- .set_provenance(gg_dta, object) - + calc_roc(object, object$y, which_outcome = which_outcome, oob = oob) + class(gg_dta) <- c("gg_roc", class(gg_dta)) + attr(gg_dta, "auc") <- calc_auc(gg_dta) + gg_dta <- .set_provenance(gg_dta, object) invisible(gg_dta) } diff --git a/R/plot.gg_roc.R b/R/plot.gg_roc.R index 19fb9dbc..809f4431 100644 --- a/R/plot.gg_roc.R +++ b/R/plot.gg_roc.R @@ -24,6 +24,10 @@ #' the forest has more than two classes, ROC curves for all classes are #' overlaid in a single plot. For binary forests \code{NULL} defaults to #' class index 2. +#' @param panel Character; layout for per-class ROC objects (those produced by +#' \code{gg_roc(..., per_class = TRUE)}). \code{"overlay"} (default) draws all +#' class curves in one panel coloured by class; \code{"facet"} wraps each +#' class into its own panel. Ignored for single-class \code{gg_roc} objects. #' @param ... Additional arguments forwarded to \code{\link{gg_roc}} when #' \code{x} is a raw forest object (e.g. \code{oob = FALSE}). #' @@ -73,7 +77,12 @@ #' for (i in seq_len(n_cls)) print(plot(gg_roc(rfsrc_iris, which_outcome = i))) #' #' @export -plot.gg_roc <- function(x, which_outcome = NULL, ...) { +plot.gg_roc <- function(x, which_outcome = NULL, ..., + panel = c("overlay", "facet")) { + # `panel` is placed after `...` so it is name-only: this preserves + # positional back-compatibility for existing callers (e.g. + # plot(x, 1, FALSE) still routes the 3rd positional arg into `...`). + panel <- match.arg(panel) gg_dta <- x ## ---- Accept a raw rfsrc or randomForest object ----------------------- @@ -118,8 +127,17 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) { } } - ## ---- Single-class ROC plot ------------------------------------------ + ## ---- Single-class ROC plot (or per_class long-format) ---------------- if (inherits(gg_dta, "gg_roc")) { + + # Per-class detection: gg_roc produced by gg_roc(..., per_class = TRUE) + # carries a 'class' column (factor) + a named AUC vector attribute. + # Rendering is delegated to a helper to keep this method's cyclomatic + # complexity within the project lint budget. + if ("class" %in% names(gg_dta)) { + return(.plot_gg_roc_per_class(gg_dta, attr(x, "auc"), panel)) + } + # Sort by specificity so the ROC curve is drawn left-to-right gg_dta <- gg_dta[order(gg_dta$spec), ] # False positive rate = 1 - specificity @@ -193,7 +211,54 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) { ) + ggplot2::coord_fixed() - # Multi-class: do not annotate a single AUC value — each class has its own. + # Multi-class: do not annotate a single AUC value - each class has its own. } return(gg_plt) } + +# Render a per-class (one-vs-rest) ROC object produced by +# gg_roc(..., per_class = TRUE). Split out of plot.gg_roc() so that method +# stays within the project's cyclomatic-complexity lint budget. +# +# gg_dta : long-format gg_roc data frame with a 'class' factor column +# auc : named numeric AUC vector (one entry per class) or NULL +# panel : "overlay" (curves coloured by class) or "facet" (one panel each) +.plot_gg_roc_per_class <- function(gg_dta, auc, panel) { + gg_dta$fpr <- 1 - gg_dta$spec + + if (panel == "overlay") { + gg_plt <- ggplot2::ggplot(gg_dta) + + ggplot2::geom_line(ggplot2::aes( + x = .data$fpr, y = .data$sens, color = .data$class + )) + + ggplot2::labs( + x = "1 - Specificity (FPR)", y = "Sensitivity (TPR)", + color = "Class" + ) + } else { + gg_plt <- ggplot2::ggplot(gg_dta) + + ggplot2::geom_line(ggplot2::aes(x = .data$fpr, y = .data$sens)) + + ggplot2::labs(x = "1 - Specificity (FPR)", y = "Sensitivity (TPR)") + + ggplot2::facet_wrap(~class) + } + + gg_plt <- gg_plt + + ggplot2::geom_abline( + slope = 1, intercept = 0, + col = "red", linetype = 2, linewidth = .5 + ) + + ggplot2::coord_fixed() + + # AUC caption - top 5 classes by descending AUC (already sorted) + if (!is.null(auc) && length(auc) > 0L) { + top_n <- min(5L, length(auc)) + auc_str <- paste( + sprintf("%s=%.3g", names(auc)[seq_len(top_n)], auc[seq_len(top_n)]), + collapse = ", " + ) + if (length(auc) > 5L) auc_str <- paste0(auc_str, ", ...") + gg_plt <- gg_plt + + ggplot2::labs(caption = paste("OvR ROC, per_class=TRUE. AUC:", auc_str)) + } + gg_plt +} diff --git a/R/summary_methods.R b/R/summary_methods.R index 03caaad3..a4d721db 100644 --- a/R/summary_methods.R +++ b/R/summary_methods.R @@ -189,11 +189,19 @@ summary.gg_partial_varpro <- function(object, ...) { #' @rdname summary.gg #' @export summary.gg_roc <- function(object, ...) { - body <- c( - sprintf("thresholds: %d", nrow(object)), - sprintf("AUC: %.4g", - attr(object, "auc") %||% .gg_auc_trap(object)) - ) + auc <- attr(object, "auc") %||% .gg_auc_trap(object) + if ("class" %in% names(object)) { + # per_class = TRUE path: named AUC vector, one entry per class + n_cls <- nlevels(object$class) + auc_str <- paste(sprintf("%s=%.4g", names(auc), auc), collapse = ", ") + body <- c(sprintf("classes: %d", n_cls), + sprintf("AUC: %s", auc_str)) + } else { + body <- c( + sprintf("thresholds: %d", nrow(object)), + sprintf("AUC: %.4g", auc) + ) + } .summary_skel(object, "gg_roc", body) } diff --git a/man/gg_roc.rfsrc.Rd b/man/gg_roc.rfsrc.Rd index 881370c7..89a37eaa 100644 --- a/man/gg_roc.rfsrc.Rd +++ b/man/gg_roc.rfsrc.Rd @@ -6,7 +6,7 @@ \alias{gg_roc.randomForest} \title{ROC (Receiver Operating Characteristic) curve data from a classification forest.} \usage{ -\method{gg_roc}{rfsrc}(object, which_outcome, oob = TRUE, ...) +\method{gg_roc}{rfsrc}(object, which_outcome, oob = TRUE, per_class = FALSE, ...) } \arguments{ \item{object}{A classification \code{\link[randomForestSRC]{rfsrc}} or @@ -33,6 +33,13 @@ predictions. For \code{randomForest}, \code{oob = TRUE} uses out-of-bag vote probabilities (\code{object$votes}); \code{FALSE} uses in-bag \code{predict(type = "prob")}.} +\item{per_class}{Logical; if \code{TRUE} and the forest has more than two +classes, return per-class one-vs-rest ROC curves in a long-format +\code{data.frame} with a \code{class} factor column and a named AUC +vector attribute (ordered by descending AUC). Binary forests treat +\code{per_class = TRUE} as a no-op. Currently honoured by the +\code{randomForest} method only.} + \item{...}{Extra arguments (currently unused).} } \value{ diff --git a/man/plot.gg_roc.Rd b/man/plot.gg_roc.Rd index 27178253..c98dbab9 100644 --- a/man/plot.gg_roc.Rd +++ b/man/plot.gg_roc.Rd @@ -4,7 +4,7 @@ \alias{plot.gg_roc} \title{ROC plot generic function for a \code{\link{gg_roc}} object.} \usage{ -\method{plot}{gg_roc}(x, which_outcome = NULL, ...) +\method{plot}{gg_roc}(x, which_outcome = NULL, ..., panel = c("overlay", "facet")) } \arguments{ \item{x}{A \code{\link{gg_roc}} object, or a raw @@ -20,6 +20,11 @@ class index 2.} \item{...}{Additional arguments forwarded to \code{\link{gg_roc}} when \code{x} is a raw forest object (e.g. \code{oob = FALSE}).} + +\item{panel}{Character; layout for per-class ROC objects (those produced by +\code{gg_roc(..., per_class = TRUE)}). \code{"overlay"} (default) draws all +class curves in one panel coloured by class; \code{"facet"} wraps each +class into its own panel. Ignored for single-class \code{gg_roc} objects.} } \value{ A \code{ggplot} object. The x-axis shows 1 - Specificity (FPR) diff --git a/tests/testthat/test_gg_roc.R b/tests/testthat/test_gg_roc.R index 7d2ab4f9..f7c2a305 100644 --- a/tests/testthat/test_gg_roc.R +++ b/tests/testthat/test_gg_roc.R @@ -249,3 +249,124 @@ test_that("calc_roc.rfsrc output is unchanged for an explicit which_outcome (gua expect_true(all(c("sens", "spec", "pct") %in% colnames(g))) expect_gte(calc_auc(g), 0.9) # rfsrc iris setosa-vs-rest stays strong }) + +## ── per_class = TRUE (PR #88) ────────────────────────────────────────────── + +test_that("gg_roc per_class=TRUE: long format with class column", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + expect_true("class" %in% names(gg)) + expect_true(all(c("sens", "spec", "pct") %in% names(gg))) # pct = threshold; same 3-col contract as calc_roc + expect_s3_class(gg$class, "factor") + expect_equal(nlevels(gg$class), 3L) +}) + +test_that("gg_roc per_class=TRUE: auc attr is named numeric vector length 3", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + auc <- attr(gg, "auc") + expect_length(auc, 3L) + expect_named(auc) + # setosa is linearly separable in iris — AUC should be near-perfect + expect_gt(auc[["setosa"]], 0.99) + # AUC values must be sorted descending + expect_true(all(diff(auc) <= 0)) +}) + +test_that("gg_roc per_class=TRUE: class factor levels ordered by descending AUC", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + auc <- attr(gg, "auc") + expect_equal(levels(gg$class), names(auc)) +}) + +test_that("gg_roc per_class=TRUE on binary forest: no class column (no-op)", { + skip_if_not_installed("randomForest") + set.seed(1L) + bin_data <- iris[iris$Species != "virginica", ] + bin_data$Species <- droplevels(bin_data$Species) + rf <- randomForest::randomForest(Species ~ ., data = bin_data, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + # Binary forest: per_class is a no-op — no class column, scalar AUC + expect_false("class" %in% names(gg)) + expect_length(attr(gg, "auc"), 1L) +}) + +test_that("gg_roc per_class=TRUE + which_outcome integer: message then per_class wins", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + expect_message( + gg <- gg_roc(rf, per_class = TRUE, which_outcome = 1L), + "which_outcome.*ignored.*per_class" + ) + expect_true("class" %in% names(gg)) +}) + +test_that("gg_roc which_outcome='all' still returns macro-average (no class column)", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, which_outcome = "all") + expect_false("class" %in% names(gg)) + # Macro-average returns a single data frame, not a class-faceted one + expect_true(all(c("sens", "spec", "pct") %in% names(gg))) # pct = threshold; same 3-col contract as calc_roc +}) + +## ── plot.gg_roc per_class paths (PR #88) ───────────────────────────────── + +test_that("plot.gg_roc per_class=TRUE: overlay returns ggplot", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + p <- plot(gg, panel = "overlay") + expect_s3_class(p, "ggplot") +}) + +test_that("plot.gg_roc per_class=TRUE: facet returns ggplot", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + p <- plot(gg, panel = "facet") + expect_s3_class(p, "ggplot") +}) + +test_that("plot.gg_roc per_class=TRUE: layer_data smokeable for overlay", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + p <- plot(gg, panel = "overlay") + expect_no_error(ggplot2::layer_data(p, 1L)) +}) + +test_that("plot.gg_roc existing single-class path unchanged", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, which_outcome = 1L) + p <- plot(gg) + expect_s3_class(p, "ggplot") + expect_no_error(ggplot2::layer_data(p, 1L)) +}) + +test_that("summary.gg_roc per_class=TRUE: prints named AUC, no error", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + s <- summary(gg) + expect_s3_class(s, "summary.gg") + # Body should mention all three class names + expect_true(any(grepl("setosa", s$body))) + expect_true(any(grepl("versicolor", s$body))) + expect_true(any(grepl("virginica", s$body))) +}) diff --git a/tests/testthat/test_snapshots.R b/tests/testthat/test_snapshots.R index 9c0b7738..3183bd5a 100644 --- a/tests/testthat/test_snapshots.R +++ b/tests/testthat/test_snapshots.R @@ -297,4 +297,27 @@ if (requireNamespace("randomForest", quietly = TRUE)) { }) } +## ── per_class ROC snapshots (PR #88) ───────────────────────────────────── +if (requireNamespace("randomForest", quietly = TRUE)) { + local({ + set.seed(1L) + rf_iris <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg_pc_iris <- gg_roc(rf_iris, per_class = TRUE) + + test_that("snapshot: gg-roc-multiclass-overlay", { + vdiffr::expect_doppelganger( + "gg-roc-multiclass-overlay", + plot(gg_pc_iris, panel = "overlay") + ) + }) + + test_that("snapshot: gg-roc-multiclass-facet", { + vdiffr::expect_doppelganger( + "gg-roc-multiclass-facet", + plot(gg_pc_iris, panel = "facet") + ) + }) + }) +} + } # end CI guard