Skip to content

Commit cf6229a

Browse files
ehrlingerclaude
andauthored
fix: map randomForest 'classification' type to 'class' in gg_variable class attr (#90)
randomForest stores family as object$type ("classification"), but the plot.gg_variable dispatcher and the rfsrc path both use "class". Add an explicit mapping so the class vector is consistent across engines and callers never need to special-case "classification". Adds one test verifying "class" is present and "classification" is absent. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 7e4197c commit cf6229a

2 files changed

Lines changed: 19 additions & 1 deletion

File tree

R/gg_variable.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,12 @@ gg_variable.randomForest <- function(object,
313313
gg_dta$yvar <- response
314314
}
315315

316-
class(gg_dta) <- c("gg_variable", object$type, class(gg_dta))
316+
# randomForest uses object$type ("classification" / "regression"); the
317+
# plot.gg_variable dispatcher and the rfsrc path both use "class" for
318+
# classification forests. Map here so the class attribute is consistent
319+
# and callers never need to special-case "classification".
320+
family_lbl <- if (object$type == "classification") "class" else object$type
321+
class(gg_dta) <- c("gg_variable", family_lbl, class(gg_dta))
317322
gg_dta <- .set_provenance(gg_dta, object)
318323
invisible(gg_dta)
319324
}

tests/testthat/test_gg_variable.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,16 @@ test_that("plot.gg_variable: missing xvar returns list for all predictors", {
355355
# Returns a single plottable object (patchwork composite for multiple predictors)
356356
expect_true(inherits(gg_plt, "patchwork") || inherits(gg_plt, "ggplot"))
357357
})
358+
359+
test_that("gg_variable.randomForest classification: class attr uses 'class' not 'classification'", {
360+
# randomForest stores the family in $type as "classification", but
361+
# plot.gg_variable and the rfsrc path both dispatch on "class".
362+
# Verify the mapping is applied so callers see a consistent class attribute.
363+
set.seed(42)
364+
rf_iris <- randomForest::randomForest(Species ~ ., data = iris, ntree = 50L)
365+
gg_dta <- gg_variable(rf_iris)
366+
367+
expect_true("class" %in% class(gg_dta))
368+
expect_false("classification" %in% class(gg_dta))
369+
expect_s3_class(gg_dta, "gg_variable")
370+
})

0 commit comments

Comments
 (0)