Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ framed.sty
# FUSE filesystem temporaries (safe to ignore; R CMD build already skips dotfiles)
^R/\.fuse_hidden
^\.positai$
^\.superpowers$
^vignettes/.*_cache$
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: ggRandomForests
Type: Package
Title: Visually Exploring Random Forests
Version: 2.7.3.9005
Version: 2.7.3.9006
Date: 2026-05-21
Authors@R: person("John", "Ehrlinger",
role = c("aut", "cre"),
Expand Down
15 changes: 14 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
Package: ggRandomForests
Version: 2.7.3.9005
Version: 2.7.3.9006

ggRandomForests v2.8.0 (development) — continued
=================================================
* **`gg_roc`: per-class one-vs-rest ROC curves (#88, closes #72).**
- `gg_roc()` gains a `per_class = FALSE` argument. When `per_class = TRUE`
and the forest has more than two classes, returns a long-format `gg_roc`
data frame with a `class` factor column and a named AUC vector attribute
(one entry per class, ordered by descending AUC).
- `plot.gg_roc()` gains a `panel = c("overlay", "facet")` argument. When the
`gg_roc` object contains a `class` column, the overlay path colours curves
by class; the facet path wraps each class into its own panel.
- `summary.gg_roc()` now prints named per-class AUC values when the `class`
column is present.
- Binary forests: `per_class = TRUE` is a silent no-op (single-curve result
returned unchanged).
- ROC confidence intervals are deferred to v2.9.0 (issue #7 / #72-CIs).
* **varPro variable dependency: `gg_udependent()` (Phase 3).**
- `gg_udependent()` extracts cross-variable dependency scores from a
`uvarpro` fit using `varPro::get.beta.entropy()` +
Expand Down
62 changes: 44 additions & 18 deletions R/gg_roc.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@
#' predictions. For \code{randomForest}, \code{oob = TRUE} uses out-of-bag
#' vote probabilities (\code{object$votes}); \code{FALSE} uses in-bag
#' \code{predict(type = "prob")}.
#' @param per_class Logical; if \code{TRUE} and the forest has more than two
#' classes, return per-class one-vs-rest ROC curves in a long-format
#' \code{data.frame} with a \code{class} factor column and a named AUC
#' vector attribute (ordered by descending AUC). Binary forests treat
#' \code{per_class = TRUE} as a no-op. Currently honoured by the
#' \code{randomForest} method only.
#' @param ... Extra arguments (currently unused).
#'
#' @return A \code{gg_roc} \code{data.frame} with one row per unique prediction
Expand Down Expand Up @@ -93,7 +99,8 @@
#' @aliases gg_roc gg_roc.rfsrc gg_roc.randomForest

#' @export
gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) {
gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE,
per_class = FALSE, ...) {
# Validate that the object was grown with randomForestSRC (grow or predict)
# or is a randomForest object — the two supported class signatures.
if (sum(inherits(object, c("rfsrc", "grow"), TRUE) == c(1, 2)) != 2 &&
Expand Down Expand Up @@ -136,38 +143,57 @@ gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) {
invisible(gg_dta)
}
#' @export
gg_roc <- function(object, which_outcome, oob = TRUE, ...) {
gg_roc <- function(object, which_outcome, oob = TRUE, per_class = FALSE, ...) {
UseMethod("gg_roc", object)
}

#' @export
gg_roc.randomForest <- function(object, which_outcome, oob = TRUE, ...) {
# Validate that the object is a genuine randomForest instance.
gg_roc.randomForest <- function(object, which_outcome, oob = TRUE,
per_class = FALSE, ...) {
if (!inherits(object, "randomForest")) {
stop(
"gg_roc.randomForest only works for objects of class 'randomForest'."
)
stop("gg_roc.randomForest only works for objects of class 'randomForest'.")
}

# Default to computing the ROC curve for all outcome classes.
if (missing(which_outcome)) {
which_outcome <- "all"
}

if (!(object$type == "classification")) {
stop("gg_roc only works with classification forests")
}

lvls <- levels(object$y)
n_class <- length(lvls)

# ── per_class = TRUE path (multi-class only) ─────────────────────────────
if (isTRUE(per_class) && n_class > 2L) {
if (!missing(which_outcome) && !identical(which_outcome, "all")) {
message("which_outcome is ignored when per_class = TRUE.")
}
prob <- .rf_prob_matrix(object, oob, lvls)
dta <- object$y
curves <- lapply(seq_along(lvls), function(k) {
cv <- .rf_one_class_roc(dta, prob, k, lvls)
cv$class <- lvls[k]
cv
})
auc_vals <- vapply(curves, calc_auc, numeric(1L))
names(auc_vals) <- lvls
auc_ord <- order(auc_vals, decreasing = TRUE)
auc_vals <- auc_vals[auc_ord]
gg_dta <- do.call(rbind, curves)
gg_dta$class <- factor(gg_dta$class, levels = lvls[auc_ord])
class(gg_dta) <- c("gg_roc", class(gg_dta))
attr(gg_dta, "auc") <- auc_vals
gg_dta <- .set_provenance(gg_dta, object)
return(invisible(gg_dta))
}

# ── Standard path (binary, or per_class not requested) ──────────────────
# For randomForest objects the response is stored in $y (not $yvar).
gg_dta <- # nolint: object_usage_linter
calc_roc(object,
object$y,
which_outcome = which_outcome,
oob = oob
)
class(gg_dta) <- c("gg_roc", class(gg_dta))
gg_dta <- .set_provenance(gg_dta, object)

calc_roc(object, object$y, which_outcome = which_outcome, oob = oob)
class(gg_dta) <- c("gg_roc", class(gg_dta))
attr(gg_dta, "auc") <- calc_auc(gg_dta)
gg_dta <- .set_provenance(gg_dta, object)
invisible(gg_dta)
}

Expand Down
71 changes: 68 additions & 3 deletions R/plot.gg_roc.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
#' the forest has more than two classes, ROC curves for all classes are
#' overlaid in a single plot. For binary forests \code{NULL} defaults to
#' class index 2.
#' @param panel Character; layout for per-class ROC objects (those produced by
#' \code{gg_roc(..., per_class = TRUE)}). \code{"overlay"} (default) draws all
#' class curves in one panel coloured by class; \code{"facet"} wraps each
#' class into its own panel. Ignored for single-class \code{gg_roc} objects.
#' @param ... Additional arguments forwarded to \code{\link{gg_roc}} when
#' \code{x} is a raw forest object (e.g. \code{oob = FALSE}).
#'
Expand Down Expand Up @@ -73,7 +77,12 @@
#' for (i in seq_len(n_cls)) print(plot(gg_roc(rfsrc_iris, which_outcome = i)))
#'
#' @export
plot.gg_roc <- function(x, which_outcome = NULL, ...) {
plot.gg_roc <- function(x, which_outcome = NULL, ...,
panel = c("overlay", "facet")) {
# `panel` is placed after `...` so it is name-only: this preserves
# positional back-compatibility for existing callers (e.g.
# plot(x, 1, FALSE) still routes the 3rd positional arg into `...`).
panel <- match.arg(panel)
gg_dta <- x

## ---- Accept a raw rfsrc or randomForest object -----------------------
Expand Down Expand Up @@ -118,8 +127,17 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) {
}
}

## ---- Single-class ROC plot ------------------------------------------
## ---- Single-class ROC plot (or per_class long-format) ----------------
if (inherits(gg_dta, "gg_roc")) {

# Per-class detection: gg_roc produced by gg_roc(..., per_class = TRUE)
# carries a 'class' column (factor) + a named AUC vector attribute.
# Rendering is delegated to a helper to keep this method's cyclomatic
# complexity within the project lint budget.
if ("class" %in% names(gg_dta)) {
return(.plot_gg_roc_per_class(gg_dta, attr(x, "auc"), panel))
}

# Sort by specificity so the ROC curve is drawn left-to-right
gg_dta <- gg_dta[order(gg_dta$spec), ]
# False positive rate = 1 - specificity
Expand Down Expand Up @@ -193,7 +211,54 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) {
) +
ggplot2::coord_fixed()

# Multi-class: do not annotate a single AUC value each class has its own.
# Multi-class: do not annotate a single AUC value - each class has its own.
}
return(gg_plt)
}

# Render a per-class (one-vs-rest) ROC object produced by
# gg_roc(..., per_class = TRUE). Split out of plot.gg_roc() so that method
# stays within the project's cyclomatic-complexity lint budget.
#
# gg_dta : long-format gg_roc data frame with a 'class' factor column
# auc : named numeric AUC vector (one entry per class) or NULL
# panel : "overlay" (curves coloured by class) or "facet" (one panel each)
.plot_gg_roc_per_class <- function(gg_dta, auc, panel) {
gg_dta$fpr <- 1 - gg_dta$spec

if (panel == "overlay") {
gg_plt <- ggplot2::ggplot(gg_dta) +
ggplot2::geom_line(ggplot2::aes(
x = .data$fpr, y = .data$sens, color = .data$class
)) +
ggplot2::labs(
x = "1 - Specificity (FPR)", y = "Sensitivity (TPR)",
color = "Class"
)
} else {
gg_plt <- ggplot2::ggplot(gg_dta) +
ggplot2::geom_line(ggplot2::aes(x = .data$fpr, y = .data$sens)) +
ggplot2::labs(x = "1 - Specificity (FPR)", y = "Sensitivity (TPR)") +
ggplot2::facet_wrap(~class)
}

gg_plt <- gg_plt +
ggplot2::geom_abline(
slope = 1, intercept = 0,
col = "red", linetype = 2, linewidth = .5
) +
ggplot2::coord_fixed()

# AUC caption - top 5 classes by descending AUC (already sorted)
if (!is.null(auc) && length(auc) > 0L) {
top_n <- min(5L, length(auc))
auc_str <- paste(
sprintf("%s=%.3g", names(auc)[seq_len(top_n)], auc[seq_len(top_n)]),
collapse = ", "
)
if (length(auc) > 5L) auc_str <- paste0(auc_str, ", ...")
gg_plt <- gg_plt +
ggplot2::labs(caption = paste("OvR ROC, per_class=TRUE. AUC:", auc_str))
}
gg_plt
}
18 changes: 13 additions & 5 deletions R/summary_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,19 @@ summary.gg_partial_varpro <- function(object, ...) {
#' @rdname summary.gg
#' @export
summary.gg_roc <- function(object, ...) {
body <- c(
sprintf("thresholds: %d", nrow(object)),
sprintf("AUC: %.4g",
attr(object, "auc") %||% .gg_auc_trap(object))
)
auc <- attr(object, "auc") %||% .gg_auc_trap(object)
if ("class" %in% names(object)) {
# per_class = TRUE path: named AUC vector, one entry per class
n_cls <- nlevels(object$class)
auc_str <- paste(sprintf("%s=%.4g", names(auc), auc), collapse = ", ")
body <- c(sprintf("classes: %d", n_cls),
sprintf("AUC: %s", auc_str))
} else {
body <- c(
sprintf("thresholds: %d", nrow(object)),
sprintf("AUC: %.4g", auc)
)
}
.summary_skel(object, "gg_roc", body)
}

Expand Down
9 changes: 8 additions & 1 deletion man/gg_roc.rfsrc.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/plot.gg_roc.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading