Skip to content

Commit 49e4c2f

Browse files
ehrlingerclaude
andauthored
feat: gg_roc per_class=TRUE — per-class OvR ROC curves (#88, closes #72) (#91)
* chore: open v2.7.3.9006 dev increment (PR #88) * feat: gg_roc.randomForest per_class=TRUE — per-class OvR ROC + named AUC * test: gg_roc per_class binary no-op and which_outcome conflict tests * feat: plot.gg_roc per_class=TRUE — overlay and facet panel paths * feat: summary.gg_roc handles named AUC vector for per_class=TRUE * test: add vdiffr snapshots for per_class ROC overlay and facet (PR #88) * docs: NEWS for PR #88 per_class ROC + roxygen for per_class/panel args - NEWS.md: feature entry + version line bump to 2.7.3.9006 - Document gg_roc(per_class=) and plot.gg_roc(panel=) — fixes codoc warnings - Replace em-dashes in R/plot.gg_roc.R with ASCII (non-ASCII warning) - .Rbuildignore: exclude stray .superpowers/ dir R CMD check --as-cran: 0 errors, 0 warnings, 0 notes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor: extract .plot_gg_roc_per_class helper to satisfy cyclocomp lint plot.gg_roc cyclomatic complexity hit 21 (lint budget 20) after the per_class branch. Move per-class rendering into a private helper; the method now delegates with a single return() call. * fix: move plot.gg_roc panel arg after ... for positional back-compat (#91 review) Copilot review: panel before ... is backward-incompatible — a positional caller like plot(x, 1, FALSE) previously routed FALSE into ..., but with panel as the 3rd formal it would bind to panel and error on match.arg. Placing panel after ... makes it name-only; all test calls already use panel = by name. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c580011 commit 49e4c2f

10 files changed

Lines changed: 299 additions & 30 deletions

File tree

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,5 @@ framed.sty
4848
# FUSE filesystem temporaries (safe to ignore; R CMD build already skips dotfiles)
4949
^R/\.fuse_hidden
5050
^\.positai$
51+
^\.superpowers$
5152
^vignettes/.*_cache$

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: ggRandomForests
22
Type: Package
33
Title: Visually Exploring Random Forests
4-
Version: 2.7.3.9005
4+
Version: 2.7.3.9006
55
Date: 2026-05-21
66
Authors@R: person("John", "Ehrlinger",
77
role = c("aut", "cre"),

NEWS.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
11
Package: ggRandomForests
2-
Version: 2.7.3.9005
2+
Version: 2.7.3.9006
33

44
ggRandomForests v2.8.0 (development) — continued
55
=================================================
6+
* **`gg_roc`: per-class one-vs-rest ROC curves (#88, closes #72).**
7+
- `gg_roc()` gains a `per_class = FALSE` argument. When `per_class = TRUE`
8+
and the forest has more than two classes, returns a long-format `gg_roc`
9+
data frame with a `class` factor column and a named AUC vector attribute
10+
(one entry per class, ordered by descending AUC).
11+
- `plot.gg_roc()` gains a `panel = c("overlay", "facet")` argument. When the
12+
`gg_roc` object contains a `class` column, the overlay path colours curves
13+
by class; the facet path wraps each class into its own panel.
14+
- `summary.gg_roc()` now prints named per-class AUC values when the `class`
15+
column is present.
16+
- Binary forests: `per_class = TRUE` is a silent no-op (single-curve result
17+
returned unchanged).
18+
- ROC confidence intervals are deferred to v2.9.0 (issue #7 / #72-CIs).
619
* **varPro variable dependency: `gg_udependent()` (Phase 3).**
720
- `gg_udependent()` extracts cross-variable dependency scores from a
821
`uvarpro` fit using `varPro::get.beta.entropy()` +

R/gg_roc.R

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@
4040
#' predictions. For \code{randomForest}, \code{oob = TRUE} uses out-of-bag
4141
#' vote probabilities (\code{object$votes}); \code{FALSE} uses in-bag
4242
#' \code{predict(type = "prob")}.
43+
#' @param per_class Logical; if \code{TRUE} and the forest has more than two
44+
#' classes, return per-class one-vs-rest ROC curves in a long-format
45+
#' \code{data.frame} with a \code{class} factor column and a named AUC
46+
#' vector attribute (ordered by descending AUC). Binary forests treat
47+
#' \code{per_class = TRUE} as a no-op. Currently honoured by the
48+
#' \code{randomForest} method only.
4349
#' @param ... Extra arguments (currently unused).
4450
#'
4551
#' @return A \code{gg_roc} \code{data.frame} with one row per unique prediction
@@ -93,7 +99,8 @@
9399
#' @aliases gg_roc gg_roc.rfsrc gg_roc.randomForest
94100

95101
#' @export
96-
gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) {
102+
gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE,
103+
per_class = FALSE, ...) {
97104
# Validate that the object was grown with randomForestSRC (grow or predict)
98105
# or is a randomForest object — the two supported class signatures.
99106
if (sum(inherits(object, c("rfsrc", "grow"), TRUE) == c(1, 2)) != 2 &&
@@ -136,38 +143,57 @@ gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) {
136143
invisible(gg_dta)
137144
}
138145
#' @export
139-
gg_roc <- function(object, which_outcome, oob = TRUE, ...) {
146+
gg_roc <- function(object, which_outcome, oob = TRUE, per_class = FALSE, ...) {
140147
UseMethod("gg_roc", object)
141148
}
142149

143150
#' @export
144-
gg_roc.randomForest <- function(object, which_outcome, oob = TRUE, ...) {
145-
# Validate that the object is a genuine randomForest instance.
151+
gg_roc.randomForest <- function(object, which_outcome, oob = TRUE,
152+
per_class = FALSE, ...) {
146153
if (!inherits(object, "randomForest")) {
147-
stop(
148-
"gg_roc.randomForest only works for objects of class 'randomForest'."
149-
)
154+
stop("gg_roc.randomForest only works for objects of class 'randomForest'.")
150155
}
151-
152-
# Default to computing the ROC curve for all outcome classes.
153156
if (missing(which_outcome)) {
154157
which_outcome <- "all"
155158
}
156-
157159
if (!(object$type == "classification")) {
158160
stop("gg_roc only works with classification forests")
159161
}
160162

163+
lvls <- levels(object$y)
164+
n_class <- length(lvls)
165+
166+
# ── per_class = TRUE path (multi-class only) ─────────────────────────────
167+
if (isTRUE(per_class) && n_class > 2L) {
168+
if (!missing(which_outcome) && !identical(which_outcome, "all")) {
169+
message("which_outcome is ignored when per_class = TRUE.")
170+
}
171+
prob <- .rf_prob_matrix(object, oob, lvls)
172+
dta <- object$y
173+
curves <- lapply(seq_along(lvls), function(k) {
174+
cv <- .rf_one_class_roc(dta, prob, k, lvls)
175+
cv$class <- lvls[k]
176+
cv
177+
})
178+
auc_vals <- vapply(curves, calc_auc, numeric(1L))
179+
names(auc_vals) <- lvls
180+
auc_ord <- order(auc_vals, decreasing = TRUE)
181+
auc_vals <- auc_vals[auc_ord]
182+
gg_dta <- do.call(rbind, curves)
183+
gg_dta$class <- factor(gg_dta$class, levels = lvls[auc_ord])
184+
class(gg_dta) <- c("gg_roc", class(gg_dta))
185+
attr(gg_dta, "auc") <- auc_vals
186+
gg_dta <- .set_provenance(gg_dta, object)
187+
return(invisible(gg_dta))
188+
}
189+
190+
# ── Standard path (binary, or per_class not requested) ──────────────────
161191
# For randomForest objects the response is stored in $y (not $yvar).
162192
gg_dta <- # nolint: object_usage_linter
163-
calc_roc(object,
164-
object$y,
165-
which_outcome = which_outcome,
166-
oob = oob
167-
)
168-
class(gg_dta) <- c("gg_roc", class(gg_dta))
169-
gg_dta <- .set_provenance(gg_dta, object)
170-
193+
calc_roc(object, object$y, which_outcome = which_outcome, oob = oob)
194+
class(gg_dta) <- c("gg_roc", class(gg_dta))
195+
attr(gg_dta, "auc") <- calc_auc(gg_dta)
196+
gg_dta <- .set_provenance(gg_dta, object)
171197
invisible(gg_dta)
172198
}
173199

R/plot.gg_roc.R

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
#' the forest has more than two classes, ROC curves for all classes are
2525
#' overlaid in a single plot. For binary forests \code{NULL} defaults to
2626
#' class index 2.
27+
#' @param panel Character; layout for per-class ROC objects (those produced by
28+
#' \code{gg_roc(..., per_class = TRUE)}). \code{"overlay"} (default) draws all
29+
#' class curves in one panel coloured by class; \code{"facet"} wraps each
30+
#' class into its own panel. Ignored for single-class \code{gg_roc} objects.
2731
#' @param ... Additional arguments forwarded to \code{\link{gg_roc}} when
2832
#' \code{x} is a raw forest object (e.g. \code{oob = FALSE}).
2933
#'
@@ -73,7 +77,12 @@
7377
#' for (i in seq_len(n_cls)) print(plot(gg_roc(rfsrc_iris, which_outcome = i)))
7478
#'
7579
#' @export
76-
plot.gg_roc <- function(x, which_outcome = NULL, ...) {
80+
plot.gg_roc <- function(x, which_outcome = NULL, ...,
81+
panel = c("overlay", "facet")) {
82+
# `panel` is placed after `...` so it is name-only: this preserves
83+
# positional back-compatibility for existing callers (e.g.
84+
# plot(x, 1, FALSE) still routes the 3rd positional arg into `...`).
85+
panel <- match.arg(panel)
7786
gg_dta <- x
7887

7988
## ---- Accept a raw rfsrc or randomForest object -----------------------
@@ -118,8 +127,17 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) {
118127
}
119128
}
120129

121-
## ---- Single-class ROC plot ------------------------------------------
130+
## ---- Single-class ROC plot (or per_class long-format) ----------------
122131
if (inherits(gg_dta, "gg_roc")) {
132+
133+
# Per-class detection: gg_roc produced by gg_roc(..., per_class = TRUE)
134+
# carries a 'class' column (factor) + a named AUC vector attribute.
135+
# Rendering is delegated to a helper to keep this method's cyclomatic
136+
# complexity within the project lint budget.
137+
if ("class" %in% names(gg_dta)) {
138+
return(.plot_gg_roc_per_class(gg_dta, attr(x, "auc"), panel))
139+
}
140+
123141
# Sort by specificity so the ROC curve is drawn left-to-right
124142
gg_dta <- gg_dta[order(gg_dta$spec), ]
125143
# False positive rate = 1 - specificity
@@ -193,7 +211,54 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) {
193211
) +
194212
ggplot2::coord_fixed()
195213

196-
# Multi-class: do not annotate a single AUC value each class has its own.
214+
# Multi-class: do not annotate a single AUC value - each class has its own.
197215
}
198216
return(gg_plt)
199217
}
218+
219+
# Render a per-class (one-vs-rest) ROC object produced by
220+
# gg_roc(..., per_class = TRUE). Split out of plot.gg_roc() so that method
221+
# stays within the project's cyclomatic-complexity lint budget.
222+
#
223+
# gg_dta : long-format gg_roc data frame with a 'class' factor column
224+
# auc : named numeric AUC vector (one entry per class) or NULL
225+
# panel : "overlay" (curves coloured by class) or "facet" (one panel each)
226+
.plot_gg_roc_per_class <- function(gg_dta, auc, panel) {
227+
gg_dta$fpr <- 1 - gg_dta$spec
228+
229+
if (panel == "overlay") {
230+
gg_plt <- ggplot2::ggplot(gg_dta) +
231+
ggplot2::geom_line(ggplot2::aes(
232+
x = .data$fpr, y = .data$sens, color = .data$class
233+
)) +
234+
ggplot2::labs(
235+
x = "1 - Specificity (FPR)", y = "Sensitivity (TPR)",
236+
color = "Class"
237+
)
238+
} else {
239+
gg_plt <- ggplot2::ggplot(gg_dta) +
240+
ggplot2::geom_line(ggplot2::aes(x = .data$fpr, y = .data$sens)) +
241+
ggplot2::labs(x = "1 - Specificity (FPR)", y = "Sensitivity (TPR)") +
242+
ggplot2::facet_wrap(~class)
243+
}
244+
245+
gg_plt <- gg_plt +
246+
ggplot2::geom_abline(
247+
slope = 1, intercept = 0,
248+
col = "red", linetype = 2, linewidth = .5
249+
) +
250+
ggplot2::coord_fixed()
251+
252+
# AUC caption - top 5 classes by descending AUC (already sorted)
253+
if (!is.null(auc) && length(auc) > 0L) {
254+
top_n <- min(5L, length(auc))
255+
auc_str <- paste(
256+
sprintf("%s=%.3g", names(auc)[seq_len(top_n)], auc[seq_len(top_n)]),
257+
collapse = ", "
258+
)
259+
if (length(auc) > 5L) auc_str <- paste0(auc_str, ", ...")
260+
gg_plt <- gg_plt +
261+
ggplot2::labs(caption = paste("OvR ROC, per_class=TRUE. AUC:", auc_str))
262+
}
263+
gg_plt
264+
}

R/summary_methods.R

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,19 @@ summary.gg_partial_varpro <- function(object, ...) {
189189
#' @rdname summary.gg
190190
#' @export
191191
summary.gg_roc <- function(object, ...) {
192-
body <- c(
193-
sprintf("thresholds: %d", nrow(object)),
194-
sprintf("AUC: %.4g",
195-
attr(object, "auc") %||% .gg_auc_trap(object))
196-
)
192+
auc <- attr(object, "auc") %||% .gg_auc_trap(object)
193+
if ("class" %in% names(object)) {
194+
# per_class = TRUE path: named AUC vector, one entry per class
195+
n_cls <- nlevels(object$class)
196+
auc_str <- paste(sprintf("%s=%.4g", names(auc), auc), collapse = ", ")
197+
body <- c(sprintf("classes: %d", n_cls),
198+
sprintf("AUC: %s", auc_str))
199+
} else {
200+
body <- c(
201+
sprintf("thresholds: %d", nrow(object)),
202+
sprintf("AUC: %.4g", auc)
203+
)
204+
}
197205
.summary_skel(object, "gg_roc", body)
198206
}
199207

man/gg_roc.rfsrc.Rd

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/plot.gg_roc.Rd

Lines changed: 6 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)