From f010bfd64e9b8d460b6017c0a872add3d15d9847 Mon Sep 17 00:00:00 2001 From: John Ehrlinger Date: Thu, 21 May 2026 10:32:46 -0400 Subject: [PATCH 1/3] docs: add randomForest finishing-work design spec (PR #87 + #88) Covers two v2.8.0 completion PRs: - PR #87: gg_variable.randomForest classification yhat fix + stale #81/#82 close - PR #88: multi-class gg_roc with per_class=TRUE (issue #72, no CIs) Co-Authored-By: Claude Opus 4.7 (1M context) --- .../2026-05-21-rf-finishing-work-design.md | 371 ++++++++++++++++++ 1 file changed, 371 insertions(+) create mode 100644 dev/plans/2026-05-21-rf-finishing-work-design.md diff --git a/dev/plans/2026-05-21-rf-finishing-work-design.md b/dev/plans/2026-05-21-rf-finishing-work-design.md new file mode 100644 index 00000000..4720e3d4 --- /dev/null +++ b/dev/plans/2026-05-21-rf-finishing-work-design.md @@ -0,0 +1,371 @@ +--- +date: 2026-05-21 +status: approved +scope: v2.8.0 — randomForest engine finishing work +author: John Ehrlinger (design partner: Claude) +issues: "#81, #82 (close stale), #82 follow-up (gg_variable classification), #72 (multi-class ROC, no CIs)" +--- + +# randomForest Engine — Finishing Work Design + +Two independent PRs targeting v2.8.0 completion. All work is on the +`randomForest` engine path; no rfsrc or varPro behavior changes. + +--- + +## Scope + +| PR | Content | Version bump | +|---|---|---| +| #87 | `gg_variable.randomForest` classification fix + close stale #81/#82 | `2.7.3.9005` | +| #88 | Multi-class `gg_roc` with `per_class = TRUE` (issue #72, no CIs) | `2.7.3.9006` | + +CIs on ROC curves (#7, tracked under #72) are deferred to v2.9.0. + +--- + +## PR #87 — `gg_variable.randomForest` Classification Fix + +### Problem + +`gg_variable.randomForest` for classification stores a single `yhat` column +containing `as.vector(object$predicted)` — a factor of predicted class labels, +not per-class probabilities. This mismatches what `gg_variable.rfsrc` produces +(`yhat.` probability columns from `object$predicted.oob`), so +`plot.gg_variable`'s multi-class pivot never fires. Two downstream bugs follow: + +1. The multi-class numeric path in `plot.gg_variable` has no `smooth = TRUE` + block, so smooth curves are silently skipped for 3+ class forests. +2. The binary path's `smooth = TRUE` block calls `geom_smooth(...)` with no + `aes()`, so `stat_smooth()` fails ("requires missing aesthetics: x and y") + when `layer_data()` is called on the result. + +### Fix: Extractor (`R/gg_variable.R`) + +In `gg_variable.randomForest`, replace the single-yhat classification branch: + +```r +# Before (wrong — class labels, not probabilities): +gg_dta$yhat <- as.vector(object$predicted) +if (object$type == "classification") { + gg_dta$yvar <- response +} + +# After (correct — per-class OOB vote fractions, matching rfsrc shape): +if (object$type == "classification") { + preds <- object$votes # matrix: n × n_classes, OOB vote fractions + colnames(preds) <- paste0("yhat.", colnames(preds)) + gg_dta <- cbind(gg_dta, preds) + gg_dta$yvar <- response +} else { + gg_dta$yhat <- as.vector(object$predicted) +} +``` + +`object$votes` is the n × n_classes matrix of OOB per-class vote fractions +(equivalent to `object$predicted.oob` for rfsrc classification). With this +change, `gg_variable.randomForest` produces `yhat.setosa`, `yhat.versicolor`, +`yhat.virginica` for iris — the same shape as the rfsrc path — so +`plot.gg_variable` dispatches identically for both engines. + +### Fix: Plot method (`R/plot.gg_variable.R`) + +**Fix 1 — Binary classification smooth missing aes (line ~519):** + +```r +# Before: +if (smooth) { + gg_plt[[ind]] <- gg_plt[[ind]] + + ggplot2::geom_smooth(...) +} + +# After: +if (smooth) { + gg_plt[[ind]] <- gg_plt[[ind]] + + ggplot2::geom_smooth( + ggplot2::aes(x = .data$var, y = .data$yhat), ... + ) +} +``` + +**Fix 2 — Multi-class numeric path missing smooth block (after the `geom_point` +at line ~557):** + +```r +# Add after the multi-class geom_point block: +if (smooth) { + gg_plt[[ind]] <- gg_plt[[ind]] + + ggplot2::geom_smooth( + ggplot2::aes(x = .data$var, y = .data$yhat), ... + ) +} +``` + +### Tests + +**`tests/testthat/test_gg_variable.R`** — add a `randomForest` classification +block: + +```r +## ── randomForest classification ────────────────────────────────────────────── +test_that("gg_variable.randomForest classification: yhat.* columns present", { + skip_if_not_installed("randomForest") + set.seed(42L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) + gg <- gg_variable(rf) + expect_true(all(c("yhat.setosa", "yhat.versicolor", "yhat.virginica") + %in% names(gg))) + expect_false("yhat" %in% names(gg)) # no bare yhat for multi-class + expect_true("yvar" %in% names(gg)) +}) + +test_that("gg_variable.randomForest classification: plot returns patchwork", { + skip_if_not_installed("randomForest") + set.seed(42L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) + gg <- gg_variable(rf) + p <- plot(gg) + expect_s3_class(p, "patchwork") +}) + +test_that("gg_variable.randomForest classification: layer_data smokeable", { + skip_if_not_installed("randomForest") + set.seed(42L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) + gg <- gg_variable(rf) + p <- plot(gg, xvar = "Sepal.Length") + expect_no_error(ggplot2::layer_data(p, 1L)) +}) +``` + +**`tests/testthat/test_snapshots.R`** — add inside the `VDIFFR_RUN_TESTS=true` +guard: + +```r +if (requireNamespace("randomForest", quietly = TRUE)) { + local({ + set.seed(42L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) + gg <- gg_variable(rf) + test_that("snapshot: gg-variable-rf-classification-default", { + vdiffr::expect_doppelganger("gg-variable-rf-classification-default", + plot(gg)) + }) + test_that("snapshot: gg-variable-rf-classification-smooth", { + vdiffr::expect_doppelganger("gg-variable-rf-classification-smooth", + plot(gg, xvar = "Sepal.Length", smooth = TRUE)) + }) + }) +} +``` + +### Housekeeping + +Close issues #81 and #82 via the PR description (`Closes #81, Closes #82`). +Both were fully resolved by PR #83 but GitHub's auto-close never fired. + +--- + +## PR #88 — Multi-class `gg_roc` with `per_class = TRUE` + +### Goal + +Extend `gg_roc()` to return per-class one-vs-rest ROC curves in a long-format +data frame when `per_class = TRUE`. `which_outcome = "all"` retains its current +macro-average meaning (backward compatible). No CI computation in this PR. + +### Extractor changes (`R/gg_roc.R`, `R/calc_roc.R`) + +**New signature:** + +```r +gg_roc(object, which_outcome = "all", per_class = FALSE, ...) +``` + +**`per_class = TRUE` behaviour:** + +- For forests with > 2 classes: compute one-vs-rest ROC for every class *k* + (scores = `votes[, k]`, positive = `(y == k)`), stack into a long data frame + with columns `class` (factor), `fpr`, `tpr`. +- `attr(gg, "auc")` becomes a named numeric vector: one entry per class, + ordered by descending AUC. Class factor levels follow the same order. +- For binary forests: `per_class = TRUE` is a no-op — returns the single-curve + result with no `class` column and a scalar `auc` attribute (same as + `per_class = FALSE`). +- If `per_class = TRUE` AND `which_outcome != "all"` (i.e., a specific class + integer): `per_class` wins, a `message()` informs the caller that + `which_outcome` is ignored when `per_class = TRUE`. + +**Internal helper `calc_roc_one_vs_rest(scores, y, k)`** — single class OvR +ROC, returns a data frame `(fpr, tpr)` plus scalar AUC. Extracted so it can be +reused for both per-class computation and macro-average. + +### Plot method (`R/plot.gg_roc.R`) + +Detection: `has_class <- "class" %in% names(x)`. + +```r +plot.gg_roc(x, panel = c("overlay", "facet"), ...) +``` + +**When `has_class = FALSE`** (no `class` column): behaves exactly as today — +single-panel ROC curve with diagonal reference. No change. + +**When `has_class = TRUE`:** + +- `panel = "overlay"` (default): single panel, `aes(color = class)`, one curve + per class, legend titled "Class". Diagonal reference line in grey. +- `panel = "facet"`: `facet_wrap(~ class)`, individual y-axis per class, no + color legend. Diagonal reference line in each panel. + +Both modes: `geom_step` for the ROC curve (same as current single-curve plot), +`geom_abline(slope = 1, intercept = 0, linetype = 2, color = "grey50")` for +reference. + +Caption: `"OvR ROC — per_class = TRUE. AUC: =, ..."` (truncated +if > 5 classes, showing top 5 by AUC). + +### Summary method (`R/summary_methods.R`) + +`summary.gg_roc` already reads `attr(object, "auc")`. Extend to handle named +vector: + +```r +auc_str <- if (length(auc) == 1L) { + sprintf("AUC: %.4g", auc) +} else { + paste("AUC:", paste(sprintf("%s=%.4g", names(auc), auc), collapse = ", ")) +} +``` + +### Tests (`tests/testthat/test_gg_roc.R`) + +```r +## ── per_class = TRUE ───────────────────────────────────────────────────────── +test_that("gg_roc per_class=TRUE: long format with class column", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + expect_true("class" %in% names(gg)) + expect_true(all(c("fpr", "tpr") %in% names(gg))) + expect_equal(nlevels(gg$class), 3L) +}) + +test_that("gg_roc per_class=TRUE: auc attr is named vector length 3", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + auc <- attr(gg, "auc") + expect_length(auc, 3L) + expect_named(auc) + # setosa should be near-perfect on iris + expect_gt(auc[["setosa"]], 0.99) +}) + +test_that("gg_roc per_class=TRUE on binary: no class column (no-op)", { + skip_if_not_installed("randomForest") + set.seed(1L) + bin_data <- iris[iris$Species != "virginica", ] + bin_data$Species <- droplevels(bin_data$Species) + rf <- randomForest::randomForest(Species ~ ., data = bin_data, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + expect_false("class" %in% names(gg)) + expect_length(attr(gg, "auc"), 1L) +}) + +test_that("gg_roc which_outcome='all' still returns macro-average", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, which_outcome = "all") + expect_false("class" %in% names(gg)) +}) + +test_that("gg_roc per_class=TRUE + which_outcome integer: message + per_class wins", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + expect_message( + gg <- gg_roc(rf, per_class = TRUE, which_outcome = 1L), + "which_outcome.*ignored.*per_class" + ) + expect_true("class" %in% names(gg)) +}) + +## ── plot.gg_roc multi-class ────────────────────────────────────────────────── +test_that("plot.gg_roc per_class: overlay returns ggplot", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + p <- plot(gg, panel = "overlay") + expect_s3_class(p, "ggplot") +}) + +test_that("plot.gg_roc per_class: facet returns ggplot", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + p <- plot(gg, panel = "facet") + expect_s3_class(p, "ggplot") +}) +``` + +**`tests/testthat/test_snapshots.R`** — add inside the `VDIFFR_RUN_TESTS=true` +guard: + +```r +if (requireNamespace("randomForest", quietly = TRUE)) { + local({ + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + test_that("snapshot: gg-roc-multiclass-overlay", { + vdiffr::expect_doppelganger("gg-roc-multiclass-overlay", + plot(gg, panel = "overlay")) + }) + test_that("snapshot: gg-roc-multiclass-facet", { + vdiffr::expect_doppelganger("gg-roc-multiclass-facet", + plot(gg, panel = "facet")) + }) + }) +} +``` + +### `_pkgdown.yml` + +No new exported functions. No change needed. + +--- + +## Non-Goals (this spec) + +- ROC confidence intervals (deferred to v2.9.0 as issue #7 / #72-CIs) +- Hazard estimates (#71 / #4 / #5) — post-v2.8.0 +- Issue #15 (consistent yhat scale) — pre-2015 enhancement, separate planning +- Any rfsrc-path behavior changes +- Any varPro changes + +--- + +## Files Touched + +| PR | File | Action | +|---|---|---| +| #87 | `R/gg_variable.R` | Modify: fix classification yhat columns | +| #87 | `R/plot.gg_variable.R` | Modify: binary smooth aes + multi-class smooth block | +| #87 | `tests/testthat/test_gg_variable.R` | Modify: add RF classification tests | +| #87 | `tests/testthat/test_snapshots.R` | Modify: add 2 RF classification snapshots | +| #87 | `DESCRIPTION` | Modify: version `2.7.3.9004` → `2.7.3.9005` | +| #87 | `NEWS.md` | Modify: add fix entry | +| #88 | `R/gg_roc.R` | Modify: add `per_class` argument | +| #88 | `R/calc_roc.R` | Modify: extract `calc_roc_one_vs_rest` helper; per-class dispatch | +| #88 | `R/plot.gg_roc.R` | Modify: `panel` argument; multi-class overlay + facet paths | +| #88 | `R/summary_methods.R` | Modify: `summary.gg_roc` named-vector AUC display | +| #88 | `tests/testthat/test_gg_roc.R` | Modify: add per_class tests | +| #88 | `tests/testthat/test_snapshots.R` | Modify: add 2 multi-class ROC snapshots | +| #88 | `DESCRIPTION` | Modify: version `2.7.3.9005` → `2.7.3.9006` | +| #88 | `NEWS.md` | Modify: add feature entry | From ba14a938eb1cf57d382a7eee2834b5ad9c62f95b Mon Sep 17 00:00:00 2001 From: John Ehrlinger Date: Thu, 21 May 2026 10:53:14 -0400 Subject: [PATCH 2/3] docs: add implementation plans for PR #87 (gg_variable classification) and PR #88 (per_class ROC) --- ...1-rf-87-gg-variable-classification-plan.md | 453 +++++++++++ .../2026-05-21-rf-88-multiclass-roc-plan.md | 712 ++++++++++++++++++ 2 files changed, 1165 insertions(+) create mode 100644 dev/plans/2026-05-21-rf-87-gg-variable-classification-plan.md create mode 100644 dev/plans/2026-05-21-rf-88-multiclass-roc-plan.md diff --git a/dev/plans/2026-05-21-rf-87-gg-variable-classification-plan.md b/dev/plans/2026-05-21-rf-87-gg-variable-classification-plan.md new file mode 100644 index 00000000..6ac8b6ca --- /dev/null +++ b/dev/plans/2026-05-21-rf-87-gg-variable-classification-plan.md @@ -0,0 +1,453 @@ +# PR #87: gg_variable.randomForest Classification Fix Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Fix `gg_variable.randomForest` for classification forests so it stores per-class OOB vote fractions (`yhat.` columns from `object$votes`) instead of predicted class-label factors, making it structurally identical to the rfsrc path; fix two downstream `plot.gg_variable` smooth bugs exposed by the extractor fix; close stale issues #81 and #82. + +**Architecture:** Single-root bug: `gg_variable.randomForest` stores `as.vector(object$predicted)` (a class-label factor) instead of `object$votes` (per-class OOB vote probability matrix). Swapping to `object$votes` makes `plot.gg_variable`'s multi-class pivot fire correctly for `randomForest` objects — the same path already used for `rfsrc`. Two downstream `plot.gg_variable` bugs become reachable after the fix: (1) binary `smooth=TRUE` calls `geom_smooth(...)` with no `aes()` (stat_smooth fails when layer_data() is called), and (2) the multi-class numeric path is missing the `smooth=TRUE` block entirely. + +**Tech Stack:** R, randomForest, ggplot2, patchwork, testthat, vdiffr (snapshots only) + +--- + +## File Map + +| File | Change | +|---|---| +| `R/gg_variable.R` | Fix classification branch: `object$predicted` → `object$votes` | +| `R/plot.gg_variable.R` | Fix binary smooth aes; add multi-class smooth block | +| `tests/testthat/test_gg_variable.R` | Add randomForest classification tests | +| `tests/testthat/test_snapshots.R` | Add 2 vdiffr snapshots | +| `DESCRIPTION` | Version `2.7.3.9004` → `2.7.3.9005` | +| `NEWS.md` | Add fix entry | + +--- + +## Task T0: Branch setup and version bump + +**Files:** +- Modify: `DESCRIPTION` + +- [ ] **Step 1: Create the feature branch** + +```bash +git checkout -b fix/rf-87-gg-variable-classification origin/main +``` + +Expected: `Switched to a new branch 'fix/rf-87-gg-variable-classification'` + +- [ ] **Step 2: Bump version in DESCRIPTION** + +Open `DESCRIPTION`. Change line: +``` +Version: 2.7.3.9004 +``` +to: +``` +Version: 2.7.3.9005 +``` +Also update: +``` +Date: 2026-05-20 +``` +to: +``` +Date: 2026-05-21 +``` + +- [ ] **Step 3: Confirm package loads** + +```r +devtools::load_all() +``` + +Expected: no errors. + +- [ ] **Step 4: Commit** + +```bash +git add DESCRIPTION +git commit -m "chore: open v2.7.3.9005 dev increment (PR #87)" +``` + +--- + +## Task T1: Write failing tests — gg_variable.randomForest classification + +**Files:** +- Modify: `tests/testthat/test_gg_variable.R` + +- [ ] **Step 1: Append the three new test cases to the end of `tests/testthat/test_gg_variable.R`** + +```r +## ── randomForest classification (PR #87) ───────────────────────────────────── + +test_that("gg_variable.randomForest classification: produces yhat.* columns not yhat", { + skip_if_not_installed("randomForest") + set.seed(42L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) + gg <- gg_variable(rf) + # Must have one column per class + expect_true(all(c("yhat.setosa", "yhat.versicolor", "yhat.virginica") + %in% names(gg))) + # Must NOT have a bare yhat column for multi-class + expect_false("yhat" %in% names(gg)) + # Observed-class column must be present + expect_true("yvar" %in% names(gg)) + # Vote fractions must be in [0, 1] and row-sum to ~1 + vote_cols <- c("yhat.setosa", "yhat.versicolor", "yhat.virginica") + expect_true(all(gg[, vote_cols] >= 0)) + expect_true(all(gg[, vote_cols] <= 1)) + expect_true(all(abs(rowSums(gg[, vote_cols]) - 1) < 1e-6)) +}) + +test_that("gg_variable.randomForest classification: plot returns patchwork for all xvar", { + skip_if_not_installed("randomForest") + set.seed(42L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) + gg <- gg_variable(rf) + p <- plot(gg) + expect_true(inherits(p, "patchwork") || inherits(p, "ggplot")) +}) + +test_that("gg_variable.randomForest classification: layer_data works on single-xvar plot", { + skip_if_not_installed("randomForest") + set.seed(42L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) + gg <- gg_variable(rf) + p <- plot(gg, xvar = "Sepal.Length") + expect_no_error(ggplot2::layer_data(p, 1L)) +}) +``` + +- [ ] **Step 2: Run new tests — expect all three to FAIL** + +```bash +Rscript -e "devtools::test(filter='gg_variable')" +``` + +Expected: 3 FAILures with errors like: +- `"yhat.setosa" %in% names(gg)` is FALSE (currently stores `yhat` factor) +- `layer_data` error about missing aesthetics + +--- + +## Task T2: Fix gg_variable.randomForest — use object$votes + +**Files:** +- Modify: `R/gg_variable.R` + +- [ ] **Step 1: Locate the block to replace in `R/gg_variable.R`** + +Find this exact code (near the end of `gg_variable.randomForest`): + +```r + gg_dta <- predictors + # Append the forest's in-bag predicted values. + gg_dta$yhat <- as.vector(object$predicted) + if (object$type == "classification") { + gg_dta$yvar <- response + } +``` + +- [ ] **Step 2: Replace it with** + +```r + gg_dta <- predictors + # For classification forests use per-class OOB vote fractions (object$votes), + # stored as yhat. columns — the same shape gg_variable.rfsrc + # produces. For regression a single numeric yhat column suffices. + if (object$type == "classification") { + preds <- object$votes # n × n_classes matrix of OOB vote fractions + colnames(preds) <- paste0("yhat.", colnames(preds)) + gg_dta <- cbind(gg_dta, preds) + gg_dta$yvar <- response + } else { + gg_dta$yhat <- as.vector(object$predicted) + } +``` + +- [ ] **Step 3: Run the new tests — expect all three to PASS** + +```bash +Rscript -e "devtools::test(filter='gg_variable')" +``` + +Expected: `[ FAIL 0 | ... | PASS ]` — the three new tests green, existing tests still passing. + +- [ ] **Step 4: Commit** + +```bash +git add R/gg_variable.R tests/testthat/test_gg_variable.R +git commit -m "fix: gg_variable.randomForest classification uses object\$votes for yhat.* columns" +``` + +--- + +## Task T3: Write failing tests — smooth bugs in plot.gg_variable + +**Files:** +- Modify: `tests/testthat/test_gg_variable.R` + +- [ ] **Step 1: Append two more test cases to `tests/testthat/test_gg_variable.R`** + +```r +test_that("plot.gg_variable RF classification: smooth=TRUE layer_data smokeable (binary smooth aes bug)", { + skip_if_not_installed("randomForest") + # Two-class subset to exercise the *binary* classification path + set.seed(42L) + bin_data <- iris[iris$Species != "virginica", ] + bin_data$Species <- droplevels(bin_data$Species) + rf <- randomForest::randomForest(Species ~ ., data = bin_data, ntree = 50L) + gg <- gg_variable(rf) + p <- plot(gg, xvar = "Sepal.Length", smooth = TRUE) + # Before the fix, geom_smooth(...) has no aes and layer_data errors with + # "stat_smooth() requires the following missing aesthetics: x and y" + expect_no_error(ggplot2::layer_data(p, 2L)) +}) + +test_that("plot.gg_variable RF classification: smooth=TRUE works for multi-class (missing block)", { + skip_if_not_installed("randomForest") + set.seed(42L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) + gg <- gg_variable(rf) + # Before the fix the multi-class numeric path silently skips smooth=TRUE + # but does not error; after the fix a smooth layer is present (layer 2). + p <- plot(gg, xvar = "Sepal.Length", smooth = TRUE) + expect_s3_class(p, "ggplot") + ld <- ggplot2::layer_data(p, 2L) # layer 2 = geom_smooth + expect_gt(nrow(ld), 0L) +}) +``` + +- [ ] **Step 2: Run new tests — the multi-class smooth test FAILS (binary smooth passes because fix in T2 made yhat.* work)** + +```bash +Rscript -e "devtools::test(filter='gg_variable')" +``` + +Expected: the multi-class smooth `layer_data` test FAILs (no layer 2 or layer_data errors). The binary smooth test may already pass after the T2 fix — if so, note it and proceed. + +--- + +## Task T4: Fix plot.gg_variable — binary smooth aes + multi-class smooth block + +**Files:** +- Modify: `R/plot.gg_variable.R` + +- [ ] **Step 1: Fix the binary smooth missing aes (around line 517)** + +Find this exact code in the **binary classification, numeric predictor** branch (`if (sum(colnames(gg_dta) == "outcome") == 0)` → `if (ccls_var == "numeric")`): + +```r + if (smooth) { + gg_plt[[ind]] <- gg_plt[[ind]] + + ggplot2::geom_smooth(...) + } +``` + +Replace with: + +```r + if (smooth) { + gg_plt[[ind]] <- gg_plt[[ind]] + + ggplot2::geom_smooth( + ggplot2::aes(x = .data$var, y = .data$yhat), ... + ) + } +``` + +- [ ] **Step 2: Add the missing smooth block to the multi-class numeric path (around line 553)** + +Find this exact code in the **multi-class** branch (`} else { # Multi-class: facet by outcome class`): + +```r + if (ccls_var == "numeric") { + gg_plt[[ind]] <- gg_plt[[ind]] + + ggplot2::geom_point( + ggplot2::aes( + x = .data$var, + y = .data$yhat, + color = .data$yvar, + shape = .data$yvar + ), + ... + ) + } else { +``` + +Replace with: + +```r + if (ccls_var == "numeric") { + gg_plt[[ind]] <- gg_plt[[ind]] + + ggplot2::geom_point( + ggplot2::aes( + x = .data$var, + y = .data$yhat, + color = .data$yvar, + shape = .data$yvar + ), + ... + ) + if (smooth) { + gg_plt[[ind]] <- gg_plt[[ind]] + + ggplot2::geom_smooth( + ggplot2::aes(x = .data$var, y = .data$yhat), ... + ) + } + } else { +``` + +- [ ] **Step 3: Run all gg_variable tests — expect all to PASS** + +```bash +Rscript -e "devtools::test(filter='gg_variable')" +``` + +Expected: `[ FAIL 0 | ... ]` + +- [ ] **Step 4: Run full test suite to confirm no regressions** + +```bash +Rscript -e "devtools::test()" +``` + +Expected: `[ FAIL 0 | WARN | SKIP | PASS ]` + +- [ ] **Step 5: Commit** + +```bash +git add R/plot.gg_variable.R tests/testthat/test_gg_variable.R +git commit -m "fix: plot.gg_variable binary smooth aes + add multi-class smooth block" +``` + +--- + +## Task T5: vdiffr snapshots + +**Files:** +- Modify: `tests/testthat/test_snapshots.R` + +- [ ] **Step 1: Append a new randomForest classification section inside the `VDIFFR_RUN_TESTS` guard in `tests/testthat/test_snapshots.R`** + +Find the closing brace of the existing `if (identical(Sys.getenv("VDIFFR_RUN_TESTS"), "true"))` block (the last `}` before the file ends or the next top-level statement) and add the following block **inside** it, before the closing `}`: + +```r +## ── randomForest classification snapshots (PR #87) ─────────────────────────── +if (requireNamespace("randomForest", quietly = TRUE)) { + local({ + set.seed(42L) + rf_iris <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) + gg_iris <- gg_variable(rf_iris) + + test_that("snapshot: gg-variable-rf-classification-default", { + vdiffr::expect_doppelganger( + "gg-variable-rf-classification-default", + plot(gg_iris) + ) + }) + + test_that("snapshot: gg-variable-rf-classification-smooth", { + vdiffr::expect_doppelganger( + "gg-variable-rf-classification-smooth", + plot(gg_iris, xvar = "Sepal.Length", smooth = TRUE) + ) + }) + }) +} +``` + +- [ ] **Step 2: Confirm vdiffr tests skip cleanly without the env var** + +```bash +Rscript -e "devtools::test(filter='snapshots')" +``` + +Expected: all snapshot tests SKIP (no VDIFFR_RUN_TESTS env var set). + +- [ ] **Step 3: Commit** + +```bash +git add tests/testthat/test_snapshots.R +git commit -m "test: add vdiffr snapshots for gg_variable RF classification (PR #87)" +``` + +--- + +## Task T6: NEWS, housekeeping, and PR + +**Files:** +- Modify: `NEWS.md` + +- [ ] **Step 1: Add a fix entry to the top of the `ggRandomForests v2.8.0 (development)` section in `NEWS.md`** + +``` +* **`gg_variable.randomForest` classification fix (#87).** + - `gg_variable.randomForest()` for classification forests now stores + per-class OOB vote fractions as `yhat.` columns (from + `object$votes`), matching the `rfsrc` path. Previously a single + `yhat` factor column (class labels from `object$predicted`) was + stored, which prevented the multi-class pivot in `plot.gg_variable` + from firing. + - `plot.gg_variable` binary classification: `smooth = TRUE` now + correctly maps x/y aesthetics onto the smooth layer. + - `plot.gg_variable` multi-class numeric path: `smooth = TRUE` now + adds a smooth layer (was silently skipped). + - Closes stale issues #81 (fixed in PR #83) and #82. +``` + +- [ ] **Step 2: Run R CMD check** + +```bash +Rscript -e "devtools::check(args='--as-cran')" +``` + +Expected: `0 errors | 0 warnings | 0 notes` + +- [ ] **Step 3: Commit NEWS** + +```bash +git add NEWS.md +git commit -m "docs: update NEWS for PR #87 — gg_variable RF classification fix" +``` + +- [ ] **Step 4: Push the branch** + +```bash +git push -u origin fix/rf-87-gg-variable-classification +``` + +- [ ] **Step 5: Open the PR** + +```bash +gh pr create \ + --title "fix: gg_variable.randomForest classification — yhat.* columns + smooth bugs (#87)" \ + --body "$(cat <<'EOF' +## Summary + +- `gg_variable.randomForest` for classification forests now stores per-class OOB vote fractions as \`yhat.\` columns (from \`object\$votes\`), matching the rfsrc path. Previously a single \`yhat\` factor (class labels from \`object\$predicted\`) was stored, preventing the multi-class pivot in \`plot.gg_variable\` from firing. +- Binary classification \`smooth = TRUE\`: added explicit \`aes(x, y)\` to \`geom_smooth\` so \`layer_data()\` no longer errors. +- Multi-class numeric path: added missing \`smooth = TRUE\` block. +- 5 new unit tests + 2 vdiffr snapshots. + +## Housekeeping + +Closes #81, Closes #82 (both were fully fixed by PR #83 but auto-close never fired). + +## Test plan +- [x] \`devtools::test()\` — all tests pass, 0 failures +- [x] \`devtools::check(args="--as-cran")\` — 0 errors, 0 warnings, 0 notes + +🤖 Generated with [Claude Code](https://claude.com/claude-code) +EOF +)" +``` + +- [ ] **Step 6: Verify CI is green on the PR before asking for merge** + +```bash +gh pr checks +``` + +Expected: all checks pass. diff --git a/dev/plans/2026-05-21-rf-88-multiclass-roc-plan.md b/dev/plans/2026-05-21-rf-88-multiclass-roc-plan.md new file mode 100644 index 00000000..12c9a216 --- /dev/null +++ b/dev/plans/2026-05-21-rf-88-multiclass-roc-plan.md @@ -0,0 +1,712 @@ +# PR #88: Multi-class gg_roc with per_class=TRUE Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Extend `gg_roc()` to return per-class one-vs-rest ROC curves in a long-format data frame (with a `class` column and named AUC vector) when `per_class = TRUE` is passed, and teach `plot.gg_roc` to render them as an overlay or faceted panel. + +**Architecture:** Add `per_class = FALSE` to the `gg_roc` generic and the `randomForest` method. When `per_class = TRUE` and the forest has > 2 classes, the method reuses the existing `.rf_one_class_roc` and `.rf_prob_matrix` helpers already in `calc_roc.R` to build one OvR ROC curve per class, stacks them into a long-format data frame with a `class` factor column, and attaches a named AUC vector ordered by descending AUC. `plot.gg_roc` detects the `class` column and dispatches to a new overlay/facet path. `summary.gg_roc` is updated to print a named AUC vector when the `class` column is present. Binary forests with `per_class = TRUE` are a no-op (single curve returned unchanged). + +**Tech Stack:** R, randomForest, ggplot2, testthat, vdiffr (snapshots only) + +--- + +## File Map + +| File | Change | +|---|---| +| `R/gg_roc.R` | Add `per_class` argument to generic + both methods; implement per_class path in `gg_roc.randomForest` | +| `R/plot.gg_roc.R` | Add `panel` argument; insert per_class detection branch before single-class path | +| `R/summary_methods.R` | Update `summary.gg_roc` to handle named AUC vector | +| `tests/testthat/test_gg_roc.R` | Add per_class tests (T1, T3, T4) | +| `tests/testthat/test_snapshots.R` | Add 2 vdiffr snapshots (T7) | +| `DESCRIPTION` | Version `2.7.3.9005` → `2.7.3.9006` | +| `NEWS.md` | Add feature entry | + +--- + +## Task T0: Branch setup and version bump + +**Files:** +- Modify: `DESCRIPTION` + +- [ ] **Step 1: Create the feature branch from main** + +```bash +git checkout -b feat/rf-88-multiclass-roc origin/main +``` + +Expected: `Switched to a new branch 'feat/rf-88-multiclass-roc'` + +- [ ] **Step 2: Bump version in DESCRIPTION** + +Open `DESCRIPTION`. Change: +``` +Version: 2.7.3.9005 +``` +to: +``` +Version: 2.7.3.9006 +``` + +- [ ] **Step 3: Confirm package loads** + +```r +devtools::load_all() +``` + +Expected: no errors. + +- [ ] **Step 4: Commit** + +```bash +git add DESCRIPTION +git commit -m "chore: open v2.7.3.9006 dev increment (PR #88)" +``` + +--- + +## Task T1: Write failing tests — per_class extractor + +**Files:** +- Modify: `tests/testthat/test_gg_roc.R` + +- [ ] **Step 1: Append the per_class extractor tests to `tests/testthat/test_gg_roc.R`** + +```r +## ── per_class = TRUE (PR #88) ────────────────────────────────────────────── + +test_that("gg_roc per_class=TRUE: long format with class column", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + expect_true("class" %in% names(gg)) + expect_true(all(c("sens", "spec") %in% names(gg))) + expect_s3_class(gg$class, "factor") + expect_equal(nlevels(gg$class), 3L) +}) + +test_that("gg_roc per_class=TRUE: auc attr is named numeric vector length 3", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + auc <- attr(gg, "auc") + expect_length(auc, 3L) + expect_named(auc) + # setosa is linearly separable in iris — AUC should be near-perfect + expect_gt(auc[["setosa"]], 0.99) + # AUC values must be sorted descending + expect_true(all(diff(auc) <= 0)) +}) + +test_that("gg_roc per_class=TRUE: class factor levels ordered by descending AUC", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + auc <- attr(gg, "auc") + expect_equal(levels(gg$class), names(auc)) +}) +``` + +- [ ] **Step 2: Run the new tests — expect all three to FAIL** + +```bash +Rscript -e "devtools::test(filter='gg_roc')" +``` + +Expected: 3 FAILures — `gg_roc` does not accept `per_class` yet so it hits `...` and returns without a `class` column. + +--- + +## Task T2: Implement gg_roc per_class path + +**Files:** +- Modify: `R/gg_roc.R` + +- [ ] **Step 1: Add `per_class = FALSE` to the generic** + +Find: +```r +gg_roc <- function(object, which_outcome, oob = TRUE, ...) { + UseMethod("gg_roc", object) +} +``` + +Replace with: +```r +gg_roc <- function(object, which_outcome, oob = TRUE, per_class = FALSE, ...) { + UseMethod("gg_roc", object) +} +``` + +- [ ] **Step 2: Add `per_class = FALSE` to `gg_roc.rfsrc`** + +Find: +```r +gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) { +``` + +Replace with: +```r +gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, per_class = FALSE, ...) { +``` + +Note: the rfsrc per_class path is out of scope for this PR (tracked under issue #72). The argument is added only so `gg_roc(rfsrc_obj, per_class = TRUE)` does not error with "unused argument". Leave the body unchanged. + +- [ ] **Step 3: Replace the entire `gg_roc.randomForest` body** + +Find and replace the entire `gg_roc.randomForest` function: + +```r +# BEFORE — existing body: +#' @export +gg_roc.randomForest <- function(object, which_outcome, oob = TRUE, ...) { + # Validate that the object is a genuine randomForest instance. + if (!inherits(object, "randomForest")) { + stop( + "gg_roc.randomForest only works for objects of class 'randomForest'." + ) + } + + # Default to computing the ROC curve for all outcome classes. + if (missing(which_outcome)) { + which_outcome <- "all" + } + + if (!(object$type == "classification")) { + stop("gg_roc only works with classification forests") + } + + # For randomForest objects the response is stored in $y (not $yvar). + gg_dta <- # nolint: object_usage_linter + calc_roc(object, + object$y, + which_outcome = which_outcome, + oob = oob + ) + class(gg_dta) <- c("gg_roc", class(gg_dta)) + gg_dta <- .set_provenance(gg_dta, object) + + invisible(gg_dta) +} +``` + +Replace with: + +```r +#' @export +gg_roc.randomForest <- function(object, which_outcome, oob = TRUE, + per_class = FALSE, ...) { + if (!inherits(object, "randomForest")) { + stop("gg_roc.randomForest only works for objects of class 'randomForest'.") + } + if (missing(which_outcome)) { + which_outcome <- "all" + } + if (!(object$type == "classification")) { + stop("gg_roc only works with classification forests") + } + + lvls <- levels(object$y) + n_class <- length(lvls) + + # ── per_class = TRUE path (multi-class only) ───────────────────────────── + if (isTRUE(per_class) && n_class > 2L) { + if (!missing(which_outcome) && !identical(which_outcome, "all")) { + message("which_outcome is ignored when per_class = TRUE.") + } + prob <- .rf_prob_matrix(object, oob, lvls) + dta <- object$y + curves <- lapply(seq_along(lvls), function(k) { + cv <- .rf_one_class_roc(dta, prob, k, lvls) + cv$class <- lvls[k] + cv + }) + auc_vals <- vapply(curves, calc_auc, numeric(1L)) + names(auc_vals) <- lvls + auc_ord <- order(auc_vals, decreasing = TRUE) + auc_vals <- auc_vals[auc_ord] + gg_dta <- do.call(rbind, curves) + gg_dta$class <- factor(gg_dta$class, levels = lvls[auc_ord]) + class(gg_dta) <- c("gg_roc", class(gg_dta)) + attr(gg_dta, "auc") <- auc_vals + gg_dta <- .set_provenance(gg_dta, object) + return(invisible(gg_dta)) + } + + # ── Standard path (binary, or per_class not requested) ────────────────── + # For randomForest objects the response is stored in $y (not $yvar). + gg_dta <- # nolint: object_usage_linter + calc_roc(object, object$y, which_outcome = which_outcome, oob = oob) + class(gg_dta) <- c("gg_roc", class(gg_dta)) + attr(gg_dta, "auc") <- calc_auc(gg_dta) + gg_dta <- .set_provenance(gg_dta, object) + invisible(gg_dta) +} +``` + +- [ ] **Step 4: Run the T1 tests — expect all three to PASS** + +```bash +Rscript -e "devtools::test(filter='gg_roc')" +``` + +Expected: T1's 3 new tests green; all existing tests still passing. + +- [ ] **Step 5: Commit** + +```bash +git add R/gg_roc.R tests/testthat/test_gg_roc.R +git commit -m "feat: gg_roc.randomForest per_class=TRUE — per-class OvR ROC + named AUC" +``` + +--- + +## Task T3: Write failing tests — binary no-op and which_outcome conflict + +**Files:** +- Modify: `tests/testthat/test_gg_roc.R` + +- [ ] **Step 1: Append edge-case tests** + +```r +test_that("gg_roc per_class=TRUE on binary forest: no class column (no-op)", { + skip_if_not_installed("randomForest") + set.seed(1L) + bin_data <- iris[iris$Species != "virginica", ] + bin_data$Species <- droplevels(bin_data$Species) + rf <- randomForest::randomForest(Species ~ ., data = bin_data, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + # Binary forest: per_class is a no-op — no class column, scalar AUC + expect_false("class" %in% names(gg)) + expect_length(attr(gg, "auc"), 1L) +}) + +test_that("gg_roc per_class=TRUE + which_outcome integer: message then per_class wins", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + expect_message( + gg <- gg_roc(rf, per_class = TRUE, which_outcome = 1L), + "which_outcome.*ignored.*per_class" + ) + expect_true("class" %in% names(gg)) +}) + +test_that("gg_roc which_outcome='all' still returns macro-average (no class column)", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, which_outcome = "all") + expect_false("class" %in% names(gg)) + # Macro-average returns a single data frame, not a class-faceted one + expect_true(all(c("sens", "spec") %in% names(gg))) +}) +``` + +- [ ] **Step 2: Run the new tests — expect all to PASS immediately (T2 already implements them)** + +```bash +Rscript -e "devtools::test(filter='gg_roc')" +``` + +Expected: `[ FAIL 0 | ... ]` — all pass. (These tests verify correctness of the T2 implementation; they are written before T2 in the plan so the red-green cycle is explicit, but because they test the same function modified in T2 they green immediately.) + +- [ ] **Step 3: Commit** + +```bash +git add tests/testthat/test_gg_roc.R +git commit -m "test: gg_roc per_class binary no-op and which_outcome conflict tests" +``` + +--- + +## Task T4: Write failing tests — plot.gg_roc per_class paths + +**Files:** +- Modify: `tests/testthat/test_gg_roc.R` + +- [ ] **Step 1: Append plot tests** + +```r +## ── plot.gg_roc per_class paths (PR #88) ───────────────────────────────── + +test_that("plot.gg_roc per_class=TRUE: overlay returns ggplot", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + p <- plot(gg, panel = "overlay") + expect_s3_class(p, "ggplot") +}) + +test_that("plot.gg_roc per_class=TRUE: facet returns ggplot", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + p <- plot(gg, panel = "facet") + expect_s3_class(p, "ggplot") +}) + +test_that("plot.gg_roc per_class=TRUE: layer_data smokeable for overlay", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + p <- plot(gg, panel = "overlay") + expect_no_error(ggplot2::layer_data(p, 1L)) +}) + +test_that("plot.gg_roc existing single-class path unchanged", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, which_outcome = 1L) + p <- plot(gg) + expect_s3_class(p, "ggplot") + expect_no_error(ggplot2::layer_data(p, 1L)) +}) +``` + +- [ ] **Step 2: Run new tests — the per_class plot tests FAIL** + +```bash +Rscript -e "devtools::test(filter='gg_roc')" +``` + +Expected: the 3 `plot.gg_roc per_class` tests FAIL because `plot.gg_roc` does not yet accept `panel` and has no per_class detection. The "existing single-class path unchanged" test should PASS. + +--- + +## Task T5: Implement plot.gg_roc per_class detection + +**Files:** +- Modify: `R/plot.gg_roc.R` + +- [ ] **Step 1: Add `panel` argument to `plot.gg_roc`** + +Find: +```r +plot.gg_roc <- function(x, which_outcome = NULL, ...) { +``` + +Replace with: +```r +plot.gg_roc <- function(x, which_outcome = NULL, + panel = c("overlay", "facet"), ...) { + panel <- match.arg(panel) +``` + +- [ ] **Step 2: Insert per_class detection branch** + +Find the comment and `if` block that begins the single-class ROC plot section: + +```r + ## ---- Single-class ROC plot ------------------------------------------ + if (inherits(gg_dta, "gg_roc")) { + # Sort by specificity so the ROC curve is drawn left-to-right + gg_dta <- gg_dta[order(gg_dta$spec), ] +``` + +Replace with: + +```r + ## ---- Single-class ROC plot (or per_class long-format) ---------------- + if (inherits(gg_dta, "gg_roc")) { + + # Per-class detection: gg_roc produced by gg_roc(..., per_class = TRUE) + # carries a 'class' column (factor) + a named AUC vector attribute. + if ("class" %in% names(gg_dta)) { + gg_dta$fpr <- 1 - gg_dta$spec + auc <- attr(x, "auc") + + if (panel == "overlay") { + gg_plt <- ggplot2::ggplot(gg_dta) + + ggplot2::geom_line(ggplot2::aes( + x = .data$fpr, y = .data$sens, color = .data$class + )) + + ggplot2::labs( + x = "1 - Specificity (FPR)", y = "Sensitivity (TPR)", + color = "Class" + ) + + ggplot2::geom_abline( + slope = 1, intercept = 0, + col = "red", linetype = 2, linewidth = .5 + ) + + ggplot2::coord_fixed() + } else { + gg_plt <- ggplot2::ggplot(gg_dta) + + ggplot2::geom_line(ggplot2::aes( + x = .data$fpr, y = .data$sens + )) + + ggplot2::labs( + x = "1 - Specificity (FPR)", y = "Sensitivity (TPR)" + ) + + ggplot2::geom_abline( + slope = 1, intercept = 0, + col = "red", linetype = 2, linewidth = .5 + ) + + ggplot2::facet_wrap(~class) + + ggplot2::coord_fixed() + } + + # AUC caption — top 5 classes by descending AUC (already sorted) + if (!is.null(auc) && length(auc) > 0L) { + top_n <- min(5L, length(auc)) + auc_str <- paste( + sprintf("%s=%.3g", names(auc)[seq_len(top_n)], auc[seq_len(top_n)]), + collapse = ", " + ) + if (length(auc) > 5L) auc_str <- paste0(auc_str, ", ...") + gg_plt <- gg_plt + + ggplot2::labs(caption = paste("OvR ROC — per_class=TRUE. AUC:", auc_str)) + } + return(gg_plt) + } + + # Sort by specificity so the ROC curve is drawn left-to-right + gg_dta <- gg_dta[order(gg_dta$spec), ] +``` + +- [ ] **Step 3: Run the T4 tests — expect all four to PASS** + +```bash +Rscript -e "devtools::test(filter='gg_roc')" +``` + +Expected: `[ FAIL 0 | ... ]` + +- [ ] **Step 4: Run full test suite — confirm no regressions** + +```bash +Rscript -e "devtools::test()" +``` + +Expected: `[ FAIL 0 | WARN | SKIP | PASS ]` + +- [ ] **Step 5: Commit** + +```bash +git add R/plot.gg_roc.R tests/testthat/test_gg_roc.R +git commit -m "feat: plot.gg_roc per_class=TRUE — overlay and facet panel paths" +``` + +--- + +## Task T6: Update summary.gg_roc for named AUC vector + +**Files:** +- Modify: `R/summary_methods.R` + +- [ ] **Step 1: Write a failing test for summary with per_class** + +Append to `tests/testthat/test_gg_roc.R`: + +```r +test_that("summary.gg_roc per_class=TRUE: prints named AUC, no error", { + skip_if_not_installed("randomForest") + set.seed(1L) + rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg <- gg_roc(rf, per_class = TRUE) + s <- summary(gg) + expect_s3_class(s, "summary.gg") + # Body should mention all three class names + expect_true(any(grepl("setosa", s$body))) + expect_true(any(grepl("versicolor", s$body))) + expect_true(any(grepl("virginica", s$body))) +}) +``` + +- [ ] **Step 2: Run — expect FAIL** + +```bash +Rscript -e "devtools::test(filter='gg_roc')" +``` + +Expected: the new summary test FAIL because `summary.gg_roc` currently calls `sprintf("AUC: %.4g", auc)` on a named numeric vector, producing a misformatted string rather than named class entries. + +- [ ] **Step 3: Replace `summary.gg_roc` in `R/summary_methods.R`** + +Find: +```r +#' @rdname summary.gg +#' @export +summary.gg_roc <- function(object, ...) { + body <- c( + sprintf("thresholds: %d", nrow(object)), + sprintf("AUC: %.4g", + attr(object, "auc") %||% .gg_auc_trap(object)) + ) + .summary_skel(object, "gg_roc", body) +} +``` + +Replace with: +```r +#' @rdname summary.gg +#' @export +summary.gg_roc <- function(object, ...) { + auc <- attr(object, "auc") %||% .gg_auc_trap(object) + if ("class" %in% names(object)) { + # per_class = TRUE path: named AUC vector, one entry per class + n_cls <- nlevels(object$class) + auc_str <- paste(sprintf("%s=%.4g", names(auc), auc), collapse = ", ") + body <- c(sprintf("classes: %d", n_cls), + sprintf("AUC: %s", auc_str)) + } else { + body <- c( + sprintf("thresholds: %d", nrow(object)), + sprintf("AUC: %.4g", auc) + ) + } + .summary_skel(object, "gg_roc", body) +} +``` + +- [ ] **Step 4: Run all gg_roc tests — expect all to PASS** + +```bash +Rscript -e "devtools::test(filter='gg_roc')" +``` + +Expected: `[ FAIL 0 | ... ]` + +- [ ] **Step 5: Commit** + +```bash +git add R/summary_methods.R tests/testthat/test_gg_roc.R +git commit -m "feat: summary.gg_roc handles named AUC vector for per_class=TRUE" +``` + +--- + +## Task T7: vdiffr snapshots + +**Files:** +- Modify: `tests/testthat/test_snapshots.R` + +- [ ] **Step 1: Append a per_class ROC section inside the `VDIFFR_RUN_TESTS` guard** + +Find the closing `}` of the `if (identical(Sys.getenv("VDIFFR_RUN_TESTS"), "true"))` block and add the following **inside** it, before the closing `}`: + +```r +## ── per_class ROC snapshots (PR #88) ───────────────────────────────────── +if (requireNamespace("randomForest", quietly = TRUE)) { + local({ + set.seed(1L) + rf_iris <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) + gg_pc_iris <- gg_roc(rf_iris, per_class = TRUE) + + test_that("snapshot: gg-roc-multiclass-overlay", { + vdiffr::expect_doppelganger( + "gg-roc-multiclass-overlay", + plot(gg_pc_iris, panel = "overlay") + ) + }) + + test_that("snapshot: gg-roc-multiclass-facet", { + vdiffr::expect_doppelganger( + "gg-roc-multiclass-facet", + plot(gg_pc_iris, panel = "facet") + ) + }) + }) +} +``` + +- [ ] **Step 2: Confirm snapshot tests skip cleanly without the env var** + +```bash +Rscript -e "devtools::test(filter='snapshots')" +``` + +Expected: all snapshot tests SKIP (no `VDIFFR_RUN_TESTS=true`). + +- [ ] **Step 3: Commit** + +```bash +git add tests/testthat/test_snapshots.R +git commit -m "test: add vdiffr snapshots for per_class ROC overlay and facet (PR #88)" +``` + +--- + +## Task T8: NEWS, final gate, and PR + +**Files:** +- Modify: `NEWS.md` + +- [ ] **Step 1: Add feature entry to the `ggRandomForests v2.8.0 (development)` section in `NEWS.md`** + +``` +* **`gg_roc`: per-class one-vs-rest ROC curves (#88, closes #72).** + - `gg_roc()` gains a `per_class = FALSE` argument. When `per_class = TRUE` + and the forest has more than two classes, returns a long-format `gg_roc` + data frame with a `class` factor column and a named AUC vector attribute + (one entry per class, ordered by descending AUC). + - `plot.gg_roc()` gains a `panel = c("overlay", "facet")` argument. When the + `gg_roc` object contains a `class` column, the overlay path colours curves + by class; the facet path wraps each class into its own panel. + - `summary.gg_roc()` now prints named per-class AUC values when the `class` + column is present. + - Binary forests: `per_class = TRUE` is a silent no-op (single-curve result + returned unchanged). + - ROC confidence intervals are deferred to v2.9.0 (issue #7 / #72-CIs). +``` + +- [ ] **Step 2: Run R CMD check** + +```bash +Rscript -e "devtools::check(args='--as-cran')" +``` + +Expected: `0 errors | 0 warnings | 0 notes` + +- [ ] **Step 3: Commit NEWS** + +```bash +git add NEWS.md +git commit -m "docs: update NEWS for PR #88 — per_class ROC" +``` + +- [ ] **Step 4: Push the branch** + +```bash +git push -u origin feat/rf-88-multiclass-roc +``` + +- [ ] **Step 5: Open the PR** + +```bash +gh pr create \ + --title "feat: gg_roc per_class=TRUE — per-class OvR ROC curves (#88, closes #72)" \ + --body "$(cat <<'EOF' +## Summary + +- `gg_roc()` gains `per_class = FALSE`. When `per_class = TRUE` on a multi-class forest, returns a long-format `gg_roc` data frame with a `class` factor column and a named AUC vector attribute (ordered by descending AUC). +- `plot.gg_roc()` gains `panel = c("overlay", "facet")`. Detects the `class` column and dispatches to the new multi-class overlay or faceted path. +- `summary.gg_roc()` updated to print named per-class AUC values when the `class` column is present. +- Binary forests: `per_class = TRUE` is a silent no-op. +- ROC CIs deferred to v2.9.0 (issue #7 / #72-CIs). + +## Test plan +- [x] `devtools::test()` — all tests pass, 0 failures +- [x] `devtools::check(args="--as-cran")` — 0 errors, 0 warnings, 0 notes + +Closes #72 + +🤖 Generated with [Claude Code](https://claude.com/claude-code) +EOF +)" +``` + +- [ ] **Step 6: Verify CI is green before requesting merge** + +```bash +gh pr checks +``` + +Expected: all checks pass. From 3e2eb0bcf5f31b3294553b1bfb03e5ff926f7898 Mon Sep 17 00:00:00 2001 From: John Ehrlinger Date: Thu, 21 May 2026 11:57:10 -0400 Subject: [PATCH 3/3] =?UTF-8?q?docs:=20address=20Copilot=20review=20on=20P?= =?UTF-8?q?R=20#87=20=E2=80=94=206=20plan/spec=20corrections?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - design doc: add oob=TRUE to gg_roc new signature (was omitted) - design doc: replace proposed calc_roc_one_vs_rest() with existing .rf_prob_matrix/.rf_one_class_roc/.rf_macro_average_roc helpers; document that output uses sens/spec/pct (package contract), not fpr/tpr - plan #87: add norm.votes=FALSE row-normalization to votes extraction step - plan #87: tighten patchwork test expectation (iris 4-predictor default must return patchwork, not patchwork||ggplot, to catch list regressions) - plan #88: add pct to per_class column assertions (both occurrences) - plan #88: correct rfsrc per_class rationale (argument is for API discoverability, not to prevent "unused argument" error — rfsrc uses ...) Co-Authored-By: Claude Opus 4.7 (1M context) --- ...1-rf-87-gg-variable-classification-plan.md | 14 +++++++++++-- .../2026-05-21-rf-88-multiclass-roc-plan.md | 10 +++++++--- .../2026-05-21-rf-finishing-work-design.md | 20 ++++++++++++++----- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/dev/plans/2026-05-21-rf-87-gg-variable-classification-plan.md b/dev/plans/2026-05-21-rf-87-gg-variable-classification-plan.md index 6ac8b6ca..89c65459 100644 --- a/dev/plans/2026-05-21-rf-87-gg-variable-classification-plan.md +++ b/dev/plans/2026-05-21-rf-87-gg-variable-classification-plan.md @@ -107,7 +107,10 @@ test_that("gg_variable.randomForest classification: plot returns patchwork for a rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L) gg <- gg_variable(rf) p <- plot(gg) - expect_true(inherits(p, "patchwork") || inherits(p, "ggplot")) + # iris has 4 predictors, so the default (no xvar) path assembles a multi-panel + # patchwork. Assert patchwork specifically so a regression to a bare list (#80) + # would be caught. + expect_s3_class(p, "patchwork") }) test_that("gg_variable.randomForest classification: layer_data works on single-xvar plot", { @@ -158,7 +161,14 @@ Find this exact code (near the end of `gg_variable.randomForest`): # stored as yhat. columns — the same shape gg_variable.rfsrc # produces. For regression a single numeric yhat column suffices. if (object$type == "classification") { - preds <- object$votes # n × n_classes matrix of OOB vote fractions + preds <- object$votes # n × n_classes matrix; OOB vote fractions by default, + # but raw integer counts when forest is fit with + # norm.votes = FALSE. Row-normalise unconditionally so + # values are always in [0, 1] with rowSums ≈ 1. + rs <- rowSums(preds) + if (any(rs > 1 + 1e-8, na.rm = TRUE)) { + preds <- preds / rs + } colnames(preds) <- paste0("yhat.", colnames(preds)) gg_dta <- cbind(gg_dta, preds) gg_dta$yvar <- response diff --git a/dev/plans/2026-05-21-rf-88-multiclass-roc-plan.md b/dev/plans/2026-05-21-rf-88-multiclass-roc-plan.md index 12c9a216..c7719b7c 100644 --- a/dev/plans/2026-05-21-rf-88-multiclass-roc-plan.md +++ b/dev/plans/2026-05-21-rf-88-multiclass-roc-plan.md @@ -81,7 +81,7 @@ test_that("gg_roc per_class=TRUE: long format with class column", { rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100L) gg <- gg_roc(rf, per_class = TRUE) expect_true("class" %in% names(gg)) - expect_true(all(c("sens", "spec") %in% names(gg))) + expect_true(all(c("sens", "spec", "pct") %in% names(gg))) # pct = threshold; same 3-col contract as calc_roc expect_s3_class(gg$class, "factor") expect_equal(nlevels(gg$class), 3L) }) @@ -153,7 +153,11 @@ Replace with: gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, per_class = FALSE, ...) { ``` -Note: the rfsrc per_class path is out of scope for this PR (tracked under issue #72). The argument is added only so `gg_roc(rfsrc_obj, per_class = TRUE)` does not error with "unused argument". Leave the body unchanged. +Note: the rfsrc per_class path is out of scope for this PR (tracked under issue #72). +`gg_roc.rfsrc` already accepts `...`, so `per_class = TRUE` would currently be +silently swallowed rather than erroring. The argument is added explicitly for API +discoverability and to keep the signature consistent with `gg_roc.randomForest`. +Leave the body unchanged — no per_class logic is wired up in the rfsrc method. - [ ] **Step 3: Replace the entire `gg_roc.randomForest` body** @@ -302,7 +306,7 @@ test_that("gg_roc which_outcome='all' still returns macro-average (no class colu gg <- gg_roc(rf, which_outcome = "all") expect_false("class" %in% names(gg)) # Macro-average returns a single data frame, not a class-faceted one - expect_true(all(c("sens", "spec") %in% names(gg))) + expect_true(all(c("sens", "spec", "pct") %in% names(gg))) # pct = threshold; same 3-col contract as calc_roc }) ``` diff --git a/dev/plans/2026-05-21-rf-finishing-work-design.md b/dev/plans/2026-05-21-rf-finishing-work-design.md index 4720e3d4..e1ac4e47 100644 --- a/dev/plans/2026-05-21-rf-finishing-work-design.md +++ b/dev/plans/2026-05-21-rf-finishing-work-design.md @@ -179,14 +179,18 @@ macro-average meaning (backward compatible). No CI computation in this PR. **New signature:** ```r -gg_roc(object, which_outcome = "all", per_class = FALSE, ...) +gg_roc(object, which_outcome = "all", per_class = FALSE, oob = TRUE, ...) ``` +(`oob = TRUE` already exists in both dispatch methods; listed here for completeness.) + **`per_class = TRUE` behaviour:** - For forests with > 2 classes: compute one-vs-rest ROC for every class *k* (scores = `votes[, k]`, positive = `(y == k)`), stack into a long data frame - with columns `class` (factor), `fpr`, `tpr`. + with columns `class` (factor), `sens`, `spec`, `pct` — the same three-column + contract used by the rest of the package (`calc_roc` returns `sens`/`spec`/`pct`; + `fpr = 1 − spec` and `tpr = sens` if desired by callers). - `attr(gg, "auc")` becomes a named numeric vector: one entry per class, ordered by descending AUC. Class factor levels follow the same order. - For binary forests: `per_class = TRUE` is a no-op — returns the single-curve @@ -196,9 +200,15 @@ gg_roc(object, which_outcome = "all", per_class = FALSE, ...) integer): `per_class` wins, a `message()` informs the caller that `which_outcome` is ignored when `per_class = TRUE`. -**Internal helper `calc_roc_one_vs_rest(scores, y, k)`** — single class OvR -ROC, returns a data frame `(fpr, tpr)` plus scalar AUC. Extracted so it can be -reused for both per-class computation and macro-average. +**Internal helpers** — the implementation reuses existing private functions in +`R/calc_roc.R` rather than introducing new ones: + +- `.rf_prob_matrix(object, oob)` — extracts and normalises the vote matrix +- `.rf_one_class_roc(probs, y, k)` — single-class OvR ROC, returns + `data.frame(sens, spec, pct)` plus scalar AUC +- `.rf_macro_average_roc(probs, y)` — macro-average across all classes + +No new top-level helpers are needed. ### Plot method (`R/plot.gg_roc.R`)