Skip to content

Commit c408621

Browse files
committed
test: add three failing tests for gg_variable.randomForest classification (PR #87)
Tests verify that gg_variable.randomForest classification forests produce: - Per-class vote fraction columns (yhat.setosa, yhat.versicolor, yhat.virginica) - No bare yhat column for multi-class prediction - Valid vote fractions (0-1) that row-sum to 1 - Plottable output for multi-xvar case - Layer_data access on single-xvar plots All three tests currently FAIL, as expected (TDD red phase). Implementation fix comes in T2.
1 parent 3052e6d commit c408621

1 file changed

Lines changed: 39 additions & 0 deletions

File tree

tests/testthat/test_gg_variable.R

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,42 @@ test_that("gg_variable.randomForest classification: class attr uses 'class' not
368368
expect_false("classification" %in% class(gg_dta))
369369
expect_s3_class(gg_dta, "gg_variable")
370370
})
371+
372+
## ── randomForest classification (PR #87) ─────────────────────────────────────
373+
374+
test_that("gg_variable.randomForest classification: produces yhat.* columns not yhat", {
375+
skip_if_not_installed("randomForest")
376+
set.seed(42L)
377+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
378+
gg <- gg_variable(rf)
379+
# Must have one column per class
380+
expect_true(all(c("yhat.setosa", "yhat.versicolor", "yhat.virginica")
381+
%in% names(gg)))
382+
# Must NOT have a bare yhat column for multi-class
383+
expect_false("yhat" %in% names(gg))
384+
# Observed-class column must be present
385+
expect_true("yvar" %in% names(gg))
386+
# Vote fractions must be in [0, 1] and row-sum to ~1
387+
vote_cols <- c("yhat.setosa", "yhat.versicolor", "yhat.virginica")
388+
expect_true(all(gg[, vote_cols] >= 0))
389+
expect_true(all(gg[, vote_cols] <= 1))
390+
expect_true(all(abs(rowSums(gg[, vote_cols]) - 1) < 1e-6))
391+
})
392+
393+
test_that("gg_variable.randomForest classification: plot returns patchwork for all xvar", {
394+
skip_if_not_installed("randomForest")
395+
set.seed(42L)
396+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
397+
gg <- gg_variable(rf)
398+
p <- plot(gg)
399+
expect_true(inherits(p, "patchwork") || inherits(p, "ggplot"))
400+
})
401+
402+
test_that("gg_variable.randomForest classification: layer_data works on single-xvar plot", {
403+
skip_if_not_installed("randomForest")
404+
set.seed(42L)
405+
rf <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
406+
gg <- gg_variable(rf)
407+
p <- plot(gg, xvar = "Sepal.Length")
408+
expect_no_error(ggplot2::layer_data(p, 1L))
409+
})

0 commit comments

Comments
 (0)