Skip to content

Commit ae0b18e

Browse files
ehrlingerclaude
andauthored
fix: gg_variable.randomForest classification — yhat.* columns + smooth bugs (#87) (#88)
* chore: open v2.7.3.9005 dev increment (PR #87) * test: add three failing tests for gg_variable.randomForest classification (PR #87) Tests verify that gg_variable.randomForest classification forests produce: - Per-class vote fraction columns (yhat.setosa, yhat.versicolor, yhat.virginica) - No bare yhat column for multi-class prediction - Valid vote fractions (0-1) that row-sum to 1 - Plottable output for multi-xvar case - Layer_data access on single-xvar plots All three tests currently FAIL, as expected (TDD red phase). Implementation fix comes in T2. * fix: gg_variable.randomForest classification uses object\$votes for yhat.* columns Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: normalise object\$votes rows to defend against norm.votes=FALSE; update stale oob comment Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test: add failing tests for plot.gg_variable smooth bugs (PR #87) * fix: plot.gg_variable binary smooth aes + add multi-class smooth block (#87) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test: add vdiffr snapshots for gg_variable RF classification (PR #87) * docs: update NEWS for PR #87 — gg_variable RF classification fix * fix: plot.gg_variable multi-class facets use class names not integer indices Strip the yhat. prefix from column names when building the outcome column in the multi-class pivot loop (line ~169), so facet labels show "setosa"/ "versicolor"/"virginica" instead of 1/2/3. Add a regression test that verifies p@data$outcome contains class names after plotting. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test: tighten patchwork assertion to match corrected plan (#87 Copilot review) Plan #87 was corrected (via PR #87 Copilot review) to assert patchwork specifically for the 4-predictor iris default-plot case, rather than the loose patchwork||ggplot form that wouldn't catch a regression to a bare list (issue #80). Align the implementation test to match. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix: address Copilot review on PR #88 — oob warning + factor level order - gg_variable.randomForest: replace silent no-op oob override with an explicit warning when oob=FALSE is supplied, since in-bag class probabilities are not available via the randomForest API. - plot.gg_variable: set outcome factor levels from gg_dta_y column order rather than factor() default (alphabetical), so multi-class facet panels follow the model's class ordering regardless of locale. - Add two new tests: oob=FALSE warning, and outcome factor level order. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent cf6229a commit ae0b18e

6 files changed

Lines changed: 203 additions & 13 deletions

File tree

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Package: ggRandomForests
22
Type: Package
33
Title: Visually Exploring Random Forests
4-
Version: 2.7.3.9004
5-
Date: 2026-05-20
4+
Version: 2.7.3.9005
5+
Date: 2026-05-21
66
Authors@R: person("John", "Ehrlinger",
77
role = c("aut", "cre"),
88
email = "john.ehrlinger@gmail.com")

NEWS.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: ggRandomForests
2-
Version: 2.7.3.9004
2+
Version: 2.7.3.9005
33

44
ggRandomForests v2.8.0 (development) — continued
55
=================================================
@@ -31,6 +31,19 @@ ggRandomForests v2.8.0 (development) — continued
3131

3232
ggRandomForests v2.8.0 (development)
3333
====================================
34+
* **`gg_variable.randomForest` classification fix (#87).**
35+
- `gg_variable.randomForest()` for classification forests now stores
36+
per-class OOB vote fractions as `yhat.<classname>` columns (from
37+
`object$votes`), matching the `rfsrc` path. Previously a single
38+
`yhat` factor column (class labels from `object$predicted`) was
39+
stored, which prevented the multi-class pivot in `plot.gg_variable`
40+
from firing. Vote fractions are row-normalised to `[0, 1]` even
41+
when the forest was fit with `norm.votes = FALSE`.
42+
- `plot.gg_variable` binary classification: `smooth = TRUE` now
43+
correctly maps x/y aesthetics onto the smooth layer.
44+
- `plot.gg_variable` multi-class numeric path: `smooth = TRUE` now
45+
adds a smooth layer (was silently skipped).
46+
- Closes stale issues #81 (fixed in PR #83) and #82.
3447
* **varPro partial dependence: `gg_partial_varpro()` (#84).**
3548
- `gg_partial_varpro()` replaces `gg_partialpro()` as the primary entry
3649
point for varPro partial dependence plots. The new extractor accepts

R/gg_variable.R

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,16 @@ gg_variable.randomForest <- function(object,
271271
...) {
272272
arg_list <- list(...)
273273

274-
# randomForest objects do not store OOB predictions in a way that maps back
275-
# to the predictor space, so we always use in-bag (full-forest) predictions.
276-
if (!is.null(arg_list$oob)) {
277-
arg_list$oob <- FALSE
274+
# randomForest uses object$votes (OOB vote matrix) unconditionally — it is the
275+
# only honest per-class probability estimate. In-bag class probabilities are
276+
# not exposed through a consistent randomForest API, so oob=FALSE is not
277+
# supported. Warn the caller rather than silently ignoring the argument.
278+
if (!is.null(arg_list$oob) && identical(arg_list$oob, FALSE)) {
279+
warning(
280+
"oob = FALSE is not supported for randomForest objects: ",
281+
"in-bag class probabilities are unavailable. ",
282+
"OOB vote fractions (object$votes) will be used instead."
283+
)
278284
}
279285

280286
if (!inherits(object, "randomForest")) {
@@ -307,10 +313,20 @@ gg_variable.randomForest <- function(object,
307313
}
308314

309315
gg_dta <- predictors
310-
# Append the forest's in-bag predicted values.
311-
gg_dta$yhat <- as.vector(object$predicted)
316+
# For classification forests use per-class OOB vote fractions (object$votes),
317+
# stored as yhat.<classname> columns — the same shape gg_variable.rfsrc
318+
# produces. For regression a single numeric yhat column suffices.
312319
if (object$type == "classification") {
313-
gg_dta$yvar <- response
320+
preds <- object$votes # n × n_classes matrix; may be raw counts or fractions
321+
rs <- rowSums(preds)
322+
if (any(rs > 1 + 1e-8, na.rm = TRUE)) {
323+
preds <- preds / rs # normalise raw vote counts to [0, 1]
324+
}
325+
colnames(preds) <- paste0("yhat.", colnames(preds))
326+
gg_dta <- cbind(gg_dta, preds)
327+
gg_dta$yvar <- response
328+
} else {
329+
gg_dta$yhat <- as.vector(object$predicted)
314330
}
315331

316332
# randomForest uses object$type ("classification" / "regression"); the

R/plot.gg_variable.R

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,14 @@ plot.gg_variable <- function(x, # nolint: cyclocomp_linter
166166
gg_dta_y <- gg_dta[, grep("yhat.", colnames(gg_dta))]
167167
lng <- ncol(gg_dta_y)
168168
gg2 <- parallel::mclapply(seq_len(ncol(gg_dta_y)), function(ind) {
169-
cbind(gg_dta_x, yhat = gg_dta_y[, ind], outcome = ind)
169+
cbind(gg_dta_x, yhat = gg_dta_y[, ind],
170+
outcome = sub("^yhat\\.", "", colnames(gg_dta_y)[ind]))
170171
})
171172
gg3 <- do.call(rbind, gg2)
172-
gg3$outcome <- factor(gg3$outcome)
173+
# Use column order from gg_dta_y (not alphabetical) so facet panels
174+
# appear in the same order as the model's class levels.
175+
outcome_levels <- sub("^yhat\\.", "", colnames(gg_dta_y))
176+
gg3$outcome <- factor(gg3$outcome, levels = outcome_levels)
173177
gg_dta <- gg3
174178
}
175179
}
@@ -516,7 +520,10 @@ plot.gg_variable <- function(x, # nolint: cyclocomp_linter
516520
}
517521
if (smooth) {
518522
gg_plt[[ind]] <- gg_plt[[ind]] +
519-
ggplot2::geom_smooth(...)
523+
ggplot2::geom_smooth(
524+
ggplot2::aes(x = .data$var, y = .data$yhat),
525+
...
526+
)
520527
}
521528
} else {
522529
# Factor predictor: jitter + boxplot coloured by observed class
@@ -550,6 +557,13 @@ plot.gg_variable <- function(x, # nolint: cyclocomp_linter
550557
),
551558
...
552559
)
560+
if (smooth) {
561+
gg_plt[[ind]] <- gg_plt[[ind]] +
562+
ggplot2::geom_smooth(
563+
ggplot2::aes(x = .data$var, y = .data$yhat),
564+
...
565+
)
566+
}
553567
} else {
554568
gg_plt[[ind]] <- gg_plt[[ind]] +
555569
ggplot2::geom_boxplot(

tests/testthat/test_gg_variable.R

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,127 @@ test_that("gg_variable.randomForest classification: class attr uses 'class' not
368368
expect_false("classification" %in% class(gg_dta))
369369
expect_s3_class(gg_dta, "gg_variable")
370370
})
371+
372+
## ── randomForest classification (PR #87) ─────────────────────────────────────
373+
374+
test_that("gg_variable.randomForest classification: produces yhat.* columns not yhat", {
375+
skip_if_not_installed("randomForest")
376+
set.seed(42L)
377+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
378+
gg <- gg_variable(rf)
379+
# Must have one column per class
380+
expect_true(all(c("yhat.setosa", "yhat.versicolor", "yhat.virginica")
381+
%in% names(gg)))
382+
# Must NOT have a bare yhat column for multi-class
383+
expect_false("yhat" %in% names(gg))
384+
# Observed-class column must be present
385+
expect_true("yvar" %in% names(gg))
386+
# Vote fractions must be in [0, 1] and row-sum to ~1
387+
vote_cols <- c("yhat.setosa", "yhat.versicolor", "yhat.virginica")
388+
expect_true(all(gg[, vote_cols] >= 0))
389+
expect_true(all(gg[, vote_cols] <= 1))
390+
expect_true(all(abs(rowSums(gg[, vote_cols]) - 1) < 1e-6))
391+
})
392+
393+
test_that("gg_variable.randomForest classification: plot returns patchwork for all xvar", {
394+
skip_if_not_installed("randomForest")
395+
set.seed(42L)
396+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
397+
gg <- gg_variable(rf)
398+
p <- plot(gg)
399+
# iris has 4 predictors so the no-xvar default assembles a multi-panel
400+
# patchwork; assert patchwork specifically to catch regressions to a bare
401+
# list (#80).
402+
expect_s3_class(p, "patchwork")
403+
})
404+
405+
test_that("gg_variable.randomForest classification: layer_data works on single-xvar plot", {
406+
skip_if_not_installed("randomForest")
407+
set.seed(42L)
408+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
409+
gg <- gg_variable(rf)
410+
p <- plot(gg, xvar = "Sepal.Length")
411+
expect_no_error(ggplot2::layer_data(p, 1L))
412+
})
413+
414+
test_that("gg_variable.randomForest classification: norm.votes=FALSE still gives [0,1] fractions", {
415+
skip_if_not_installed("randomForest")
416+
set.seed(42L)
417+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L,
418+
norm.votes = FALSE)
419+
gg <- gg_variable(rf)
420+
vote_cols <- c("yhat.setosa", "yhat.versicolor", "yhat.virginica")
421+
expect_true(all(c("yhat.setosa", "yhat.versicolor", "yhat.virginica") %in% names(gg)))
422+
expect_true(all(gg[, vote_cols] >= 0))
423+
expect_true(all(gg[, vote_cols] <= 1))
424+
expect_true(all(abs(rowSums(gg[, vote_cols]) - 1) < 1e-6))
425+
})
426+
427+
test_that("plot.gg_variable RF classification: smooth=TRUE layer_data smokeable (binary smooth aes bug)", {
428+
skip_if_not_installed("randomForest")
429+
# Two-class subset to exercise the *binary* classification path
430+
set.seed(42L)
431+
bin_data <- iris[iris$Species != "virginica", ]
432+
bin_data$Species <- droplevels(bin_data$Species)
433+
rf <- randomForest::randomForest(Species ~ ., data = bin_data, ntree = 50L)
434+
gg <- gg_variable(rf)
435+
p <- plot(gg, xvar = "Sepal.Length", smooth = TRUE)
436+
# Before the fix, geom_smooth(...) has no aes and layer_data errors with
437+
# "stat_smooth() requires the following missing aesthetics: x and y"
438+
expect_no_error(ggplot2::layer_data(p, 2L))
439+
})
440+
441+
test_that("plot.gg_variable RF classification: smooth=TRUE works for multi-class (missing block)", {
442+
skip_if_not_installed("randomForest")
443+
set.seed(42L)
444+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
445+
gg <- gg_variable(rf)
446+
# Before the fix the multi-class numeric path silently skips smooth=TRUE
447+
# but does not error; after the fix a smooth layer is present (layer 2).
448+
p <- plot(gg, xvar = "Sepal.Length", smooth = TRUE)
449+
expect_s3_class(p, "ggplot")
450+
ld <- ggplot2::layer_data(p, 2L) # layer 2 = geom_smooth
451+
expect_gt(nrow(ld), 0L)
452+
})
453+
454+
test_that("plot.gg_variable RF classification multi-class: outcome column is class names not integers", {
455+
skip_if_not_installed("randomForest")
456+
set.seed(42L)
457+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
458+
gg <- gg_variable(rf)
459+
p <- plot(gg, xvar = "Sepal.Length")
460+
expect_s3_class(p, "ggplot")
461+
# The 'outcome' column in the plot data drives facet labels.
462+
# It must contain class names, not integer indices.
463+
# ggplot2 >= 3.5 uses S7 slots; fall back to $ accessor for older versions.
464+
pd <- tryCatch(p@data, error = function(e) p$data)
465+
expect_false(is.numeric(pd$outcome))
466+
expect_true(all(c("setosa", "versicolor", "virginica") %in% as.character(pd$outcome)))
467+
})
468+
469+
test_that("plot.gg_variable RF classification multi-class: outcome factor levels match column order", {
470+
skip_if_not_installed("randomForest")
471+
set.seed(42L)
472+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
473+
gg <- gg_variable(rf)
474+
p <- plot(gg, xvar = "Sepal.Length")
475+
pd <- tryCatch(p@data, error = function(e) p$data)
476+
# Levels must follow the yhat.* column order in gg_variable output,
477+
# not alphabetical order (which factor() would impose by default).
478+
expected_levels <- sub("^yhat\\.", "", grep("^yhat\\.", names(gg), value = TRUE))
479+
expect_equal(levels(pd$outcome), expected_levels)
480+
})
481+
482+
test_that("gg_variable.randomForest: oob=FALSE triggers a warning", {
483+
skip_if_not_installed("randomForest")
484+
set.seed(42L)
485+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
486+
# oob=FALSE is not supported for randomForest; a warning must be emitted
487+
# and OOB vote fractions are still returned.
488+
expect_warning(
489+
gg <- gg_variable(rf, oob = FALSE),
490+
regexp = "oob = FALSE is not supported"
491+
)
492+
expect_s3_class(gg, "gg_variable")
493+
expect_true("yhat.setosa" %in% names(gg))
494+
})

tests/testthat/test_snapshots.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,27 @@ local({
274274
})
275275
}
276276

277+
## ── randomForest classification snapshots (PR #87) ───────────────────────────
278+
if (requireNamespace("randomForest", quietly = TRUE)) {
279+
local({
280+
set.seed(42L)
281+
rf_iris <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
282+
gg_iris <- gg_variable(rf_iris)
283+
284+
test_that("snapshot: gg-variable-rf-classification-default", {
285+
vdiffr::expect_doppelganger(
286+
"gg-variable-rf-classification-default",
287+
plot(gg_iris)
288+
)
289+
})
290+
291+
test_that("snapshot: gg-variable-rf-classification-smooth", {
292+
vdiffr::expect_doppelganger(
293+
"gg-variable-rf-classification-smooth",
294+
plot(gg_iris, xvar = "Sepal.Length", smooth = TRUE)
295+
)
296+
})
297+
})
298+
}
299+
277300
} # end CI guard

0 commit comments

Comments
 (0)