Skip to content

Commit 9dc6c9c

Browse files
authored
Merge pull request #61 from ehrlinger/copilot_refactor
Refactor loop indices to use seq_along for improved readability and p…
2 parents 8840ed1 + b486919 commit 9dc6c9c

3 files changed

Lines changed: 19 additions & 6 deletions

File tree

R/gg_partial.R

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,12 @@ gg_partial <- function(part_dta,
7575
## ---- Categorical variable: few unique x values -------------------
7676
## VarPro works with logical or continuous only; factors are
7777
## one-hot encoded internally in the varPro call.
78+
## Normalize to character so bind_rows sees a consistent type; we'll
79+
## re-factor within each feature after stacking.
80+
x_chr <- as.character(x_vals)
81+
7882
plt.df <- dplyr::bind_cols(
79-
x = factor(x_vals),
83+
x = x_chr,
8084
yhat = part_dta$plotthis[[feature]]$yhat
8185
)
8286
plt.df$name <- names(part_dta$plotthis)[[feature]]
@@ -87,8 +91,17 @@ gg_partial <- function(part_dta,
8791

8892
# Combine per-variable lists into single data frames (NULL entries dropped)
8993
continuous <- dplyr::bind_rows(cont_list)
90-
categorical <- dplyr::bind_rows(cat_list)
91-
94+
if(length(cat_list) == 0) {
95+
categorical <- NA
96+
} else {
97+
categorical <- dplyr::bind_rows(cat_list)
98+
categorical <- dplyr::group_by(categorical, name)
99+
categorical <- dplyr::mutate(
100+
categorical,
101+
x = factor(x, levels = unique(x))
102+
)
103+
categorical <- dplyr::ungroup(categorical)
104+
}
92105
## Optionally attach a model label (useful when overlaying multiple forests)
93106
if (!is.null(model)) {
94107
continuous$model <- categorical$model <- model

R/plot.gg_roc.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) {
146146
st
147147
})
148148
# Tag each subset with its outcome index for colour/linetype mapping
149-
gg_dta <- parallel::mclapply(seq_len(length(gg_dta)), function(ind) {
149+
gg_dta <- parallel::mclapply(seq_along(gg_dta), function(ind) {
150150
gg_dta[[ind]]$outcome <- ind
151151
gg_dta[[ind]]
152152
})

tests/testthat/test_gg_variable.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ test_that("gg_variable classifications", {
4141
# Test return is s ggplot object
4242
expect_is(gg_plt, "list")
4343
expect_equal(length(gg_plt), length(rfsrc_iris$xvar.names))
44-
for (ind in seq_len(length(rfsrc_iris$xvar.names)))
44+
for (ind in seq_along(rfsrc_iris$xvar.names))
4545
expect_is(gg_plt[[ind]], "ggplot")
4646
## Test plotting the gg_error object
4747
gg_plt <- plot.gg_variable(gg_dta, xvar = rfsrc_iris$xvar.names,
@@ -109,7 +109,7 @@ test_that("gg_variable regression", {
109109
# Test return is s ggplot object
110110
expect_is(gg_plt, "list")
111111
expect_equal(length(gg_plt), length(rfsrc_boston$xvar.names))
112-
for (ind in seq_len(length(rfsrc_boston$xvar.names)))
112+
for (ind in seq_along(rfsrc_boston$xvar.names))
113113
expect_is(gg_plt[[ind]], "ggplot")
114114

115115

0 commit comments

Comments
 (0)