Skip to content

Commit 8840ed1

Browse files
authored
Merge pull request #60 from ehrlinger/copilot_refactor
Improve unit tests, add integration tests, and raise coverage to 83%
2 parents c9e01ce + ec215ca commit 8840ed1

39 files changed

Lines changed: 2038 additions & 435 deletions

.Rbuildignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@ framed.sty
2828
^CRAN-RELEASE$
2929
^CRAN-SUBMISSION$
3030
^\.github$
31+
^\.claude$
32+
^\.git$
33+
^\.vscode$
3134
^doc$
3235
^Meta$
3336
^_pkgdown\.yml$
3437
^docs$
3538
^pkgdown$
3639
^LICENSE\.md$
40+
^memory$

.claude/settings.local.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"permissions": {
3+
"allow": [
4+
"Bash(Rscript -e \"devtools::test\\(\\)\")",
5+
"Bash(Rscript -e \"devtools::test\\(reporter = testthat::SummaryReporter$new\\(\\)\\)\")",
6+
"Bash(Rscript -e \"testthat::set_max_fails\\(Inf\\); devtools::test\\(reporter = testthat::SummaryReporter$new\\(\\)\\)\")",
7+
"Bash(Rscript -e \"options\\(testthat.max_fails = Inf\\); devtools::test\\(reporter = testthat::SummaryReporter$new\\(\\)\\)\")",
8+
"Bash(Rscript -e \":*)",
9+
"Bash(Rscript --vanilla -e \":*)",
10+
"Bash(Rscript -e \"library\\(covr\\); cov <- package_coverage\\(''/Users/ehrlinj/Documents/GitHub/ggRandomForests/.claude/worktrees/lucid-herschel''\\); print\\(cov\\)\")",
11+
"Bash(Rscript -e \"rcmdcheck::rcmdcheck\\(args=''--no-manual'', error_on=''warning''\\)\")",
12+
"Bash(git add .Rbuildignore tests/testthat/test_gg_rfsrc.R tests/testthat/test_gg_variable.R tests/testthat/test_gg_partial.R tests/testthat/test_gg_partialpro.R tests/testthat/test_ggrandomforests_news.R tests/testthat/test_surv_partial.R tests/testthat/test_varpro_feature_names.R)"
13+
]
14+
}
15+
}

.claude/worktrees/lucid-herschel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 9b2abd8d14bd5c4d563cdc460ec2cf9ad4c99645

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"snyk.advanced.autoSelectOrganization": true
3+
}

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Package: ggRandomForests
22
Type: Package
33
Title: Visually Exploring Random Forests
4-
Version: 2.5.0
5-
Date: 2026-02-19
4+
Version: 2.6.0
5+
Date: 2026-03-04
66
Authors@R: person("John", "Ehrlinger",
77
role = c("aut", "cre"),
88
email = "john.ehrlinger@gmail.com")

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ export(gg_variable)
3636
export(gg_vimp)
3737
export(kaplan)
3838
export(nelson)
39+
export(plot.gg_error)
40+
export(plot.gg_rfsrc)
41+
export(plot.gg_roc)
42+
export(plot.gg_survival)
43+
export(plot.gg_variable)
44+
export(plot.gg_vimp)
3945
export(quantile_pts)
4046
export(surv_partial.rfsrc)
4147
export(varpro_feature_names)

R/calc_roc.R

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,45 +66,50 @@ calc_roc.rfsrc <-
6666
which_outcome = "all",
6767
oob = TRUE,
6868
...) {
69+
# Ensure response is a factor so levels() is well-defined
6970
if (!is.factor(dta)) {
7071
dta <- factor(dta)
7172
}
7273

74+
# Re-read oob from ... so callers can override the default
7375
arg_list <- as.list(substitute(list(...)))
7476

7577
oob <- FALSE
7678
if (!is.null(arg_list$oob) && is.logical(arg_list$oob)) {
7779
oob <- as.logical(arg_list$oob)
7880
}
7981

82+
# "all" outcomes not yet supported; fall back to the first class
8083
if (which_outcome == "all") {
8184
warning("Must specify which_outcome for now.")
8285
which_outcome <- 1
8386
}
87+
# Build (binary indicator, full-forest prediction, OOB prediction) triplet
8488
dta_roc <-
8589
data.frame(cbind(
8690
res = (dta == levels(dta)[which_outcome]),
8791
prd = object$predicted[, which_outcome],
8892
oob_prd = object$predicted.oob[, which_outcome]
8993
))
9094

91-
# Get the list of unique prob
95+
# Collect the unique predicted probability thresholds for the ROC sweep
9296
if (oob) {
9397
pct <- sort(unique(object$predicted.oob[, which_outcome]))
9498
} else {
9599
pct <- sort(unique(object$predicted[, which_outcome]))
96100
}
97101

98102
last <- length(pct)
103+
# Remove the maximum threshold (the cutpoint where nothing is classified
104+
# as positive), which produces the (sens=0, spec=1) anchor point
99105
pct <- pct[-last]
100106

101-
# Make sure we don't have to many points... if the training set was large,
102-
# This may break plotting all ROC curves in multiclass settings.
103-
# Arbitrarily reduce this to only include 200 points along the curve
107+
# Cap at 200 threshold points to keep multi-class ROC plots manageable
104108
if (last > 200) {
105109
pct <- pct[seq(1, length(pct), length.out = 200)]
106110
}
107111

112+
# For each threshold, build the 2×2 confusion table and extract TPR/TNR
108113
gg_dta <- parallel::mclapply(pct, function(crit) {
109114
if (oob) {
110115
tbl <- xtabs(~ res + (oob_prd > crit), dta_roc)
@@ -118,6 +123,7 @@ calc_roc.rfsrc <-
118123
})
119124

120125
gg_dta <- do.call(rbind, gg_dta)
126+
# Anchor curve at perfect specificity (0, 1) and perfect sensitivity (1, 0)
121127
gg_dta <- rbind(c(0, 1), gg_dta, c(1, 0))
122128

123129
gg_dta <- data.frame(gg_dta, row.names = seq_len(nrow(gg_dta)))
@@ -222,16 +228,15 @@ calc_roc.randomForest <-
222228
#' @aliases calc_auc calc_auc.gg_roc
223229
#' @export
224230
calc_auc <- function(x) {
225-
## Use the trapeziod rule, basically calc
226-
##
227-
## auc = dx/2(f(x_{i+1}) - f(x_i))
228-
##
229-
## f(x) is sensitivity, x is 1-specificity
231+
## Trapezoidal rule: AUC = Σ dx/2 * (f(x_{i+1}) + f(x_i))
232+
## Here f(x) is sensitivity (TPR) and x is 1 − specificity (FPR).
233+
## The shift() helper provides the lead value x_{i+1}.
230234

231-
# Since we are leading vectors (x_{i+1} - x_{i}), we need to
232-
# ensure we are in decreasing order of specificity (x var = 1-spec)
235+
# Sort in decreasing specificity so FPR increases left-to-right along the curve
233236
x <- x[order(x$spec, decreasing = TRUE), ]
234237

238+
# Trapezoidal approximation: average of consecutive sensitivity values
239+
# multiplied by the FPR increment (change in 1 - spec)
235240
auc <- (3 * shift(x$sens) - x$sens) / 2 * (x$spec - shift(x$spec))
236241
sum(auc, na.rm = TRUE)
237242
}

R/gg_error.R

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ gg_error <- function(object, ...) {
203203
}
204204
#' @export
205205
gg_error.rfsrc <- function(object, ...) {
206-
## Check that the input obect is of the correct type.
206+
## Check that the input object is of the correct type.
207207
if (!inherits(object, "rfsrc")) {
208208
stop(
209209
paste(
@@ -212,20 +212,27 @@ gg_error.rfsrc <- function(object, ...) {
212212
)
213213
)
214214
}
215+
# The forest must have been grown with tree.err = TRUE so that per-tree
216+
# OOB error rates are recorded in $err.rate.
215217
if (is.null(object$err.rate)) {
216218
stop("Performance values are not available for this forest.")
217219
}
218220

221+
# Convert the err.rate matrix (ntree × n_outcomes) to a data frame.
219222
gg_dta <- data.frame(object$err.rate)
220223

221-
# If there is only one column in the error rate... name it reasonably.
224+
# rfsrc wraps single-column matrices with a column name derived from the
225+
# object name; rename it to the neutral label "error" for downstream use.
222226
if ("object.err.rate" %in% colnames(gg_dta)) {
223227
colnames(gg_dta)[which(colnames(gg_dta) == "object.err.rate")] <-
224228
"error"
225229
}
226230

231+
# Add a sequential tree counter required by the x-axis of plot.gg_error.
227232
gg_dta$ntree <- seq_len(dim(gg_dta)[1])
228233

234+
# Optional in-bag training error: re-predict on the full training set using
235+
# the stored forest and record the resulting per-tree error trajectory.
229236
arg_list <- as.list(substitute(list(...)))
230237
training <- FALSE
231238
if (!is.null(arg_list$training)) {
@@ -249,7 +256,7 @@ gg_error.rfsrc <- function(object, ...) {
249256

250257
#' @export
251258
gg_error.randomForest <- function(object, ...) {
252-
## Check that the input obect is of the correct type.
259+
## Check that the input object is of the correct type.
253260
if (!inherits(object, "randomForest")) {
254261
stop(
255262
paste(
@@ -260,10 +267,10 @@ gg_error.randomForest <- function(object, ...) {
260267
}
261268

262269
if (!is.null(object$mse)) {
263-
# For regression
270+
# Regression forests store the cumulative OOB mean squared error in $mse.
264271
gg_dta <- data.frame(object$mse)
265272

266-
# If there is only one column in the error rate... name it reasonably.
273+
# Normalise the auto-generated column name to "error".
267274
if ("object.mse" %in% colnames(gg_dta)) {
268275
colnames(gg_dta)[which(colnames(gg_dta) == "object.mse")] <-
269276
"error"
@@ -277,14 +284,16 @@ gg_error.randomForest <- function(object, ...) {
277284
training <- arg_list$training
278285
}
279286

287+
# Optionally compute and append the per-tree in-bag training error curve.
280288
if (training) {
281289
train_curve <- .rf_training_curve(object)
282290
if (!is.null(train_curve)) {
283291
gg_dta$train <- train_curve
284292
}
285293
}
286294
} else if (!is.null(object$err.rate)) {
287-
# For classification
295+
# Classification forests store the cumulative OOB error matrix in
296+
# $err.rate (rows = trees, columns = overall + per-class error rates).
288297
gg_dta <- data.frame(object$err.rate)
289298

290299
gg_dta$ntree <- seq_len(nrow(gg_dta))

R/gg_partial.R

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,98 @@
11
##=============================================================================
2-
#' Split partial lots into continuous or categorical datasets
2+
#' Split partial dependence data into continuous or categorical datasets
3+
#'
4+
#' Takes the list returned by \code{rfsrc::plot.variable(partial = TRUE)} and
5+
#' separates the variables into two data frames: one for continuous predictors
6+
#' and one for categorical (factor-like) predictors. The split is controlled
7+
#' by \code{cat_limit}: variables with more unique x-values than this threshold
8+
#' are treated as continuous; all others are categorical.
9+
#'
310
#' @param part_dta partial plot data from \code{rfsrc::plot.variable}
411
#' @param nvars how many of the partial plot variables to calculate
5-
#' @param cat_limit Categorical features are build when there are fewer than
6-
#' cat_limit unique features.
12+
#' @param cat_limit Categorical features are built when there are fewer than
13+
#' \code{cat_limit} unique feature values.
714
#' @param model a label name applied to all features. Useful when combining
815
#' multiple partial plot objects in figures.
916
#'
17+
#' @return A named list with two elements:
18+
#' \describe{
19+
#' \item{continuous}{data.frame with columns \code{x}, \code{yhat},
20+
#' \code{name} (and optionally \code{model}) for continuous variables}
21+
#' \item{categorical}{data.frame with the same columns but with \code{x}
22+
#' as a factor, for low-cardinality / categorical variables}
23+
#' }
24+
#'
25+
#' @seealso \code{\link{gg_partial_rfsrc}} \code{\link{gg_partialpro}}
26+
#'
27+
#' @examples
28+
#' ## Build a small regression forest on the airquality dataset
29+
#' set.seed(42)
30+
#' airq <- na.omit(airquality)
31+
#' rf <- rfsrc(Ozone ~ ., data = airq, ntree = 50)
32+
#'
33+
#' ## Compute partial dependence via plot.variable (show.plots = FALSE to
34+
#' ## suppress the base-graphics output — we only want the data)
35+
#' pv <- randomForestSRC::plot.variable(rf, partial = TRUE,
36+
#' show.plots = FALSE)
37+
#'
38+
#' ## Split into continuous and categorical data frames
39+
#' result <- gg_partial(pv)
40+
#' head(result$continuous)
41+
#'
42+
#' ## Label this model for later comparison with a second forest
43+
#' result_labelled <- gg_partial(pv, model = "airq_model")
44+
#' unique(result_labelled$continuous$model)
45+
#'
1046
#' @export
11-
gg_partial = function(part_dta,
12-
nvars = NULL,
13-
cat_limit = 10,
14-
model = NULL) {
15-
## Prepare the partial dependencies data for panel plots
47+
gg_partial <- function(part_dta,
48+
nvars = NULL,
49+
cat_limit = 10,
50+
model = NULL) {
51+
## Default: process all variables returned by plot.variable
1652
if (is.null(nvars)) {
17-
nvars = length(part_dta$plotthis)
53+
nvars <- length(part_dta$plotthis)
1854
}
19-
20-
cont_list = list()
21-
cat_list = list()
55+
56+
# Accumulate per-variable data frames before binding
57+
cont_list <- list()
58+
cat_list <- list()
59+
2260
for (feature in seq(nvars)) {
23-
## Format any continuous features (those with fewer than cat_limit unique values)
24-
if (length(unique(part_dta$plotthis[[feature]]$x)) > cat_limit) {
25-
plt.df = dplyr::bind_cols(
26-
x = part_dta$plotthis[[feature]]$x,
61+
x_vals <- part_dta$plotthis[[feature]]$x
62+
63+
## ---- Continuous variable: more unique x values than cat_limit -------
64+
if (length(unique(x_vals)) > cat_limit) {
65+
plt.df <- dplyr::bind_cols(
66+
x = x_vals,
2767
yhat = part_dta$plotthis[[feature]]$yhat
2868
)
29-
plt.df$name = names(part_dta$plotthis)[[feature]]
30-
69+
# Tag each row with the variable name for downstream faceting
70+
plt.df$name <- names(part_dta$plotthis)[[feature]]
71+
3172
cont_list[[feature]] <- plt.df
32-
} else{
33-
## Categorical features
34-
35-
## Though VarPro works with logical or continuous only. Factors are
36-
## one hot encoded internal to the varPro call.
37-
plt.df = dplyr::bind_cols(
38-
x = factor(part_dta$plotthis[[feature]]$x),
73+
74+
} else {
75+
## ---- Categorical variable: few unique x values -------------------
76+
## VarPro works with logical or continuous only; factors are
77+
## one-hot encoded internally in the varPro call.
78+
plt.df <- dplyr::bind_cols(
79+
x = factor(x_vals),
3980
yhat = part_dta$plotthis[[feature]]$yhat
4081
)
41-
plt.df$name = names(part_dta$plotthis)[[feature]]
42-
82+
plt.df$name <- names(part_dta$plotthis)[[feature]]
83+
4384
cat_list[[feature]] <- plt.df
4485
}
4586
}
46-
continuous = dplyr::bind_rows(cont_list)
47-
categorical = dplyr::bind_rows(cat_list)
48-
87+
88+
# Combine per-variable lists into single data frames (NULL entries dropped)
89+
continuous <- dplyr::bind_rows(cont_list)
90+
categorical <- dplyr::bind_rows(cat_list)
91+
92+
## Optionally attach a model label (useful when overlaying multiple forests)
4993
if (!is.null(model)) {
5094
continuous$model <- categorical$model <- model
5195
}
52-
96+
5397
return(list(continuous = continuous, categorical = categorical))
54-
}
98+
}

0 commit comments

Comments
 (0)