diff --git a/.Rbuildignore b/.Rbuildignore index 9b97b2ac..eee15a90 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -38,3 +38,9 @@ framed.sty ^pkgdown$ ^LICENSE\.md$ ^memory$ +^\.lintr$ +^CONTRIBUTING\.md$ +^code-review\.md$ +^release-checklist.*\.md$ +# FUSE filesystem temporaries (safe to ignore; R CMD build already skips dotfiles) +^R/\.fuse_hidden diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index a73e3aba..a7788a7d 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -28,6 +28,10 @@ jobs: env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} R_KEEP_PKG_SOURCE: yes + # Skip vdiffr visual-regression tests in CI until reference SVGs are + # committed. To regenerate: run testthat::snapshot_accept() locally, + # commit tests/testthat/_snaps/, then remove this line. + VDIFFR_RUN_TESTS: "false" steps: - uses: actions/checkout@v4 @@ -41,14 +45,18 @@ jobs: r-version: ${{ matrix.config.r }} http-user-agent: ${{ matrix.config.http-user-agent }} use-public-rspm: true - rtools-version: '42' + rtools-version: '44' - uses: r-lib/actions/setup-r-dependencies@v2 with: extra-packages: any::rcmdcheck + cache-version: 2 needs: check - uses: r-lib/actions/check-r-package@v2 with: upload-snapshots: true build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' + # Treat NOTEs as errors on CRAN-submission platforms (devel + release); + # warnings are always errors everywhere. + error_on: '"warning"' diff --git a/.github/workflows/check-release.yaml b/.github/workflows/check-release.yaml index 0383438c..ed70f886 100644 --- a/.github/workflows/check-release.yaml +++ b/.github/workflows/check-release.yaml @@ -14,8 +14,9 @@ jobs: env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} R_KEEP_PKG_SOURCE: yes + VDIFFR_RUN_TESTS: "false" steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: r-lib/actions/setup-pandoc@v2 @@ -27,6 +28,7 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: extra-packages: rcmdcheck + cache-version: 2 - uses: r-lib/actions/setup-tinytex@v2 - uses: r-lib/actions/check-r-package@v2 diff --git a/.github/workflows/check-standard.yaml b/.github/workflows/check-standard.yaml index 1c4ea786..00f6a853 100644 --- a/.github/workflows/check-standard.yaml +++ b/.github/workflows/check-standard.yaml @@ -27,9 +27,10 @@ jobs: env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} R_KEEP_PKG_SOURCE: yes + VDIFFR_RUN_TESTS: "false" steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: r-lib/actions/setup-pandoc@v2 @@ -43,6 +44,7 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: extra-packages: rcmdcheck + cache-version: 2 - uses: r-lib/actions/setup-tinytex@v2 - uses: r-lib/actions/check-r-package@v2 diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index c61d80fd..49fb17f5 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -1,20 +1,22 @@ -# Workflow derived from https://github.com/r-lib/actions/tree/master/examples +# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help on: push: branches: [main] pull_request: - branches: [main] -name: lint +name: lint.yaml + +permissions: read-all jobs: lint: runs-on: ubuntu-latest env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: r-lib/actions/setup-r@v2 with: @@ -22,8 +24,16 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: lintr + extra-packages: any::lintr, any::cyclocomp, any::pak + needs: lint + cache-version: 2 - name: Lint - run: lintr::lint_package() + run: | + lints <- lintr::lint_package() + print(lints) + if (length(lints) > 0L) { + message(sprintf("lintr found %d issue(s).", length(lints))) + quit(status = 1L) + } shell: Rscript {0} diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index b2e013fe..592acc2e 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -14,6 +14,7 @@ jobs: runs-on: ubuntu-latest env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + VDIFFR_RUN_TESTS: "false" steps: - uses: actions/checkout@v4 @@ -26,6 +27,7 @@ jobs: with: extra-packages: any::covr, any::xml2 needs: coverage + cache-version: 2 - name: Test coverage run: | @@ -38,15 +40,6 @@ jobs: covr::to_cobertura(cov) shell: Rscript {0} - - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} - # Fail if error if not on PR, or if on PR and token is given - fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN }} - file: ./cobertura.xml - plugin: noop - disable_search: true - - name: Show testthat output if: always() run: | @@ -60,8 +53,13 @@ jobs: with: name: coverage-test-failures path: ${{ runner.temp }}/package - + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} + # Fail if error if not on PR, or if on PR and token is given + fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN != '' }} + files: ./cobertura.xml + plugins: noop + disable_search: true diff --git a/.gitignore b/.gitignore index fd644183..f7752c9e 100644 --- a/.gitignore +++ b/.gitignore @@ -40,9 +40,8 @@ tests/testthat/Rplots.pdf vignettes/ggrfRegression.html vignettes/ggrfRegression.R docs -NAMESPACE -man/ -*.Rd +# FUSE filesystem temporaries (macOS / Linux FUSE driver artefacts) +.fuse_hidden* vignettes/ggRandomForests_files vignettes/ggRandomForests.html diff --git a/.lintr b/.lintr new file mode 100644 index 00000000..1a33cd85 --- /dev/null +++ b/.lintr @@ -0,0 +1,16 @@ +linters: linters_with_defaults( + line_length_linter(120), + object_name_linter(styles = c("snake_case", "dotted.case", "camelCase", "symbols")), + cyclocomp_linter(complexity_limit = 20), + T_and_F_symbol_linter = NULL, + return_linter = NULL, + indentation_linter = NULL, + object_length_linter = NULL, + object_usage_linter = NULL, + commented_code_linter = NULL + ) +exclusions: list( + "R/ggrandomforests.news.R", + "R/zzz.R" + ) +encoding: "UTF-8" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..faa4907c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,445 @@ +# Contributing to ggRandomForests + +Thank you for your interest in contributing! This guide is written for R programmers who are new to package development. It covers everything from setting up your environment to opening a pull request. + +--- + +## Table of contents + +1. [What you will need](#1-what-you-will-need) +2. [Getting the code](#2-getting-the-code) +3. [Package structure](#3-package-structure) +4. [The gg\_\* design pattern](#4-the-gg_-design-pattern) +5. [Making a change](#5-making-a-change) +6. [Writing tests](#6-writing-tests) +7. [Documentation standards](#7-documentation-standards) +8. [Code style](#8-code-style) +9. [Running the full check suite](#9-running-the-full-check-suite) +10. [Opening a pull request](#10-opening-a-pull-request) +11. [Getting help](#11-getting-help) + +--- + +## 1. What you will need + +| Tool | Why | Install | +|------|-----|---------| +| R >= 4.4.0 | Runtime | | +| RStudio or Positron | IDE with R-aware tools | | +| Git | Version control | | +| Quarto | Builds the vignette | | + +Install the development helper packages in R: + +```r +install.packages(c( + "devtools", # load, test, check, document in one place + "testthat", # testing framework + "roxygen2", # builds man/ pages from inline comments + "lintr", # style linter + "covr", # test coverage measurement + "pkgdown" # builds the website +)) +``` + +Install the package dependencies (randomForestSRC >= 3.4.0 is required): + +```r +install.packages(c("randomForestSRC", "randomForest", "ggplot2", + "dplyr", "tidyr", "survival")) +``` + +--- + +## 2. Getting the code + +```bash +# Fork the repo on GitHub first, then: +git clone https://github.com/YOUR-USERNAME/ggRandomForests.git +cd ggRandomForests + +# Add the upstream remote so you can pull future changes +git remote add upstream https://github.com/ehrlinger/ggRandomForests.git +``` + +Open `ggRandomForests.Rproj` in RStudio/Positron. Then load the package in development mode — this makes all functions available without installing: + +```r +devtools::load_all() # shortcut: Ctrl+Shift+L +``` + +Confirm it works: + +```r +library(randomForestSRC) +rf <- rfsrc(Species ~ ., data = iris) +plot(gg_error(rf)) +``` + +--- + +## 3. Package structure + +``` +R/ All source code — one file per function family + gg_*.R Data extraction functions (return gg_* objects) + plot.gg_*.R S3 plot methods (return ggplot objects) + help.R Package-level ?ggRandomForests documentation + zzz.R .onAttach startup message + +man/ Auto-generated — never edit by hand + *.Rd Built from Roxygen comments by devtools::document() + +tests/ + testthat/ + test_*.R One test file per R/ source file + +vignettes/ + ggRandomForests.qmd Main package vignette (Quarto) + +DESCRIPTION Package metadata, dependencies, version +NAMESPACE Exports/imports — auto-generated by roxygen2 +NEWS.md Changelog — add an entry for every user-visible change +``` + +**Key rule:** the `man/` and `NAMESPACE` files are always auto-generated. Run `devtools::document()` after any Roxygen change and commit the updated files. + +--- + +## 4. The gg\_\* design pattern + +Every feature in this package follows the same two-step pattern: + +``` +forest object + │ + ▼ +gg_*(forest) ← R/gg_*.R — data extraction, returns a gg_* data.frame + │ + ▼ +plot(gg_object) ← R/plot.gg_*.R — builds and returns a ggplot2 object +``` + +**Why two steps?** Keeping data and plotting separate means users can inspect, save, transform, or combine the intermediate data before plotting, and apply any `ggplot2` layers they want on top of the returned object. + +### The gg\_\* object + +A `gg_*` object is just a `data.frame` with extra class attributes: + +```r +# Example: what gg_vimp returns +class(gg_vimp(rf)) +# [1] "gg_vimp" "data.frame" +``` + +The extra class (`"gg_vimp"`) lets R dispatch `plot(gg_dta)` to `plot.gg_vimp` automatically through R's S3 system. + +### S3 dispatch for multiple forest packages + +Most `gg_*` functions support both `randomForestSRC` and `randomForest` objects. The pattern is: + +```r +# 1. Generic — dispatches based on class of `object` +gg_vimp <- function(object, ...) { + UseMethod("gg_vimp", object) +} + +# 2. rfsrc method +gg_vimp.rfsrc <- function(object, ...) { ... } + +# 3. randomForest method +gg_vimp.randomForest <- function(object, ...) { ... } +``` + +Both methods should return an identically structured `gg_*` object so that `plot.gg_vimp` works for either. + +### Example: adding a new gg\_\* function + +Suppose you want to add `gg_depth()` to plot average tree depth. Here is the skeleton: + +```r +# R/gg_depth.R + +#' Tree depth data object +#' +#' Extracts average depth statistics per tree from a random forest. +#' +#' @param object A fitted \code{\link[randomForestSRC]{rfsrc}} or +#' \code{\link[randomForest]{randomForest}} object. +#' @param ... Optional arguments passed to methods. +#' +#' @return A \code{gg_depth} \code{data.frame} with columns \code{ntree} +#' and \code{depth}. +#' +#' @seealso \code{\link{plot.gg_depth}} +#' +#' @examples +#' rf <- rfsrc(Species ~ ., data = iris) +#' plot(gg_depth(rf)) +#' +#' @export +gg_depth <- function(object, ...) { + UseMethod("gg_depth", object) +} + +#' @export +gg_depth.rfsrc <- function(object, ...) { + # ... extract depth data ... + gg_dta <- data.frame(ntree = seq_len(object$ntree), depth = depths) + class(gg_dta) <- c("gg_depth", class(gg_dta)) + invisible(gg_dta) +} +``` + +Then create `R/plot.gg_depth.R` following the same pattern as `plot.gg_error.R`. + +--- + +## 5. Making a change + +Always work on a branch — never commit directly to `main`: + +```bash +git checkout -b my-feature-name +``` + +The development cycle is: + +```r +devtools::load_all() # reload after editing source +devtools::test() # run tests +devtools::document() # rebuild man/ from Roxygen comments +devtools::check() # full R CMD check (slow — run before PR) +``` + +--- + +## 6. Writing tests + +Tests live in `tests/testthat/` and are named `test_.R` to match the file they cover. The framework is [testthat](https://testthat.r-lib.org). + +### Basic structure + +```r +# tests/testthat/test_gg_depth.R +test_that("gg_depth returns correct class for rfsrc", { + rf <- randomForestSRC::rfsrc(Species ~ ., data = iris, ntree = 50) + + gg_dta <- gg_depth(rf) + + expect_s3_class(gg_dta, "gg_depth") + expect_s3_class(gg_dta, "data.frame") + expect_true(all(c("ntree", "depth") %in% names(gg_dta))) + expect_equal(nrow(gg_dta), rf$ntree) +}) + +test_that("plot.gg_depth returns a ggplot", { + rf <- randomForestSRC::rfsrc(Species ~ ., data = iris, ntree = 50) + + gg_plt <- plot(gg_depth(rf)) + + expect_s3_class(gg_plt, "ggplot") +}) + +test_that("gg_depth throws on wrong input", { + expect_error(gg_depth("not a forest")) +}) +``` + +### Practical tips + +- Keep forests small in tests — `ntree = 50` is plenty, faster than the default 1000. +- Test the error path as well as the happy path (`expect_error`, `expect_warning`). +- Use `expect_s3_class()` rather than the older `expect_is()`. +- Avoid `set.seed()` unless you are explicitly testing something random — randomForestSRC results are stochastic and exact-value tests break across versions. + +Run tests for a single file during development: + +```r +testthat::test_file("tests/testthat/test_gg_depth.R") +``` + +Check coverage (aim for > 80%): + +```r +covr::package_coverage() +``` + +--- + +## 7. Documentation standards + +Documentation is written in [Roxygen2](https://roxygen2.r-lib.org) comments (lines starting with `#'`) immediately above each function. + +### Required sections for every exported function + +```r +#' Short one-line title +#' +#' One or two paragraphs describing what the function does and why. +#' +#' @param arg1 Type and meaning. Include the default value if there is one. +#' @param arg2 ... +#' +#' @return Describe what is returned: the class, the columns in any +#' data.frame, and any class attributes set on the object. +#' +#' @seealso \code{\link{related_function}} +#' +#' @examples +#' # A runnable example — must complete in < 10 seconds for CRAN +#' rf <- rfsrc(Species ~ ., data = iris, ntree = 50) +#' plot(gg_something(rf)) +#' +#' @export +``` + +### Rules + +- `@param` for every argument, including `...` when the extras are meaningful. +- `@return` must describe the shape of the output — not just the class name. +- `@seealso` links to the paired `plot.*` function (or the `gg_*` function from a `plot.*` file). +- `@examples` must be runnable without error by `R CMD check`. Wrap slow examples in `\donttest{}`. Never wrap in `\dontrun{}` unless they literally cannot run on CRAN (network, credentials, etc.). +- Internal helpers (not exported) get `@keywords internal` instead of `@export`. + +Rebuild the docs after any change: + +```r +devtools::document() +``` + +Then spot-check the result: + +```r +?gg_depth +``` + +### Updating NEWS.md + +Every user-visible change needs a bullet in `NEWS.md` under the appropriate version heading: + +```md +ggRandomForests v2.7.0 +===================== +* Add `gg_depth()` to visualise average tree depth per forest (#42) +``` + +--- + +## 8. Code style + +The package follows the [tidyverse style guide](https://style.tidyverse.org). Key points: + +| Rule | Good | Bad | +|------|------|-----| +| Spacing around operators | `x <- x + 1` | `x<-x+1` | +| Spaces after commas | `f(x, y)` | `f(x,y)` | +| Indentation | 2 spaces | tabs | +| Object names | `snake_case` | `camelCase`, `dotted.name` | +| Boolean checks | `!inherits(x, "foo")` | `inherits(x, "foo") == FALSE` | +| Safe sequences | `seq_len(n)` | `1:n` | +| Column references in aes() | `.data$col` or `.data[[var]]` | bare `col` or string `"col"` | +| `dplyr` column selection | `dplyr::select(tidyr::all_of(vars))` | `dplyr::select(vars)` | + +Check your code with lintr before opening a PR: + +```r +lintr::lint_package() +``` + +Common issues lintr flags: + +- Lines > 120 characters. +- `T` / `F` instead of `TRUE` / `FALSE`. +- Trailing whitespace. +- `1:n` instead of `seq_len(n)`. +- `inherits(x, "cls") == FALSE` instead of `!inherits(x, "cls")`. + +--- + +## 9. Running the full check suite + +Before opening a PR, run the same checks CI runs: + +```r +# Quick: just tests +devtools::test() + +# Thorough: full R CMD check (builds vignette, checks examples, etc.) +devtools::check() +``` + +A clean check means: +``` +0 errors ✔ | 0 warnings ✔ | 0 notes ✔ +``` + +One note about the package size or installed path is acceptable. Errors or warnings must be fixed before a PR can be merged. + +To reproduce the exact CI matrix locally you can use [rhub](https://r-hub.github.io/rhub/): + +```r +rhub::rhub_check() +``` + +--- + +## 10. Opening a pull request + +1. **Commit your changes** with a clear, present-tense message: + ```bash + git add R/gg_depth.R R/plot.gg_depth.R tests/testthat/test_gg_depth.R + git commit -m "Add gg_depth() for average tree depth visualisation" + ``` + +2. **Push to your fork:** + ```bash + git push origin my-feature-name + ``` + +3. **Open a PR** on GitHub against the `main` branch of `ehrlinger/ggRandomForests`. + +4. **PR description checklist** — include in the description: + - What problem does this solve or what feature does it add? + - Which functions are new or changed? + - Did you add or update tests? + - Did you add a `NEWS.md` entry? + - Does `devtools::check()` pass cleanly? + +5. **CI will run automatically** across macOS, Windows, and Linux on R release, devel, and oldrel-1. All checks must pass before merge. + +### Commit message conventions + +``` +Add gg_depth() for average tree depth ← new feature +Fix factor ordering in gg_partial categorical branch ← bug fix +Improve @return docs for gg_rfsrc ← documentation +Refactor bootstrap_survival to utils.R ← refactor +``` + +Avoid "WIP", "fix", or "update" with no context. + +--- + +## 11. Getting help + +- **Bug reports and feature requests:** [GitHub Issues](https://github.com/ehrlinger/ggRandomForests/issues). Search existing issues before filing a new one. +- **Questions about usage:** [GitHub Discussions](https://github.com/ehrlinger/ggRandomForests/discussions) or post on [Posit Community](https://community.rstudio.com). +- **randomForestSRC questions:** the [randomForestSRC documentation](https://www.randomforestsrc.org) and its own GitHub issues. + +When filing a bug, always include: + +```r +# Minimum reproducible example +library(ggRandomForests) +library(randomForestSRC) + +rf <- rfsrc(Species ~ ., data = iris, ntree = 50) +# ... the code that triggers the error ... + +sessionInfo() # paste this output into the issue +``` + +--- + +*Thank you for helping improve ggRandomForests!* diff --git a/DESCRIPTION b/DESCRIPTION index 6e107c61..1aecc858 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,8 @@ Package: ggRandomForests Type: Package Title: Visually Exploring Random Forests -Version: 2.6.1 -Date: 2026-03-16 +Version: 2.7.0 +Date: 2026-03-25 Authors@R: person("John", "Ehrlinger", role = c("aut", "cre"), email = "john.ehrlinger@gmail.com") @@ -29,9 +29,11 @@ Suggests: RColorBrewer, MASS, lintr, + covr, + vdiffr, datasets, - rmarkdown, - quarto, + rmarkdown, + quarto, pkgdown, pkgload VignetteBuilder: quarto diff --git a/NAMESPACE b/NAMESPACE index dfb2803c..bfb09bc3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -23,9 +23,6 @@ S3method(plot,gg_vimp) export(calc_auc) export(calc_roc) export(gg_error) -export(gg_error.randomForest) -export(gg_error.randomForest.formula) -export(gg_error.rfsrc) export(gg_partial) export(gg_partial_rfsrc) export(gg_partialpro) @@ -36,12 +33,6 @@ export(gg_variable) export(gg_vimp) export(kaplan) export(nelson) -export(plot.gg_error) -export(plot.gg_rfsrc) -export(plot.gg_roc) -export(plot.gg_survival) -export(plot.gg_variable) -export(plot.gg_vimp) export(quantile_pts) export(surv_partial.rfsrc) export(varpro_feature_names) @@ -71,6 +62,6 @@ importFrom(stringr,str_sub) importFrom(survival,Surv) importFrom(survival,strata) importFrom(survival,survfit) -importFrom(tidyr,gather) +importFrom(tidyr,pivot_longer) importFrom(utils,head) importFrom(utils,tail) diff --git a/NEWS.md b/NEWS.md index 731ece87..83c89a6e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,57 @@ Package: ggRandomForests -Version: 2.4.0 +Version: 2.7.0 + +ggRandomForests v2.7.0 +===================== +* Fix critical visual bug in `plot.gg_rfsrc`: all `aes()` calls used bare + string literals instead of `.data[[col]]`, causing every aesthetic to map + to a constant string rather than the underlying data column. All plot + types (regression, classification, survival) were affected. +* Fix `aes()` bare-string literals in `plot.gg_roc` multi-class branch; + remove unreachable `if (crv < 2)` dead-code branch. +* Fix `bootstrap_survival` CI-band indexing in `gg_rfsrc`: negative index + computed via `colnames()` was a no-op on large datasets and a latent crash + for data with ≤ 2 unique event times. +* Fix `gg_rfsrc.rfsrc`: `is.null(df[, col])` does not detect missing columns; + replaced with `!col %in% colnames()` guard. +* Fix `gg_rfsrc.randomForest`: method used non-existent `object$xvar`; now + recovers the training frame via `.rf_recover_model_frame()`. +* Fix legend suppression in `plot.gg_error` for single-outcome forests where + the data frame has no `variable` column. +* Fix `gg_vimp` and `plot.gg_vimp`: `1:nvar` replaced with `seq_len(nvar)` + in both S3 methods; `1:0` silently returned `c(1, 0)` instead of + `integer(0)` when `nvar == 0`. +* Migrate full test suite to testthat 3.x API: `expect_is` → + `expect_s3_class` / `expect_type` / `expect_true(is.*())`; + `expect_equivalent` → `expect_equal(ignore_attr = TRUE)`; all `context()` + calls removed; testthat 1.x `expect_that` / `is_identical_to` removed. +* Add `.lintr` package-level linter configuration; fix lintr spacing in + `gg_partial`. +* Improve GitHub Actions: `lint.yaml` now fails CI on any lint issue; + `R-CMD-check.yaml` treats warnings as errors and uses Rtools 44; + `test-coverage.yaml` duplicate codecov upload removed. +* Add `covr` and `vdiffr` to `Suggests`. + +ggRandomForests v2.6.1 +===================== +* Fix model-label assignment in `gg_partial` for categorical variable data +* Refactor `gg_partial` and `gg_partial_rfsrc` to improve factor-level + normalisation and categorical data handling + +ggRandomForests v2.6.0 +===================== +* Add and export new plotting functions; update existing plot documentation +* Improve unit and integration tests; overall coverage raised to 83% +* Remove `hvtiRutilities` internal dependency; clean up associated imports +* Refactor `gg_partial_rfsrc` to use `.data` pronoun for all `dplyr` calls + +ggRandomForests v2.5.0 +===================== +* Initial `gg_partial_rfsrc` function: computes partial dependence data + directly from an `rfsrc` model via `randomForestSRC::partial.rfsrc`, without + requiring a separate `plot.variable` call +* Add support for a grouping variable (`xvar2.name`) in `gg_partial_rfsrc` +* Improved vignette formatting and namespace usage ggRandomForests v2.4.0 ===================== diff --git a/R/calc_roc.R b/R/calc_roc.R index 1762216f..3c8353ee 100644 --- a/R/calc_roc.R +++ b/R/calc_roc.R @@ -11,6 +11,16 @@ #### ####********************************************************************** ####********************************************************************** +# Internal helper: normalise the which_outcome argument. +# "all" is not yet fully supported; fall back to class 1 with a warning. +.validate_which_outcome <- function(which_outcome) { + if (identical(which_outcome, "all")) { + warning("Must specify which_outcome for now.") + return(1L) + } + which_outcome +} + #' Receiver Operator Characteristic calculator #' #' @details For a randomForestSRC prediction and the actual @@ -20,15 +30,25 @@ #' This is a helper function for the \code{\link{gg_roc}} functions, and #' not intended for use by the end user. #' -#' @param object \code{\link[randomForestSRC]{rfsrc}} or -#' \code{\link[randomForestSRC]{predict.rfsrc}} object -#' containing predicted response -#' @param dta True response variable -#' @param which_outcome If defined, only show ROC for this response. -#' @param oob Use OOB estimates, the normal validation method (TRUE) -#' @param ... extra arguments passed to helper functions -#' -#' @return A \code{gg_roc} object +#' @param object A fitted \code{\link[randomForestSRC]{rfsrc}}, +#' \code{\link[randomForestSRC]{predict.rfsrc}}, or +#' \code{\link[randomForest]{randomForest}} classification object containing +#' predicted class probabilities. +#' @param dta A factor (or coercible to factor) of the true observed class +#' labels, one per observation. Typically \code{object$yvar} for rfsrc or +#' \code{object$y} for randomForest. +#' @param which_outcome Integer index of the class for which the ROC curve is +#' computed (e.g. \code{1} for the first class, \code{2} for the second). +#' Use \code{"all"} to request all classes (currently falls back to class 1 +#' with a warning). +#' @param oob Logical; if \code{TRUE} (default for rfsrc) use OOB predicted +#' probabilities. Forced to \code{FALSE} for \code{randomForest} objects. +#' @param ... Extra arguments passed to helper functions (currently unused). +#' +#' @return A \code{gg_roc} \code{data.frame} with columns \code{sens} +#' (sensitivity), \code{spec} (specificity), and \code{pct} (the probability +#' threshold), with one row per unique prediction value. Suitable for passing +#' to \code{\link{calc_auc}} or \code{\link{plot.gg_roc}}. #' #' @aliases calc_roc.rfsrc calc_roc.randomForest calc_roc #' @@ -71,19 +91,8 @@ calc_roc.rfsrc <- dta <- factor(dta) } - # Re-read oob from ... so callers can override the default - arg_list <- as.list(substitute(list(...))) - - oob <- FALSE - if (!is.null(arg_list$oob) && is.logical(arg_list$oob)) { - oob <- as.logical(arg_list$oob) - } - # "all" outcomes not yet supported; fall back to the first class - if (which_outcome == "all") { - warning("Must specify which_outcome for now.") - which_outcome <- 1 - } + which_outcome <- .validate_which_outcome(which_outcome) # Build (binary indicator, full-forest prediction, OOB prediction) triplet dta_roc <- data.frame(cbind( @@ -150,10 +159,7 @@ calc_roc.randomForest <- ...) { prd <- predict(object, type = "prob") - if (which_outcome == "all") { - warning("Must specify which_outcomefor now.") - which_outcome <- 1 - } + which_outcome <- .validate_which_outcome(which_outcome) dta_roc <- data.frame(cbind(res = (dta == levels(dta)[which_outcome]), prd = prd)) @@ -228,85 +234,18 @@ calc_roc.randomForest <- #' @aliases calc_auc calc_auc.gg_roc #' @export calc_auc <- function(x) { - ## Trapezoidal rule: AUC = Σ dx/2 * (f(x_{i+1}) + f(x_i)) + ## Trapezoidal rule: AUC = Σ (f(x_i) + f(x_{i+1})) / 2 * |Δx| ## Here f(x) is sensitivity (TPR) and x is 1 − specificity (FPR). - ## The shift() helper provides the lead value x_{i+1}. + ## Sort so that specificity decreases (FPR increases) left-to-right, + ## then each step moves one FPR increment to the right. - # Sort in decreasing specificity so FPR increases left-to-right along the curve + # Sort in decreasing specificity so FPR = 1-spec increases monotonically x <- x[order(x$spec, decreasing = TRUE), ] - # Trapezoidal approximation: average of consecutive sensitivity values - # multiplied by the FPR increment (change in 1 - spec) - auc <- (3 * shift(x$sens) - x$sens) / 2 * (x$spec - shift(x$spec)) + # Δ(FPR) = -(Δspec) — spec decreases, so (spec[i] - spec[i+1]) > 0 + # Average height of trapezoid = (sens[i] + sens[i+1]) / 2 + auc <- (x$sens + shift(x$sens)) / 2 * (x$spec - shift(x$spec)) # nolint: object_usage_linter sum(auc, na.rm = TRUE) } -calc_auc.gg_roc <- calc_auc - -#' lead function to shift by one (or more). -#' -#' @param x a vector of values -#' @param shift_by an integer of length 1, giving the number of positions -#' to lead (positive) or lag (negative) by -#' -#' @details Lead and lag are useful for comparing values offset by a constant -#' (e.g. the previous or next value) -#' -#' Taken from: -#' http://ctszkin.com/2012/03/11/generating-a-laglead-variables/ -#' -#' This function allows me to remove the dplyr::lead depends. Still suggest for -#' vignettes though. -#' -#' @examples -#' d <- data.frame(x = 1:15) -#' # generate lead variable -#' d$df_lead2 <- ggRandomForests:::shift(d$x, 2) -#' # generate lag variable -#' d$df_lag2 <- ggRandomForests:::shift(d$x, -2) -#' # -# > d -# x df_lead2 df_lag2 -# 1 1 3 NA -# 2 2 4 NA -# 3 3 5 1 -# 4 4 6 2 -# 5 5 7 3 -# 6 6 8 4 -# 7 7 9 5 -# 8 8 10 6 -# 9 9 NA 7 -# 10 10 NA 8 -#' # -# # shift_by is vectorized -# d$df_lead2 shift(d$x,-2:2) -# [,1] [,2] [,3] [,4] [,5] -# [1,] NA NA 1 2 3 -# [2,] NA 1 2 3 4 -# [3,] 1 2 3 4 5 -# [4,] 2 3 4 5 6 -# [5,] 3 4 5 6 7 -# [6,] 4 5 6 7 8 -# [7,] 5 6 7 8 9 -# [8,] 6 7 8 9 10 -# [9,] 7 8 9 10 NA -# [10,] 8 9 10 NA NA -shift <- function(x, shift_by = 1) { - stopifnot(is.numeric(shift_by)) - stopifnot(is.numeric(x)) - - if (length(shift_by) > 1) { - return(sapply(shift_by, shift, x = x)) - } - - out <- NULL - abs_shift_by <- abs(shift_by) - if (shift_by > 0) { - out <- c(tail(x, -abs_shift_by), rep(NA, abs_shift_by)) - } else if (shift_by < 0) { - out <- c(rep(NA, abs_shift_by), head(x, -abs_shift_by)) - } else { - out <- x - } - out -} +calc_auc.gg_roc <- calc_auc # nolint: object_name_linter diff --git a/R/gg_error.R b/R/gg_error.R index 67ad8177..284ed975 100644 --- a/R/gg_error.R +++ b/R/gg_error.R @@ -39,8 +39,11 @@ #' \code{training = TRUE} is honored an additional \code{train} column is #' included. #' -#' @seealso \code{\link{plot.gg_error}}, \code{\link[randomForestSRC]{rfsrc}}, -#' \code{\link[randomForest]{randomForest}} +#' @seealso \code{\link{plot.gg_error}}, \code{\link{gg_vimp}}, +#' \code{\link{gg_variable}}, +#' \code{\link[randomForestSRC]{rfsrc}}, +#' \code{\link[randomForest]{randomForest}}, +#' \code{\link[randomForestSRC]{plot.rfsrc}} #' #' @references #' Breiman L. (2001). Random forests, Machine Learning, 45:5-32. @@ -48,8 +51,9 @@ #' Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, #' Rnews, 7(2):25-31. #' -#' Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, Regression -#' and Classification (RF-SRC), R package version 1.4. +#' Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +#' Regression and Classification. R package version >= 3.4.0. +#' \url{https://cran.r-project.org/package=randomForestSRC} #' #' @aliases gg_error gg_error.rfsrc gg_error.randomForest #' @aliases gg_error.randomForest.formula @@ -82,7 +86,7 @@ #' ## ------------------------------------------------------------ #' ## Regression example #' ## ------------------------------------------------------------ -#' +#' #' ## ------------- airq data #' rfsrc_airq <- rfsrc(Ozone ~ ., #' data = airquality, @@ -94,7 +98,7 @@ #' #' # Plot the gg_error object #' plot(gg_dta) -#' +#' #' #' ## ------------- Boston data #' data(Boston, package = "MASS") @@ -113,17 +117,17 @@ #' # Plot the gg_error object #' plot(gg_dta) #' -#' +#' #' ## ------------- mtcars data #' rfsrc_mtcars <- rfsrc(mpg ~ ., data = mtcars, tree.err = TRUE) -#' +#' #' # Get a data.frame containing error rates #' gg_dta<- gg_error(rfsrc_mtcars) #' #' # Plot the gg_error object #' plot(gg_dta) -#' +#' #' #' ## ------------------------------------------------------------ #' ## Survival example @@ -196,8 +200,7 @@ #' #' @importFrom stats as.formula model.frame model.response na.omit predict qnorm #' -#' @export gg_error gg_error.rfsrc gg_error.randomForest -#' @export gg_error.randomForest.formula +#' @export gg_error <- function(object, ...) { UseMethod("gg_error", object) } @@ -233,11 +236,8 @@ gg_error.rfsrc <- function(object, ...) { # Optional in-bag training error: re-predict on the full training set using # the stored forest and record the resulting per-tree error trajectory. - arg_list <- as.list(substitute(list(...))) - training <- FALSE - if (!is.null(arg_list$training)) { - training <- arg_list$training - } + arg_list <- list(...) + training <- isTRUE(arg_list$training) if (training) { trn <- data.frame(cbind(object$xvar, object$yvar)) colnames(trn) <- c(object$xvar.names, object$yvar.names) @@ -278,11 +278,8 @@ gg_error.randomForest <- function(object, ...) { gg_dta$ntree <- seq_len(nrow(gg_dta)) - arg_list <- as.list(substitute(list(...))) - training <- FALSE - if (!is.null(arg_list$training)) { - training <- arg_list$training - } + arg_list <- list(...) + training <- isTRUE(arg_list$training) # Optionally compute and append the per-tree in-bag training error curve. if (training) { @@ -298,11 +295,8 @@ gg_error.randomForest <- function(object, ...) { gg_dta$ntree <- seq_len(nrow(gg_dta)) - arg_list <- as.list(substitute(list(...))) - training <- FALSE - if (!is.null(arg_list$training)) { - training <- arg_list$training - } + arg_list <- list(...) + training <- isTRUE(arg_list$training) if (training) { train_curve <- .rf_training_curve(object) if (!is.null(train_curve)) { diff --git a/R/gg_partial.R b/R/gg_partial.R index 7bbe3767..65fd3f59 100644 --- a/R/gg_partial.R +++ b/R/gg_partial.R @@ -91,21 +91,22 @@ gg_partial <- function(part_dta, # Combine per-variable lists into single data frames (NULL entries dropped) continuous <- dplyr::bind_rows(cont_list) - if(length(cat_list) == 0) { + if (length(cat_list) == 0) { categorical <- data.frame(x = character(0), yhat = numeric(0), name = character(0)) } else { categorical <- dplyr::bind_rows(cat_list) - categorical <- dplyr::group_by(categorical, .data$name) - categorical <- dplyr::mutate( - categorical, - x = factor(.data$x, levels = unique(.data$x)) - ) - categorical <- dplyr::ungroup(categorical) + # Set within-group factor levels to order-of-appearance (base-R, no .data pronoun) + split_grps <- split(seq_len(nrow(categorical)), categorical$name) + for (grp_idx in split_grps) { + vals <- categorical$x[grp_idx] + categorical$x[grp_idx] <- as.character(factor(vals, levels = unique(vals))) + } + categorical$x <- factor(categorical$x) } ## Optionally attach a model label (useful when overlaying multiple forests) if (!is.null(model)) { continuous$model <- model - if(!is.null(categorical) && nrow(categorical) > 0) { + if (!is.null(categorical) && nrow(categorical) > 0) { categorical$model <- model } } diff --git a/R/gg_partial_rfsrc.R b/R/gg_partial_rfsrc.R index 6e4cd81f..36baadf0 100644 --- a/R/gg_partial_rfsrc.R +++ b/R/gg_partial_rfsrc.R @@ -1,16 +1,37 @@ ##============================================================================= -#' Split partial lots into continuous or categorical datasets +#' Partial dependence data from an rfsrc model #' -#' gg_partial_rfsrc uses the \code{rfsrc::partial.rfsrc} to generate the partial -#' plot data internally. So you provide the \code{rfsrc::rfsrc} model, and the -#' xvar.names to generate the data. +#' Computes partial dependence for one or more predictors by calling +#' \code{\link[randomForestSRC]{partial.rfsrc}} internally, then splits the +#' results into separate data frames for continuous and categorical variables. +#' Unlike \code{\link{gg_partial}}, no separate \code{plot.variable} call is +#' required — supply the fitted \code{rfsrc} object directly. #' -#' @param rf_model \code{rfsrc::rfsrc} model -#' @param xvar.names list() Which variables to calculate partial plots -#' @param xvar2.name a single grouping feature that is in the newx dataset -#' @param newx a \code{data.frame} containing data to use for the partial plots -#' @param cat_limit Categorical features are build when there are fewer than -#' cat_limit unique features. +#' @param rf_model A fitted \code{\link[randomForestSRC]{rfsrc}} object. +#' @param xvar.names Character vector of predictor names for which partial +#' dependence should be computed. Must be a subset of \code{rf_model$xvar.names}. +#' @param xvar2.name Optional single character name of a grouping variable in +#' \code{newx}. When supplied, partial dependence is computed separately for +#' each unique level of this variable and a \code{grp} column is appended. +#' @param newx Optional \code{data.frame} of predictor values to evaluate +#' partial effects at. Defaults to the training data stored in +#' \code{rf_model$xvar}. All column names must match \code{rf_model$xvar.names}. +#' @param cat_limit Variables with fewer than \code{cat_limit} unique values in +#' \code{newx} are treated as categorical; all others are continuous. +#' Defaults to 10. +#' +#' @return A named list with two elements: +#' \describe{ +#' \item{continuous}{A \code{data.frame} with columns \code{x} (numeric), +#' \code{yhat}, \code{name} (variable name), and optionally \code{grp} +#' (the level of \code{xvar2.name}) and \code{time} (survival forests +#' only) for all continuous predictors.} +#' \item{categorical}{A \code{data.frame} with the same columns but +#' \code{x} kept as character, for low-cardinality predictors.} +#' } +#' +#' @seealso \code{\link{gg_partial}}, \code{\link[randomForestSRC]{partial.rfsrc}}, +#' \code{\link[randomForestSRC]{get.partial.plot.data}} #' #' @examples #' ## ------------------------------------------------------------ @@ -34,22 +55,22 @@ gg_partial_rfsrc <- function(rf_model, cat_limit = 10) { # Check the rfsrc type # rf_model$family - + # we supply new data, make sure we use that and that it is a dataframe... if (is.null(newx)) { newx = rf_model$xvar } - + if (sum(colnames(newx) %in% rf_model$xvar.names) != ncol(newx)) { - return("newx must be a dataframe with the same columns used to train the rfsrc object") + stop("newx must be a dataframe with the same columns used to train the rfsrc object") } - + if (!is.null(xvar.names)) { if (sum(xvar.names %in% colnames(newx)) != length(xvar.names)) { - return("xvar.names contains column names not found in the rfsrc object") + stop("xvar.names contains column names not found in the rfsrc object") } } - + if (is.null(xvar2.name)) { pdta <- lapply(xvar.names, function(xname) { xval <- unlist(newx |> @@ -69,7 +90,7 @@ gg_partial_rfsrc <- function(rf_model, } return(out_dta) }) - } else{ + } else { xv2 <- unique(unlist(newx |> dplyr::select(dplyr::all_of(xvar2.name)))) pdta <- lapply(xv2, function(x2val) { @@ -97,12 +118,14 @@ gg_partial_rfsrc <- function(rf_model, p1dta$grp <- x2val return(p1dta) }) - } pdta <- do.call("rbind", pdta) - continuous <- pdta |> dplyr::filter(.data$type == "continuous") |> - mutate(x = as.numeric(.data$x)) |> dplyr::select(-"type") - categorical <- pdta |> dplyr::filter(.data$type == "categorical") |> - dplyr::select(-"type") - return(list(continuous = continuous, categorical = categorical)) -} \ No newline at end of file + # Split into continuous / categorical and tidy up the type column + cont_idx <- pdta$type == "continuous" + continuous <- pdta[cont_idx, , drop = FALSE] + continuous$x <- as.numeric(continuous$x) + continuous$type <- NULL + categorical <- pdta[!cont_idx, , drop = FALSE] + categorical$type <- NULL + list(continuous = continuous, categorical = categorical) +} diff --git a/R/gg_rfsrc.R b/R/gg_rfsrc.R index 2eb1d5f2..81ab4751 100644 --- a/R/gg_rfsrc.R +++ b/R/gg_rfsrc.R @@ -18,23 +18,58 @@ #' \code{\link[randomForestSRC]{rfsrc}} object, and formats data for plotting #' the response using \code{\link{plot.gg_rfsrc}}. #' -#' @param object \code{\link[randomForestSRC]{rfsrc}} object -#' @param by stratifying variable in the training dataset, defaults to NULL -#' @param oob boolean, should we return the oob prediction , or the full -#' forest prediction. -#' @param ... extra arguments -#' -#' @return \code{gg_rfsrc} object +#' @param object A fitted \code{\link[randomForestSRC]{rfsrc}} or +#' \code{\link[randomForest]{randomForest}} object. +#' @param by Optional stratifying variable. Either a character column name +#' present in the training data, or a vector/factor of the same length as +#' the training set. When supplied, a \code{group} column is added to the +#' returned data and bootstrap CI bands (survival) are computed per group. +#' Omit or leave missing to return an unstratified result. +#' @param oob Logical; if \code{TRUE} (default) return out-of-bag predictions. +#' Set to \code{FALSE} to use full in-bag (training) predictions. Forced to +#' \code{FALSE} automatically for \code{predict.rfsrc} objects, which carry +#' no OOB estimates. +#' @param ... Additional arguments controlling output for specific forest +#' families: +#' \describe{ +#' \item{surv_type}{Character; one of \code{"surv"} (default), +#' \code{"chf"}, or \code{"mortality"} for survival forests.} +#' \item{conf.int}{Numeric coverage probability (e.g. \code{0.95}) to +#' request bootstrap pointwise confidence bands for survival forests. +#' Triggers wide-format output with \code{lower}, \code{upper}, +#' \code{median}, and \code{mean} columns.} +#' \item{bs.sample}{Integer; number of bootstrap resamples when +#' \code{conf.int} is set. Defaults to the number of observations.} +#' } +#' +#' @return A \code{gg_rfsrc} object (a classed \code{data.frame}) whose +#' structure depends on the forest family: +#' \describe{ +#' \item{regression}{Columns \code{yhat} and the response name; optionally +#' a \code{group} column when \code{by} is supplied.} +#' \item{classification}{One column per class with predicted probabilities; +#' a \code{y} column with observed class labels; optionally \code{group}.} +#' \item{survival (no CI / grouping)}{Long-format with columns +#' \code{variable} (event time), \code{value} (survival probability), +#' \code{obs_id}, and \code{event}.} +#' \item{survival (with \code{conf.int} or \code{by})}{Wide-format with +#' pointwise bootstrap CI columns (\code{lower}, \code{upper}, +#' \code{median}, \code{mean}) per time point; a \code{group} column +#' when \code{by} is supplied.} +#' } +#' The object carries class attributes for the forest family so that +#' \code{\link{plot.gg_rfsrc}} dispatches correctly. #' #' @details -#' \code{surv_type} ("surv", "chf", "mortality", "hazard") for survival -#' forests -#' -#' \code{oob} boolean, should we return the oob prediction , or the full -#' forest prediction. +#' For survival forests, use the \code{surv_type} argument +#' (\code{"surv"}, \code{"chf"}, or \code{"mortality"}) to select the +#' predicted quantity. Bootstrap confidence bands are requested by passing +#' \code{conf.int} (e.g. \code{conf.int = 0.95}); the number of resamples +#' is controlled by \code{bs.sample}. #' -#' @seealso \code{\link{plot.gg_rfsrc}} \code{rfsrc} \code{plot.rfsrc} -#' \code{\link{gg_survival}} +#' @seealso \code{\link{plot.gg_rfsrc}}, +#' \code{\link[randomForestSRC]{rfsrc}}, +#' \code{\link{gg_survival}} #' #' @examples #' ## ------------------------------------------------------------ @@ -142,12 +177,12 @@ #' @aliases gg_rfsrc gg_rfsrc.rfsrc #' @export -gg_rfsrc.rfsrc <- function(object, +gg_rfsrc.rfsrc <- function(object, # nolint: cyclocomp_linter oob = TRUE, by, ...) { ## Check that the input object is of the correct type. - if (inherits(object, "rfsrc") == FALSE) { + if (!inherits(object, "rfsrc")) { stop( paste( "This function only works for Forests grown with the", @@ -178,11 +213,10 @@ gg_rfsrc.rfsrc <- function(object, grp <- by # Accept either a column name (character) or a pre-built vector/factor. if (is.character(grp)) { - if (is.null(object$xvar[, grp])) { + if (!grp %in% colnames(object$xvar)) { stop(paste("No column named", grp, "in forest training set.")) - } else { - grp <- object$xvar[, grp] } + grp <- object$xvar[, grp] } if (is.vector(grp) || is.factor(grp)) { @@ -196,8 +230,7 @@ gg_rfsrc.rfsrc <- function(object, stop( paste( "By argument should be either a vector, or colname", - "of training data", - nrow(object$xvar) + "of training data" ) ) } @@ -208,18 +241,19 @@ gg_rfsrc.rfsrc <- function(object, ## ---- Classification branch ----------------------------------------------- if (object$family == "class") { - # For binary classification rfsrc stores two-column probability matrices; - # we drop the first column (the "negative" class probability) for binary - # problems since it is redundant. Multi-class forests keep all columns. + # For binary classification rfsrc stores exactly two probability columns + # (one per class); we drop the first (the "negative" class probability) + # since it is redundant — prob(class 2) = 1 - prob(class 1). + # Multi-class forests (3+ classes) keep all columns. if (oob) { gg_dta <- - if (ncol(object$predicted.oob) <= 2) { + if (ncol(object$predicted.oob) == 2) { data.frame(cbind(object$predicted.oob[, -1])) } else { data.frame(cbind(object$predicted.oob)) } } else { - gg_dta <- if (ncol(object$predicted) <= 2) { + gg_dta <- if (ncol(object$predicted) == 2) { data.frame(cbind(object$predicted[, -1])) } else { data.frame(cbind(object$predicted)) @@ -287,10 +321,10 @@ gg_rfsrc.rfsrc <- function(object, if (is.null(arg_list$conf.int) && missing(by)) { # No grouping or CI requested: pivot to long form so plot.gg_rfsrc can # draw one survival step function per observation. - gathercols <- + pivot_cols <- colnames(gg_dta)[-which(colnames(gg_dta) %in% c("obs_id", "event"))] gg_dta_mlt <- - tidyr::gather(gg_dta, "variable", "value", tidyr::all_of(gathercols)) + tidyr::pivot_longer(gg_dta, tidyr::all_of(pivot_cols), names_to = "variable", values_to = "value") gg_dta_mlt$variable <- as.numeric(as.character(gg_dta_mlt$variable)) gg_dta_mlt$obs_id <- factor(gg_dta_mlt$obs_id) @@ -386,6 +420,28 @@ gg_rfsrc.rfsrc <- function(object, +#' Bootstrap pointwise confidence bands for a mean survival curve +#' +#' Draws \code{bs_samples} bootstrap resamples (with replacement) of the +#' per-observation survival curves stored in \code{gg_dta}, computes the column +#' means to obtain a bootstrapped mean curve per resample, then returns the +#' pointwise quantiles at \code{level_set} and the overall mean across +#' resamples. +#' +#' @param gg_dta A wide \code{data.frame} of survival probabilities as returned +#' by the survival branch of \code{\link{gg_rfsrc.rfsrc}}, before the +#' optional pivot to long form. Columns \code{obs_id}, \code{event}, and +#' \code{group} (if present) are excluded from the resampling. +#' @param bs_samples Integer; number of bootstrap resamples. +#' @param level_set Numeric vector of length 2 giving the lower and upper +#' quantile probabilities for the confidence band (e.g. \code{c(0.025, 0.975)} +#' for a 95\% CI). +#' +#' @return A \code{data.frame} with one row per unique event time and columns +#' \code{value} (time), \code{lower}, \code{upper}, \code{median}, and +#' \code{mean}. +#' +#' @keywords internal bootstrap_survival <- function(gg_dta, bs_samples, level_set) { ## Calculate the leave one out estimate of the mean survival gg_t <- @@ -412,21 +468,19 @@ bootstrap_survival <- function(gg_dta, bs_samples, level_set) { mean(rng[, t_pt]) }) + # gg_t already has obs_id/event/group stripped; rng and mn are indexed over + # time points only, so no further exclusion is needed. time_interest <- as.numeric(colnames(gg_t)) dta <- data.frame(cbind( time_interest, - t(rng)[-which(colnames(gg_dta) %in% - c("obs_id", "event")), ], - mn[-which(colnames(gg_dta) %in% - c("obs_id", "event"))] + t(rng), + mn )) - if (ncol(dta) == 5) { - colnames(dta) <- c("value", "lower", "upper", "median", "mean") - } else { - colnames(dta) <- c("value", level_set, "mean") - } + # rng always has 3 rows: lower quantile, upper quantile, median (.5). + # Name columns canonically so plot.gg_rfsrc can always find "lower"/"upper". + colnames(dta) <- c("value", "lower", "upper", "median", "mean") dta } @@ -442,8 +496,8 @@ gg_rfsrc.randomForest <- function(object, oob = TRUE, by, ...) { - ## Check that the input obect is of the correct type. - if (inherits(object, "randomForest") == FALSE) { + ## Check that the input object is of the correct type. + if (!inherits(object, "randomForest")) { stop( paste( "This function only works for Forests grown with the", @@ -456,42 +510,48 @@ gg_rfsrc.randomForest <- function(object, oob <- FALSE } + # Recover the training predictor frame once (needed for by-column lookup and + # dimension checks). randomForest stores predictors in $forest$xlevels keys + # but not the actual data; use .rf_recover_model_frame() for that. + rf_info <- .rf_recover_model_frame(object) # nolint: object_usage_linter + rf_xvar <- if (!is.null(rf_info)) rf_info$model_frame[ + , setdiff(colnames(rf_info$model_frame), rf_info$response_name), + drop = FALSE + ] else NULL + n_train <- length(object$predicted) + if (!missing(by)) { grp <- by - # If the by argument is a vector, make sure it is the correct length + # Accept either a column name (character) or a pre-built vector/factor. if (is.character(grp)) { - if (is.null(object$xvar[, grp])) { + if (is.null(rf_xvar) || !grp %in% colnames(rf_xvar)) { stop(paste("No column named", grp, "in forest training set.")) - } else { - grp <- object$xvar[, grp] } + grp <- rf_xvar[, grp] } if (is.vector(grp) || is.factor(grp)) { - if (length(grp) != nrow(object$xvar)) { + if (length(grp) != n_train) { stop(paste( "By argument does not have the correct dimension ", - nrow(object$xvar) + n_train )) } } else { stop( paste( "By argument should be either a vector, or colname", - "of training data", - nrow(object$xvar) + "of training data" ) ) } grp <- factor(grp, levels = unique(grp)) } - # gg_variable is really just the training data and the outcome. - gg_dta <- get(as.character(object$call$data)) - - # Remove the response from the data.frame + # Extract the response variable name from the formula for column naming below. + # (Both branches below build gg_dta from the forest's stored predictions, so + # there is no need to recover the original training data frame here.) rsp <- as.character(object$call$formula)[2] - gg_dta <- gg_dta[, -which(colnames(gg_dta) == rsp)] # Do the work... if (object$type == "classification") { diff --git a/R/gg_roc.R b/R/gg_roc.R index aac6c8df..08b8d4ad 100644 --- a/R/gg_roc.R +++ b/R/gg_roc.R @@ -12,19 +12,39 @@ ####********************************************************************** ####********************************************************************** #' -#' ROC (Receiver operator curve) data from a classification random forest. +#' ROC (Receiver Operating Characteristic) curve data from a classification forest. #' -#' The sensitivity and specificity of a randomForest classification object. +#' Computes sensitivity (true positive rate) and specificity (1 - false positive +#' rate) across all prediction thresholds for one class of a classification +#' \code{\link[randomForestSRC]{rfsrc}} or +#' \code{\link[randomForest]{randomForest}} object. #' -#' @param object an \code{\link[randomForestSRC]{rfsrc}} classification object -#' @param which_outcome select the classification outcome of interest. -#' @param oob use oob estimates (default TRUE) -#' @param ... extra arguments (not used) +#' @param object A classification \code{\link[randomForestSRC]{rfsrc}} or +#' \code{\link[randomForest]{randomForest}} object. Only forests with +#' \code{family == "class"} (rfsrc) or \code{type == "classification"} +#' (randomForest) are supported. +#' @param which_outcome Integer index or character name of the class for which +#' the ROC curve is computed. For binary forests this is typically \code{1} +#' or \code{2}; for multi-class forests any valid class index. Use +#' \code{which_outcome = 0} to obtain the overall (averaged) ROC. +#' @param oob Logical; if \code{TRUE} (default) use out-of-bag predicted +#' probabilities for the curve. Set to \code{FALSE} to use full in-bag +#' predictions. +#' @param ... Extra arguments (currently unused). #' -#' @return \code{gg_roc} \code{data.frame} for plotting ROC curves. +#' @return A \code{gg_roc} \code{data.frame} with one row per unique prediction +#' threshold and columns: +#' \describe{ +#' \item{sens}{Sensitivity (true positive rate) at each threshold.} +#' \item{spec}{Specificity (true negative rate) at each threshold.} +#' \item{yvar}{The observed class label for each observation.} +#' } +#' Pass to \code{\link{calc_auc}} for the area under the curve. #' -#' @seealso \code{\link{plot.gg_roc}} \code{\link[randomForestSRC]{rfsrc}} -#' \code{\link[randomForest]{randomForest}} +#' @seealso \code{\link{plot.gg_roc}}, \code{\link{calc_roc}}, +#' \code{\link{calc_auc}}, +#' \code{\link[randomForestSRC]{rfsrc}}, +#' \code{\link[randomForest]{randomForest}} #' #' @examples #' ## ------------------------------------------------------------ @@ -63,15 +83,15 @@ #' @aliases gg_roc gg_roc.rfsrc gg_roc.randomForest #' @export -gg_roc.rfsrc <- function(object, which_outcome, oob, ...) { +gg_roc.rfsrc <- function(object, which_outcome, oob = TRUE, ...) { # Validate that the object was grown with randomForestSRC (grow or predict) # or is a randomForest object — the two supported class signatures. if (sum(inherits(object, c("rfsrc", "grow"), TRUE) == c(1, 2)) != 2 && sum(inherits(object, c("rfsrc", "predict"), TRUE) == c(1, 2)) != 2 && !inherits(object, "randomForest")) { stop( - "This function only works for objects of class `(rfsrc, grow)', - '(rfsrc, predict)' or 'randomForest." + "This function only works for objects of class '(rfsrc, grow)', ", + "'(rfsrc, predict)', or 'randomForest'." ) } # ROC curves only make sense for classification; reject other families early. @@ -94,7 +114,7 @@ gg_roc.rfsrc <- function(object, which_outcome, oob, ...) { # Delegate the threshold-sweep computation to calc_roc, passing the # observed response vector and the chosen outcome column index. - gg_dta <- + gg_dta <- # nolint: object_usage_linter calc_roc(object, object$yvar, which_outcome = which_outcome, @@ -105,17 +125,16 @@ gg_roc.rfsrc <- function(object, which_outcome, oob, ...) { invisible(gg_dta) } #' @export -gg_roc <- function(object, which_outcome, oob, ...) { +gg_roc <- function(object, which_outcome, oob = TRUE, ...) { UseMethod("gg_roc", object) } #' @export gg_roc.randomForest <- function(object, which_outcome, oob, ...) { # Validate that the object is a genuine randomForest instance. - if (sum(inherits(object, "randomForest", TRUE) == c(1, 2)) != 1) { + if (!inherits(object, "randomForest")) { stop( - "This function only works for objects of class `(rfsrc, grow)', - '(rfsrc, predict)' or 'randomForest." + "gg_roc.randomForest only works for objects of class 'randomForest'." ) } @@ -129,7 +148,7 @@ gg_roc.randomForest <- function(object, which_outcome, oob, ...) { } # For randomForest objects the response is stored in $y (not $yvar). - gg_dta <- + gg_dta <- # nolint: object_usage_linter calc_roc(object, object$y, which_outcome = which_outcome diff --git a/R/gg_survival.R b/R/gg_survival.R index fce45f3d..e206de14 100644 --- a/R/gg_survival.R +++ b/R/gg_survival.R @@ -18,15 +18,21 @@ #' nonparametric survival estimates using either \code{\link{nelson}}-Aalen #' or \code{\link{kaplan}}-Meier estimates. #' -#' @param data name of the training data.frame -#' @param interval name of the interval variable in the training dataset. -#' @param censor name of the censoring variable in the training dataset. -#' @param by stratifying variable in the training dataset, defaults to NULL -#' @param type one of ("kaplan","nelson"), defaults to Kaplan-Meier -#' @param ... extra arguments passed to Kaplan or Nelson functions. +#' @param data A \code{data.frame} containing the survival data. +#' @param interval Character; name of the time-to-event column in \code{data}. +#' @param censor Character; name of the event-indicator column in \code{data} +#' (1 = event occurred, 0 = censored). +#' @param by Optional character; name of a grouping column in \code{data} for +#' stratified estimates. Defaults to \code{NULL} (unstratified). +#' @param type One of \code{"kaplan"} (Kaplan-Meier, default) or +#' \code{"nelson"} (Nelson-Aalen cumulative hazard). +#' @param ... Additional arguments passed to \code{\link{kaplan}} or +#' \code{\link{nelson}} (e.g. \code{conf.int} to change the CI width). #' -#' @return A \code{gg_survival} object created using the non-parametric -#' Kaplan-Meier or Nelson-Aalen estimators. +#' @return A \code{gg_survival} \code{data.frame} with columns \code{time}, +#' \code{surv} (or \code{cum_haz} for Nelson-Aalen), \code{lower}, +#' \code{upper} (confidence limits), and \code{n.risk}. A \code{strata} +#' column is added when \code{by} is supplied. #' #' @seealso \code{\link{kaplan}} \code{\link{nelson}} #' @seealso \code{\link{plot.gg_survival}} @@ -61,7 +67,7 @@ #' ) #' #' plot(gg_dta, error = "lines") -#' +#' #' @export gg_survival <- function(interval = NULL, censor = NULL, @@ -74,7 +80,7 @@ gg_survival <- function(interval = NULL, # Delegate entirely to the selected estimator helper. Both kaplan() and # nelson() return a gg_survival object that plot.gg_survival can render. - gg_dta <- switch(type, + gg_dta <- switch(type, # nolint: object_usage_linter kaplan = kaplan( interval = interval, censor = censor, diff --git a/R/gg_variable.R b/R/gg_variable.R index 4108e312..c5faa47a 100644 --- a/R/gg_variable.R +++ b/R/gg_variable.R @@ -43,10 +43,15 @@ #' @param ... Optional arguments such as \code{time}, \code{time_labels}, and #' \code{oob} that tailor the marginal dependence extraction. #' -#' @return \code{gg_variable} object +#' @return A \code{gg_variable} object: a \code{data.frame} of all predictor +#' columns from the training data paired with the OOB (or in-bag) predicted +#' response. For survival forests each requested time horizon produces an +#' additional column named by \code{time_labels}. The object carries a +#' \code{"family"} class attribute (\code{"regr"}, \code{"class"}, or +#' \code{"surv"}) used by \code{\link{plot.gg_variable}} for dispatch. #' -#' @seealso \code{\link{plot.gg_variable}} -#' @seealso \code{\link[randomForestSRC]{plot.variable}} +#' @seealso \code{\link{plot.gg_variable}}, +#' \code{\link[randomForestSRC]{plot.variable}} #' #' @aliases gg_variable gg_variable.rfsrc #' @@ -70,7 +75,7 @@ #' ## ------------------------------------------------------------ #' ## regression #' ## ------------------------------------------------------------ -#' +#' #' ## -------- air quality data #' rfsrc_airq <- rfsrc(Ozone ~ ., data = airquality) #' gg_dta <- gg_variable(rfsrc_airq) @@ -86,7 +91,7 @@ #' plot(gg_dta, xvar = c("Solar.R", "Wind", "Temp", "Day"), panel = TRUE) #' #' plot(gg_dta, xvar = "Month", notch = TRUE) -#' +#' #' ## -------- motor trend cars data #' rfsrc_mtcars <- rfsrc(mpg ~ ., data = mtcars) #' @@ -112,7 +117,7 @@ #' xvar = c("cyl", "vs", "am", "gear", "carb"), panel = TRUE, #' notch = TRUE #' ) -#' +#' #' ## -------- Boston data #' data(Boston, package = "MASS") #' @@ -123,7 +128,7 @@ #' ## ------------------------------------------------------------ #' ## survival examples #' ## ------------------------------------------------------------ -#' +#' #' ## -------- veteran data #' ## survival #' data(veteran, package = "randomForestSRC") @@ -148,7 +153,7 @@ #' #' # Generate variable dependence plots for age and diagtime #' plot(gg_dta, xvar = "age") -#' +#' #' ## -------- pbc data #' ## We don't run this because of bootstrap confidence limits #' # We need to create this dataset @@ -207,7 +212,7 @@ #' #' # Generate coplots #' plot(gg_dta, xvar = c("age", "trig"), panel = TRUE, se = FALSE) -#' +#' #' #' @aliases gg_variable gg_variable.rfsrc gg_variable.randomForest #' @aliases gg_variable.random @@ -279,7 +284,7 @@ gg_variable.rfsrc <- function(object, } lng <- length(time) - for (ind in 1:lng) { + for (ind in seq_len(lng)) { if (ind > 1) { gg_dta_t_old <- gg_dta_t } @@ -340,7 +345,7 @@ gg_variable.randomForest <- function(object, # Reconstruct the training data from the stored call so we can pair # predictions with the original predictors. - training_info <- .rf_recover_model_frame(object) + training_info <- .rf_recover_model_frame(object) # nolint: object_usage_linter if (is.null(training_info)) { stop( "Unable to reconstruct the training data for this randomForest object.", diff --git a/R/gg_vimp.R b/R/gg_vimp.R index 18bf2d1b..1e95ec05 100644 --- a/R/gg_vimp.R +++ b/R/gg_vimp.R @@ -56,14 +56,14 @@ #' ## ------------------------------------------------------------ #' ## regression example #' ## ------------------------------------------------------------ -#' +#' #' ## -------- air quality data #' rfsrc_airq <- rfsrc(Ozone ~ ., airquality, #' importance = TRUE #' ) #' gg_dta <- gg_vimp(rfsrc_airq) #' plot(gg_dta) -#' +#' #' #' ## -------- Boston data #' data(Boston, package = "MASS") @@ -78,7 +78,7 @@ #' gg_dta <- gg_vimp(rf_boston) #' plot(gg_dta) #' -#' +#' #' ## -------- mtcars data #' rfsrc_mtcars <- rfsrc(mpg ~ ., #' data = mtcars, @@ -86,11 +86,11 @@ #' ) #' gg_dta <- gg_vimp(rfsrc_mtcars) #' plot(gg_dta) -#' +#' #' ## ------------------------------------------------------------ #' ## survival example #' ## ------------------------------------------------------------ -#' +#' #' ## -------- veteran data #' data(veteran, package = "randomForestSRC") #' rfsrc_veteran <- rfsrc(Surv(time, status) ~ ., @@ -160,7 +160,7 @@ #' # Restrict to only the top 10. #' gg_dta <- gg_vimp(rfsrc_pbc, nvar = 10) #' plot(gg_dta) -#' +#' #' @aliases gg_vimp gg_vimp.rfsrc gg_vimp.randomForest #' @aliases gg_vimp.randomForest.formula #' @export @@ -232,7 +232,12 @@ gg_vimp.rfsrc <- function(object, nvar, ...) { ) } } else { - # Look up by integer index (1-based into class columns). + # Look up by integer index. + # which.outcome = 0 → overall (across-class) importance, column 1 + # which.outcome = k → importance for class k, column k+1 + if (!is.numeric(arg_set$which.outcome) || arg_set$which.outcome < 0) { + stop("which.outcome must be a non-negative integer or a class name.") + } if (arg_set$which.outcome < ncol(gg_dta)) { gg_v <- data.frame(vimp = sort(gg_dta[, arg_set$which.outcome + 1], decreasing = TRUE @@ -243,11 +248,10 @@ gg_vimp.rfsrc <- function(object, nvar, ...) { )] } else { stop( - paste( - "which.outcome specified larger than the number of classes (+1).", - arg_set$which.outcome, - " >= ", - ncol(gg_dta) + paste0( + "which.outcome (", arg_set$which.outcome, ") is out of range. ", + "Valid values are 0 (overall) to ", ncol(gg_dta) - 1, + " (number of classes)." ) ) } @@ -259,12 +263,14 @@ gg_vimp.rfsrc <- function(object, nvar, ...) { } gg_dta <- gg_dta[seq_len(nvar), ] - gathercols <- + pivot_cols <- colnames(gg_dta)[-which(colnames(gg_dta) == "vars")] # Pivot from wide (one column per class) to long (one row per class-var pair). - gg_dta <- tidyr::gather( - gg_dta, "set", "vimp", - tidyr::all_of(gathercols) + gg_dta <- tidyr::pivot_longer( + gg_dta, + tidyr::all_of(pivot_cols), + names_to = "set", + values_to = "vimp" ) gg_dta <- gg_dta[order(gg_dta$vimp, decreasing = TRUE), ] gg_dta$vars <- factor(gg_dta$vars) @@ -278,7 +284,7 @@ gg_vimp.rfsrc <- function(object, nvar, ...) { gg_dta$vars[which(is.na(gg_dta$vars))] <- rownames(gg_dta)[which(is.na(gg_dta$vars))] - gg_dta <- gg_dta[1:nvar, ] + gg_dta <- gg_dta[seq_len(nvar), ] } # Convert vars to an ordered factor (reversed so the most important variable @@ -297,7 +303,7 @@ gg_vimp.rfsrc <- function(object, nvar, ...) { #' @export gg_vimp.randomForest <- function(object, nvar, ...) { - ## Check that the input obect is of the correct type. + ## Check that the input object is of the correct type. if (!inherits(object, "randomForest")) { stop( paste( @@ -316,7 +322,7 @@ gg_vimp.randomForest <- function(object, nvar, ...) { if (is.null(importance_est)) { vars <- object$forest$xvar.names if (is.null(vars)) { - training_info <- .rf_recover_model_frame(object) + training_info <- .rf_recover_model_frame(object) # nolint: object_usage_linter if (!is.null(training_info)) { vars <- setdiff(colnames(training_info$model_frame), training_info$response_name) } @@ -354,6 +360,9 @@ gg_vimp.randomForest <- function(object, nvar, ...) { cn <- colnames(gg_dta)[1] gg_dta <- gg_dta[order(gg_dta[, cn], decreasing = TRUE), ] + # Ensure a canonical "vimp" column exists so the positive-flag logic and + # plot.gg_vimp both work regardless of what randomForest named the column. + colnames(gg_dta)[1] <- "vimp" } } if (missing(nvar)) { @@ -390,6 +399,11 @@ gg_vimp.randomForest <- function(object, nvar, ...) { ) } } else { + # which.outcome = 0 → overall importance (column 1) + # which.outcome = k → class k importance (column k+1) + if (!is.numeric(arg_set$which.outcome) || arg_set$which.outcome < 0) { + stop("which.outcome must be a non-negative integer or a class name.") + } if (arg_set$which.outcome < ncol(gg_dta)) { gg_v <- data.frame(vimp = sort(gg_dta[, arg_set$which.outcome + 1], decreasing = TRUE @@ -400,11 +414,10 @@ gg_vimp.randomForest <- function(object, nvar, ...) { )] } else { stop( - paste( - "which.outcome specified larger than the number of classes (+1).", - arg_set$which.outcome, - " >= ", - ncol(gg_dta) + paste0( + "which.outcome (", arg_set$which.outcome, ") is out of range. ", + "Valid values are 0 (overall) to ", ncol(gg_dta) - 1, + " (number of classes)." ) ) } @@ -414,11 +427,13 @@ gg_vimp.randomForest <- function(object, nvar, ...) { gg_dta$vars <- rownames(gg_dta) } - gathercols <- + pivot_cols <- colnames(gg_dta)[-which(colnames(gg_dta) == "vars")] - gg_dta <- tidyr::gather( - gg_dta, "set", "vimp", - tidyr::all_of(gathercols) + gg_dta <- tidyr::pivot_longer( + gg_dta, + tidyr::all_of(pivot_cols), + names_to = "set", + values_to = "vimp" ) gg_dta <- gg_dta[order(gg_dta$vimp, decreasing = TRUE), ] gg_dta$vars <- factor(gg_dta$vars) @@ -426,7 +441,7 @@ gg_vimp.randomForest <- function(object, nvar, ...) { gg_dta$vars[which(is.na(gg_dta$vars))] <- rownames(gg_dta)[which(is.na(gg_dta$vars))] } - gg_dta <- gg_dta[1:nvar, ] + gg_dta <- gg_dta[seq_len(nvar), ] gg_dta$vars <- factor(gg_dta$vars, levels = rev(unique(gg_dta$vars))) diff --git a/R/ggrandomforests.news.R b/R/ggrandomforests.news.R index 48045d90..d47b9c67 100644 --- a/R/ggrandomforests.news.R +++ b/R/ggrandomforests.news.R @@ -1,3 +1,14 @@ +#' Display the ggRandomForests NEWS file +#' +#' Opens the package NEWS file in the system pager so users can read the +#' version history and change log without leaving their R session. +#' +#' @param ... Currently unused; reserved for future arguments. +#' +#' @return Called for its side-effect of opening the NEWS file in the system +#' pager (\code{file.show}). Returns \code{invisible(NULL)}. +#' +#' @keywords internal ggrandomforests.news <- function(...) { newsfile <- file.path(system.file(package="ggRandomForests"), "NEWS") file.show(newsfile) diff --git a/R/help.R b/R/help.R index 9a0c28b8..7ee3ee23 100644 --- a/R/help.R +++ b/R/help.R @@ -17,11 +17,12 @@ #' @title ggRandomForests: Visually Exploring Random Forests #' #' @description \code{ggRandomForests} is a utility package for -#' \code{randomForestSRC} (Ishwaran et.al. 2014, 2008, 2007) for survival, +#' \code{randomForestSRC} (Ishwaran and Kogalur) for survival, #' regression and classification forests and uses the \code{ggplot2} #' (Wickham 2009) package for plotting results. \code{ggRandomForests} is #' structured to extract data objects from the random forest and provides S3 #' functions for printing and plotting these objects. +#' Requires \code{randomForestSRC} >= 3.4.0. #' #' The \code{randomForestSRC} package provides a unified treatment of #' Breiman's (2001) random forests for a variety of data settings. Regression @@ -36,8 +37,8 @@ #' \item Separation of data and figures: \code{ggRandomForest} contains #' functions that operate on either the \code{\link[randomForestSRC]{rfsrc}} #' forest object directly, or on the output from \code{randomForestSRC} post -#' processing functions (i.e. \code{plot.variable}, -#' \code{find.interaction}) to generate intermediate \code{ggRandomForests} +#' processing functions (i.e. \code{plot.variable}) to generate intermediate +#' \code{ggRandomForests} #' data objects. S3 functions are provide to further process these objects and #' plot results using the \code{ggplot2} graphics package. Alternatively, #' users can use these data objects for additional custom plotting or @@ -77,8 +78,9 @@ #' @references #' Breiman, L. (2001). Random forests, Machine Learning, 45:5-32. #' -#' Ishwaran H. and Kogalur U.B. (2014). Random Forests for Survival, -#' Regression and Classification (RF-SRC), R package version 1.5.5.12. +#' Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +#' Regression and Classification. R package version >= 3.4.0. +#' \url{https://cran.r-project.org/package=randomForestSRC} #' #' Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R. R News #' 7(2), 25--31. diff --git a/R/kaplan.R b/R/kaplan.R index c47afb8a..be1c0b6c 100644 --- a/R/kaplan.R +++ b/R/kaplan.R @@ -56,7 +56,7 @@ kaplan <- function(interval, data, by = NULL, ...) { # Build a Surv object from the named columns in the data frame. - srv <- survival::Surv(time = data[[interval]], event = data[[censor]]) + srv <- survival::Surv(time = data[[interval]], event = data[[censor]]) # nolint: object_usage_linter # Fit the Kaplan-Meier estimator; stratify on `by` when provided. if (is.null(by)) { @@ -85,21 +85,7 @@ kaplan <- function(interval, ) # When stratifying, stitch a "groups" label column onto the table. - # Stratum boundaries are detected by finding where the time column resets - # (survfit concatenates strata end-to-end in ascending time order). - if (!is.null(by)) { - tm_splits <- - which(c(FALSE, sapply(2:nrow(tbl), function(ind) { - tbl$time[ind] < tbl$time[ind - 1] - }))) - - lbls <- levels(data[[by]]) - tbl$groups <- lbls[1] - - for (ind in 2:(length(tm_splits) + 1)) { - tbl$groups[tm_splits[ind - 1]:nrow(tbl)] <- lbls[ind] - } - } + if (!is.null(by)) tbl <- .label_strata(tbl, data, by) # nolint: object_usage_linter # Keep only rows where at least one event occurred — censoring-only rows # do not contribute new KM estimates. @@ -118,11 +104,12 @@ kaplan <- function(interval, mid_int <- (gg_dta$time + lag_t) / 2 lag_l <- 0 - # Cumulative expected life in each interval (trapezoidal-rule approximation). + # Cumulative expected life in each interval (trapezoidal rule): + # L(t_i) = L(t_{i-1}) + (S(t_{i-1}) + S(t_i)) / 2 * Δt_i life <- vector("numeric", length = dim(gg_dta)[1]) for (ind in seq_len(dim(gg_dta)[1])) { life[ind] <- - lag_l + delta_t[ind] * (3 * gg_dta[ind, "surv"] - lag_s[ind]) / 2 + lag_l + (lag_s[ind] + gg_dta[ind, "surv"]) / 2 * delta_t[ind] lag_l <- life[ind] } prp_life <- life / gg_dta$time diff --git a/R/nelson.R b/R/nelson.R index 879f74b5..92ae169c 100644 --- a/R/nelson.R +++ b/R/nelson.R @@ -59,7 +59,7 @@ #' #' plot(gg_dta, error = "bars") #' plot(gg_dta) -#' +#' #' @export nelson <- function(interval, @@ -75,7 +75,7 @@ nelson <- } # Build the Surv object and fit the (possibly stratified) estimator. - srv <- + srv <- # nolint: object_usage_linter survival::Surv(time = data[[interval]], event = data[[censor]]) if (is.null(by)) { srv_tab <- survival::survfit(srv ~ 1, ...) @@ -84,16 +84,7 @@ nelson <- survival::survfit(srv ~ survival::strata(data[[by]]), ...) } - # Nelson-Aalen cumulative hazard: Λ(t) = Σ d_i / n_i over t_i ≤ t - # (The loop below computes the partial sums; then we overwrite with - # -log(S(t)) which is equivalent and numerically identical for KM.) - hazard <- srv_tab$n.event / srv_tab$n.risk - cum_hazard <- vector() - for (i in seq_len(length(hazard))) { - cum_hazard[i] <- sum(hazard[1:i]) - } - cum_hazard <- c(cum_hazard, cum_hazard[length(cum_hazard)]) - # Use -log(S(t)) for consistency with the Kaplan-Meier relation. + # Cumulative hazard H(t) = -log(S(t)), consistent with the KM estimator. cum_hazard <- -log(srv_tab$surv) # Collect per-time-point statistics into a flat data frame. @@ -111,20 +102,8 @@ nelson <- ) ) - # Detect stratum boundaries by finding time resets in the concatenated - # survfit output, then label each row with its group name. - if (!is.null(by)) { - tm_splits <- which(c(FALSE, sapply(2:nrow(tbl), function(ind) { - tbl$time[ind] < tbl$time[ind - 1] - }))) - - lbls <- unique(data[[by]]) - tbl$groups <- lbls[1] - - for (ind in 2:(length(tm_splits) + 1)) { - tbl$groups[tm_splits[ind - 1]:nrow(tbl)] <- lbls[ind] - } - } + # Detect stratum boundaries and label each row with its group name. + if (!is.null(by)) tbl <- .label_strata(tbl, data, by) # nolint: object_usage_linter # Retain only rows with at least one event. gg_dta <- tbl[which(tbl[["dead"]] != 0), ] @@ -142,10 +121,12 @@ nelson <- mid_int <- (gg_dta$time + lag_time) / 2 lag_l <- 0 + # Cumulative expected life in each interval (trapezoidal rule): + # L(t_i) = L(t_{i-1}) + (S(t_{i-1}) + S(t_i)) / 2 * Δt_i life <- vector("numeric", length = dim(gg_dta)[1]) for (ind in seq_len(dim(gg_dta)[1])) { life[ind] <- - lag_l + delta_t[ind] * (3 * gg_dta[ind, "surv"] - lag_surv[ind]) / 2 + lag_l + (lag_surv[ind] + gg_dta[ind, "surv"]) / 2 * delta_t[ind] lag_l <- life[ind] } prp_life <- life / gg_dta$time diff --git a/R/plot.gg_error.R b/R/plot.gg_error.R index a51a6fb9..26812dad 100644 --- a/R/plot.gg_error.R +++ b/R/plot.gg_error.R @@ -17,11 +17,18 @@ #' A plot of the cumulative OOB error rates of the random forest as a #' function of number of trees. #' -#' @param x gg_error object created from a \code{\link[randomForestSRC]{rfsrc}} -#' object -#' @param ... extra arguments passed to \code{ggplot} functions +#' @param x A \code{\link{gg_error}} object created from either a +#' \code{\link[randomForestSRC]{rfsrc}} or a +#' \code{\link[randomForest]{randomForest}} object. A raw forest object +#' may also be supplied and will be passed through \code{\link{gg_error}} +#' automatically before plotting. +#' @param ... Extra arguments forwarded to the underlying \code{ggplot2} +#' geometry calls (e.g. \code{size}, \code{linetype}). #' -#' @return \code{ggplot} object +#' @return A \code{ggplot} object with \code{ntree} on the x-axis and +#' OOB error rate on the y-axis. Single-outcome forests (regression, +#' survival) produce a single line; multi-outcome forests (classification) +#' produce one coloured line per class. #' #' @details The gg_error plot is used to track the convergence of the #' randomForest. This figure is a reproduction of the error plot @@ -36,8 +43,9 @@ #' Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, Rnews, #' 7(2):25-31. #' -#' Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, Regression -#' and Classification (RF-SRC), R package version 1.4. +#' Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +#' Regression and Classification. R package version >= 3.4.0. +#' \url{https://cran.r-project.org/package=randomForestSRC} #' #' @examples #' ## Examples from RFSRC package... @@ -77,7 +85,7 @@ #' ## ------------- airq data #' rfsrc_airq <- rfsrc(Ozone ~ ., #' data = airquality, -#' na.action = "na.impute", +#' na.action = "na.impute", #' forest = TRUE, #' importance = TRUE, #' tree.err = TRUE, @@ -192,9 +200,8 @@ #' plot(gg_dta) #' #' @importFrom ggplot2 ggplot geom_line theme labs -#' @importFrom tidyr gather +#' @importFrom tidyr pivot_longer #' @export -#' @export plot.gg_error plot.gg_error <- function(x, ...) { gg_dta <- x @@ -210,15 +217,15 @@ plot.gg_error <- function(x, ...) { # Use points instead of lines when there is only one non-NA row (e.g. a # forest built with a single tree, or one where only ntree=1 has an error # rate recorded). A line plot with one point renders nothing visible. - point = FALSE + point <- FALSE if (nrow(na.omit(gg_dta)) < 2) { - point=TRUE + point <- TRUE } if (ncol(gg_dta) > 2) { # Multi-outcome (classification): gg_error has one column per class plus # the "ntree" column. Pivot to long form so we can colour by outcome. - gg_dta <- tidyr::gather(gg_dta, "variable", "value", -"ntree") + gg_dta <- tidyr::pivot_longer(gg_dta, -"ntree", names_to = "variable", values_to = "value") gg_plt <- ggplot2::ggplot(na.omit(gg_dta), ggplot2::aes(x = .data[["ntree"]], y = .data[["value"]], @@ -235,15 +242,18 @@ plot.gg_error <- function(x, ...) { gg_plt <- gg_plt + ggplot2::geom_point() + ggplot2::labs(x = "Number of Trees", y = "OOB Error Rate", color = "Outcome") - } else{ + } else { gg_plt <- gg_plt + ggplot2::geom_line() + ggplot2::labs(x = "Number of Trees", y = "OOB Error Rate", color = "Outcome") } # Hide the legend when there is only a single outcome variable — the colour - # key adds no information and clutters the plot. - if (length(unique(gg_dta$variable)) == 1) { + # key adds no information and clutters the plot. For single-outcome forests + # (regression / survival) the data is never gathered, so there is no + # "variable" column; suppress the legend unconditionally in that case. + if (!"variable" %in% names(gg_dta) || + length(unique(gg_dta$variable)) <= 1) { gg_plt <- gg_plt + ggplot2::theme(legend.position = "none") } return(gg_plt) diff --git a/R/plot.gg_rfsrc.R b/R/plot.gg_rfsrc.R index 646cb5fc..7249d876 100644 --- a/R/plot.gg_rfsrc.R +++ b/R/plot.gg_rfsrc.R @@ -16,15 +16,48 @@ #' #' Plot the predicted response from a \code{\link{gg_rfsrc}} object, the #' \code{\link[randomForestSRC]{rfsrc}} prediction, using the OOB prediction -#' from the forest. +#' from the forest. The plot type adapts automatically to the forest family: +#' jitter + boxplot for regression and classification, step curves for +#' survival. #' -#' @param x \code{\link{gg_rfsrc}} object created from a -#' \code{\link[randomForestSRC]{rfsrc}} object -#' @param ... arguments passed to \code{\link{gg_rfsrc}}. +#' @param x A \code{\link{gg_rfsrc}} object, or a raw +#' \code{\link[randomForestSRC]{rfsrc}} object (which will be passed through +#' \code{\link{gg_rfsrc}} automatically before plotting). +#' @param notch Logical; whether to draw notched boxplots for regression and +#' classification forests (default \code{TRUE}). Set \code{notch = FALSE} +#' to suppress notches when sample sizes are too small for reliable +#' confidence intervals on the median. +#' @param ... Additional arguments forwarded to the underlying +#' \code{ggplot2} geometry calls. Commonly useful arguments include: +#' \describe{ +#' \item{\code{alpha}}{Numeric in \eqn{[0,1]}; point/ribbon transparency. +#' For survival plots with confidence bands the ribbon alpha is +#' automatically halved relative to the value supplied here.} +#' \item{\code{size}}{Point or line size passed to \code{geom_jitter}, +#' \code{geom_step}, etc.} +#' } +#' Arguments that control \code{\link{gg_rfsrc}} (e.g. \code{conf.int}, +#' \code{surv_type}, \code{by}) should be applied when constructing the +#' \code{gg_rfsrc} object before calling \code{plot()}. #' -#' @return \code{ggplot} object +#' @return A \code{ggplot} object. The plot appearance depends on the forest +#' family stored in \code{x}: +#' \describe{ +#' \item{Regression (\code{"regr"})}{Jitter + notched boxplot of OOB +#' predicted values. If a \code{group} column is present the x-axis +#' shows each group label; otherwise observations are collapsed to a +#' single x-position.} +#' \item{Classification (\code{"class"})}{Binary: jitter + notched +#' boxplot of the predicted class probability. Multi-class: jitter +#' plot with one panel per class (class probabilities in long form).} +#' \item{Survival (\code{"surv"})}{Step curves of the ensemble survival +#' function. When \code{gg_rfsrc} was called with \code{conf.int}, +#' a shaded ribbon is added. When called with \code{by}, curves are +#' coloured by group.} +#' } #' #' @seealso \code{\link{gg_rfsrc}} \code{\link[randomForestSRC]{rfsrc}} +#' \code{\link[randomForest]{randomForest}} #' #' @references #' Breiman L. (2001). Random forests, Machine Learning, 45:5-32. @@ -32,8 +65,9 @@ #' Ishwaran H. and Kogalur U.B. (2007). Random survival forests for #' R, Rnews, 7(2):25-31. #' -#' Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, Regression -#' and Classification (RF-SRC), R package version 1.4. +#' Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +#' Regression and Classification. R package version >= 3.4.0. +#' \url{https://cran.r-project.org/package=randomForestSRC} #' #' @examples #' ## ------------------------------------------------------------ @@ -144,11 +178,11 @@ #' } #' #' @export -#' @export plot.gg_rfsrc -plot.gg_rfsrc <- function(x, ...) { +plot.gg_rfsrc <- function(x, notch = TRUE, ...) { gg_dta <- x - # Capture any extra named arguments (e.g. alpha, size) for geom calls + # Capture extra passthrough args (e.g. alpha for survival ribbons). + # notch is a named formal above so it is never present in arg_set. arg_set <- list(...) ## If the user passed a raw rfsrc object, extract predictions first @@ -159,24 +193,28 @@ plot.gg_rfsrc <- function(x, ...) { ## ---- Classification forest branch ------------------------------------ if (inherits(gg_dta, "class") || inherits(gg_dta, "classification")) { - if (ncol(gg_dta) < 3) { + # Identify probability columns (everything except "y" and "group") + non_prob <- c("y", "group") + prob_cols <- setdiff(colnames(gg_dta), non_prob) + prob_col <- prob_cols[1] + obs_col <- "y" + if (length(prob_cols) < 2) { # Binary classification: single probability column + observed class gg_plt <- ggplot2::ggplot(gg_dta) + ggplot2::geom_jitter( ggplot2::aes( x = 1, - y = colnames(gg_dta)[1], - color = colnames(gg_dta)[2], - shape = colnames(gg_dta)[2] + y = .data[[prob_col]], + color = .data[[obs_col]], + shape = .data[[obs_col]] ), ... ) + ggplot2::geom_boxplot( - ggplot2::aes(x = 1, y = colnames(gg_dta)[1]), + ggplot2::aes(x = 1, y = .data[[prob_col]]), outlier.colour = "transparent", fill = "transparent", - notch = TRUE, - ... + notch = notch ) + ggplot2::theme( axis.ticks = ggplot2::element_blank(), @@ -184,16 +222,17 @@ plot.gg_rfsrc <- function(x, ...) { ) } else { # Multi-class: gather all class probability columns into long form - gathercols <- colnames(gg_dta)[-which(colnames(gg_dta) == "y")] + pivot_cols <- setdiff(colnames(gg_dta), c("y", "group")) gg_dta_mlt <- - tidyr::gather(gg_dta, "variable", "value", tidyr::all_of(gathercols)) + tidyr::pivot_longer(gg_dta, tidyr::all_of(pivot_cols), names_to = "variable", values_to = "value") gg_plt <- ggplot2::ggplot( gg_dta_mlt, - ggplot2::aes(x = "variable", y = "value") + ggplot2::aes(x = .data$variable, y = .data$value) ) + - ggplot2::geom_jitter(ggplot2::aes(color = "y", shape = "y"), + ggplot2::geom_jitter( + ggplot2::aes(color = .data$y, shape = .data$y), alpha = .5 ) } @@ -216,41 +255,41 @@ plot.gg_rfsrc <- function(x, ...) { gg_plt <- ggplot2::ggplot(gg_dta) + ggplot2::geom_ribbon( ggplot2::aes( - x = "value", - ymin = "lower", - ymax = "upper", - fill = "group" + x = .data$value, + ymin = .data$lower, + ymax = .data$upper, + fill = .data$group ), alpha = alph, ... ) + ggplot2::geom_step(ggplot2::aes( - x = "value", - y = "median", - color = "group" + x = .data$value, + y = .data$median, + color = .data$group ), ...) } else { # Single-group survival curve with CI ribbon gg_plt <- ggplot2::ggplot(gg_dta) + ggplot2::geom_ribbon( ggplot2::aes( - x = "value", - ymin = "lower", - ymax = "upper" + x = .data$value, + ymin = .data$lower, + ymax = .data$upper ), alpha = alph ) + - ggplot2::geom_step(ggplot2::aes(x = "value", y = "median"), ...) + ggplot2::geom_step(ggplot2::aes(x = .data$value, y = .data$median), ...) } } else { # No confidence bands: draw one step line per observation gg_plt <- ggplot2::ggplot( gg_dta, ggplot2::aes( - x = "variable", - y = "value", - col = "event", - by = "obs_id" + x = .data$variable, + y = .data$value, + col = .data$event, + group = .data$obs_id ) ) + ggplot2::geom_step(...) @@ -264,24 +303,30 @@ plot.gg_rfsrc <- function(x, ...) { inherits(gg_dta, "regression")) { if ("group" %in% colnames(gg_dta)) { # Grouped regression: x-axis shows each group label - gg_plt <- ggplot2::ggplot(gg_dta, ggplot2::aes(x = "group", y = "yhat")) + gg_plt <- ggplot2::ggplot(gg_dta, ggplot2::aes(x = .data$group, y = .data$yhat)) } else { # Single-group regression: collapse to a single x position - gg_plt <- ggplot2::ggplot(gg_dta, ggplot2::aes(x = 1, y = "yhat")) + gg_plt <- ggplot2::ggplot(gg_dta, ggplot2::aes(x = 1, y = .data$yhat)) } gg_plt <- gg_plt + - ggplot2::geom_jitter(, ...) + + ggplot2::geom_jitter(...) + ggplot2::geom_boxplot( outlier.colour = "transparent", fill = "transparent", - notch = TRUE, - ... + notch = notch + ) + + ggplot2::labs( + y = "Predicted Value", + x = if ("group" %in% colnames(gg_dta)) "Group" else "" ) + - ggplot2::labs(y = "Predicted Value", x = colnames(gg_dta)[2]) + ggplot2::theme( axis.ticks = ggplot2::element_blank(), - axis.text.x = ggplot2::element_blank() + axis.text.x = if ("group" %in% colnames(gg_dta)) { + ggplot2::element_text() + } else { + ggplot2::element_blank() + } ) } else { # Unknown forest type — not yet implemented diff --git a/R/plot.gg_roc.R b/R/plot.gg_roc.R index 590e5973..297de3db 100644 --- a/R/plot.gg_roc.R +++ b/R/plot.gg_roc.R @@ -15,13 +15,27 @@ #' #' ROC plot generic function for a \code{\link{gg_roc}} object. #' -#' @param x \code{\link{gg_roc}} object created from a classification forest -#' @param which_outcome for multiclass problems, choose the class for plotting -#' @param ... arguments passed to the \code{\link{gg_roc}} function +#' @param x A \code{\link{gg_roc}} object, or a raw +#' \code{\link[randomForestSRC]{rfsrc}} classification forest or +#' \code{\link[randomForest]{randomForest}} classification object. When a +#' forest is supplied, \code{\link{gg_roc}} is called automatically. +#' @param which_outcome Integer; for multi-class problems, the index of the +#' class whose ROC curve should be plotted. When \code{NULL} (default) and +#' the forest has more than two classes, ROC curves for all classes are +#' overlaid in a single plot. For binary forests \code{NULL} defaults to +#' class index 2. +#' @param ... Additional arguments forwarded to \code{\link{gg_roc}} when +#' \code{x} is a raw forest object (e.g. \code{oob = FALSE}). #' -#' @return \code{ggplot} object of the ROC curve +#' @return A \code{ggplot} object. The x-axis shows 1 − Specificity (FPR) +#' and the y-axis shows Sensitivity (TPR). A dashed red diagonal reference +#' line marks the random-classifier baseline. The AUC value is annotated +#' on the plot for single-class curves. Multi-class plots colour and style +#' each class curve distinctly. #' -#' @seealso \code{\link{gg_roc}} rfsrc +#' @seealso \code{\link{gg_roc}} \code{\link{calc_roc}} \code{\link{calc_auc}} +#' \code{\link[randomForestSRC]{rfsrc}} +#' \code{\link[randomForest]{randomForest}} #' #' @references #' Breiman L. (2001). Random forests, Machine Learning, 45:5-32. @@ -29,8 +43,9 @@ #' Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, #' Rnews, 7(2):25-31. #' -#' Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, -#' Regression and Classification (RF-SRC), R package version 1.4. +#' Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +#' Regression and Classification. R package version >= 3.4.0. +#' \url{https://cran.r-project.org/package=randomForestSRC} #' #' @examples #' ## ------------------------------------------------------------ @@ -43,21 +58,21 @@ #' #' # ROC for setosa (outcome index 1) #' gg_dta <- gg_roc(rfsrc_iris, which_outcome = 1) -#' plot.gg_roc(gg_dta) +#' plot(gg_dta) #' #' # ROC for versicolor (outcome index 2) #' gg_dta <- gg_roc(rfsrc_iris, which_outcome = 2) -#' plot.gg_roc(gg_dta) +#' plot(gg_dta) #' #' # ROC for virginica (outcome index 3) #' gg_dta <- gg_roc(rfsrc_iris, which_outcome = 3) -#' plot.gg_roc(gg_dta) +#' plot(gg_dta) #' -#' # Alternatively, pass the forest directly to plot all three ROC curves -#' plot.gg_roc(rfsrc_iris) +#' # Plot all three ROC curves in one call by iterating over outcome indices +#' n_cls <- ncol(rfsrc_iris$predicted) +#' for (i in seq_len(n_cls)) print(plot(gg_roc(rfsrc_iris, which_outcome = i))) #' #' @export -#' @export plot.gg_roc plot.gg_roc <- function(x, which_outcome = NULL, ...) { gg_dta <- x @@ -71,7 +86,7 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) { if (crv > 2 && is.null(which_outcome)) { # Multi-class: compute ROC for every class in parallel - gg_dta <- mclapply(1:crv, function(ind) { + gg_dta <- mclapply(seq_len(crv), function(ind) { gg_roc(gg_dta, which_outcome = ind, ...) }) } else { @@ -90,7 +105,7 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) { crv <- length(levels(gg_dta$predicted)) if (crv > 2 && is.null(which_outcome)) { # Multi-class: compute ROC for every class in parallel - gg_dta <- parallel::mclapply(1:crv, function(ind) { + gg_dta <- parallel::mclapply(seq_len(crv), function(ind) { gg_roc(gg_dta, which_outcome = ind, ...) }) } else { @@ -137,7 +152,7 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) { ## ---- Multi-class ROC plot (list of gg_roc objects) ---------------- # Sort each class's data by specificity gg_dta <- parallel::mclapply(gg_dta, function(st) { - st[order(st$spec), ] + st <- st[order(st$spec), ] st }) # Compute FPR for each class @@ -164,8 +179,8 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) { ggplot2::geom_line(ggplot2::aes( x = .data$fpr, y = .data$sens, - linetype = "outcome", - col = "outcome" + linetype = .data$outcome, + col = .data$outcome )) + ggplot2::labs(x = "1 - Specificity (FPR)", y = "Sensitivity (TPR)") + # Reference diagonal for a random classifier @@ -178,17 +193,7 @@ plot.gg_roc <- function(x, which_outcome = NULL, ...) { ) + ggplot2::coord_fixed() - # Annotate AUC only when there is a single outcome (binary case fallback) - if (crv < 2) { - gg_plt <- gg_plt + - ggplot2::annotate( - x = .5, - y = .2, - geom = "text", - label = paste("AUC = ", round(auc, digits = 3), sep = ""), - hjust = 0 - ) - } + # Multi-class: do not annotate a single AUC value — each class has its own. } return(gg_plt) } diff --git a/R/plot.gg_survival.R b/R/plot.gg_survival.R index b347a56d..b9ca4a4d 100644 --- a/R/plot.gg_survival.R +++ b/R/plot.gg_survival.R @@ -23,7 +23,13 @@ #' @param label Modify the legend label when gg_survival has stratified samples #' @param ... not used #' -#' @return \code{ggplot} object +#' @return A \code{ggplot} object. The y-axis shows the chosen \code{type} +#' (e.g. survival probability for \code{"surv"}) and the x-axis shows time. +#' Confidence shading, bars, or lines are added when the input object carries +#' confidence-interval columns. +#' +#' @seealso \code{\link{gg_survival}}, \code{\link{kaplan}}, +#' \code{\link{nelson}}, \code{\link{gg_rfsrc}} #' #' @examples #' ## -------- pbc data @@ -68,7 +74,6 @@ #' plot(gg_dta, label = "sex", error = "lines") #' #' @export -#' @export plot.gg_survival ### Survival plots plot.gg_survival <- function(x, type = c("surv", @@ -85,10 +90,10 @@ plot.gg_survival <- function(x, if (inherits(gg_dta, "rfsrc")) { gg_dta <- gg_survival(gg_dta) } - + error <- match.arg(error) type <- match.arg(type) - + # Now order matters, so we want to place the forest predictions on the bottom # Create the figure skeleton, if (is.null(gg_dta$groups)) { diff --git a/R/plot.gg_variable.R b/R/plot.gg_variable.R index 188bcee1..e3c0b48a 100644 --- a/R/plot.gg_variable.R +++ b/R/plot.gg_variable.R @@ -26,15 +26,21 @@ #' @param smooth include a smooth curve (boolean) #' @param ... arguments passed to the \code{ggplot2} functions. #' -#' @return A single \code{ggplot} object, or list of \code{ggplot} objects +#' @return A single \code{ggplot} object when \code{length(xvar) == 1} or +#' \code{panel = TRUE}. Otherwise a named list of \code{ggplot} objects, one +#' per variable in \code{xvar}. +#' +#' @seealso \code{\link{gg_variable}}, \code{\link{gg_partial}}, +#' \code{\link[randomForestSRC]{plot.variable}} #' #' @references Breiman L. (2001). Random forests, Machine Learning, 45:5-32. #' #' Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, Rnews, #' 7(2):25-31. #' -#' Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, Regression -#' and Classification (RF-SRC), R package version 1.4. +#' Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +#' Regression and Classification. R package version >= 3.4.0. +#' \url{https://cran.r-project.org/package=randomForestSRC} #' #' #' @importFrom ggplot2 .data @@ -110,8 +116,7 @@ #' plot(gg_dta, xvar = c("age", "diagtime"), panel = TRUE) #' #' @export -#' @export plot.gg_variable -plot.gg_variable <- function(x, +plot.gg_variable <- function(x, # nolint: cyclocomp_linter xvar, time, time_labels, @@ -218,11 +223,11 @@ plot.gg_variable <- function(x, # Subset to response + requested predictors, then pivot to long form tmp_dta <- gg_dta[, c(wch_y_var, wch_x_var)] - gathercols <- + pivot_cols <- colnames(tmp_dta)[-which(colnames(tmp_dta) %in% c("time", "event", "yhat"))] gg_dta_mlt <- - tidyr::gather(tmp_dta, "variable", "value", tidyr::all_of(gathercols)) + tidyr::pivot_longer(tmp_dta, tidyr::all_of(pivot_cols), names_to = "variable", values_to = "value") # Preserve user-supplied xvar ordering in the facet strips gg_dta_mlt$variable <- @@ -309,23 +314,27 @@ plot.gg_variable <- function(x, # Include the observed class label column for colouring wch_y_var <- c(wch_y_var, which(colnames(gg_dta) == "yvar")) tmp_dta <- gg_dta[, c(wch_y_var, wch_x_var)] - gathercols <- + pivot_cols <- colnames(tmp_dta)[-which(colnames(tmp_dta) %in% c("yvar", "yhat"))] gg_dta_mlt <- - tidyr::gather( - tmp_dta, "variable", "value", - tidyr::all_of(gathercols) + tidyr::pivot_longer( + tmp_dta, + tidyr::all_of(pivot_cols), + names_to = "variable", + values_to = "value" ) } else { # Regression: keep yhat and the optional yvar reference column wch_y_var <- c(wch_y_var, which(colnames(gg_dta) == "yvar")) tmp_dta <- gg_dta[, c(wch_y_var, wch_x_var)] - gathercols <- + pivot_cols <- colnames(tmp_dta)[-which(colnames(tmp_dta) == "yhat")] gg_dta_mlt <- - tidyr::gather( - tmp_dta, "variable", "value", - tidyr::all_of(gathercols) + tidyr::pivot_longer( + tmp_dta, + tidyr::all_of(pivot_cols), + names_to = "variable", + values_to = "value" ) } # Preserve user-supplied xvar ordering in the facet strips @@ -382,17 +391,8 @@ plot.gg_variable <- function(x, # Add point/smooth layers for non-classification forests if (family != "class") { if (points) { - gg_plt <- ggplot2::ggplot( - gg_dta_mlt, - ggplot2::aes(x = .data$value, y = .data$yhat) - ) + + gg_plt <- gg_plt + ggplot2::geom_point(...) - } else { - gg_plt <- ggplot2::ggplot( - gg_dta_mlt, - ggplot2::aes(x = .data$value, y = .data$yhat) - ) + - ggplot2::geom_smooth(...) } if (smooth) { gg_plt <- gg_plt + @@ -412,13 +412,15 @@ plot.gg_variable <- function(x, # Pre-allocate a list; collapsed to a single object when lng == 1 gg_plt <- vector("list", length = lng) - for (ind in 1:lng) { + for (ind in seq_len(lng)) { # Temporarily rename the target predictor column to "var" for aes() ch_indx <- which(colnames(gg_dta) == xvar[ind]) h_name <- colnames(gg_dta)[ch_indx] colnames(gg_dta)[ch_indx] <- "var" - ccls <- class(gg_dta[, "var"]) - ccls[which(ccls == "integer")] <- "numeric" + # Use only the primary class (class() can return multiple strings, e.g. + # c("POSIXct", "POSIXt")); a multi-element vector in if() triggers a warning. + ccls_var <- class(gg_dta[, "var"])[1L] + if (ccls_var == "integer") ccls_var <- "numeric" gg_plt[[ind]] <- ggplot2::ggplot(gg_dta) @@ -427,7 +429,7 @@ plot.gg_variable <- function(x, gg_plt[[ind]] <- gg_plt[[ind]] + ggplot2::labs(x = h_name, y = "Survival") - if (ccls == "numeric") { + if (ccls_var == "numeric") { # Continuous predictor: scatter (and optional smooth) if (points) { gg_plt[[ind]] <- gg_plt[[ind]] + @@ -486,7 +488,7 @@ plot.gg_variable <- function(x, if (sum(colnames(gg_dta) == "outcome") == 0) { # Single-outcome (binary) classification - if (ccls == "numeric") { + if (ccls_var == "numeric") { if (points) { gg_plt[[ind]] <- gg_plt[[ind]] + ggplot2::geom_point( @@ -532,7 +534,7 @@ plot.gg_variable <- function(x, } } else { # Multi-class: facet by outcome class - if (ccls == "numeric") { + if (ccls_var == "numeric") { gg_plt[[ind]] <- gg_plt[[ind]] + ggplot2::geom_point( ggplot2::aes( @@ -571,7 +573,7 @@ plot.gg_variable <- function(x, # assume regression gg_plt[[ind]] <- gg_plt[[ind]] + ggplot2::labs(x = h_name, y = "Predicted") - if (ccls == "numeric") { + if (ccls_var == "numeric") { if (points) { gg_plt[[ind]] <- gg_plt[[ind]] + ggplot2::geom_point(ggplot2::aes(x = .data$var, y = .data$yhat), ...) diff --git a/R/plot.gg_vimp.R b/R/plot.gg_vimp.R index 5c9544f4..94bb9de6 100644 --- a/R/plot.gg_vimp.R +++ b/R/plot.gg_vimp.R @@ -32,8 +32,9 @@ #' Ishwaran H. and Kogalur U.B. (2007). Random survival forests for #' R, Rnews, 7(2):25-31. #' -#' Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, -#' Regression and Classification (RF-SRC), R package version 1.4. +#' Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +#' Regression and Classification. R package version >= 3.4.0. +#' \url{https://cran.r-project.org/package=randomForestSRC} #' #' @examples #' ## ------------------------------------------------------------ @@ -54,7 +55,6 @@ #' #' #' @export -#' @export plot.gg_vimp plot.gg_vimp <- function(x, relative, lbls, ...) { gg_dta <- x @@ -63,8 +63,8 @@ plot.gg_vimp <- function(x, relative, lbls, ...) { gg_dta <- gg_vimp(gg_dta, ...) } - # Capture extra args by name so we can inspect nvar without consuming it - arg_set <- as.list(substitute(list(...)))[-1L] + # Capture extra args so we can inspect nvar. + arg_set <- list(...) # Optionally restrict to the top-nvar most important variables (gg_vimp # already sorts by descending VIMP, so we just trim the tail). @@ -73,7 +73,7 @@ plot.gg_vimp <- function(x, relative, lbls, ...) { if (is.numeric(arg_set$nvar) && arg_set$nvar > 1) { if (arg_set$nvar < nrow(gg_dta)) { nvar <- arg_set$nvar - gg_dta <- gg_dta[1:nvar, ] + gg_dta <- gg_dta[seq_len(nvar), ] } } } @@ -94,10 +94,10 @@ plot.gg_vimp <- function(x, relative, lbls, ...) { gg_plt <- gg_plt + ggplot2::geom_bar( ggplot2::aes( - y = msr, - x = "vars", - fill = "positive", - color = "positive" + y = .data[[msr]], + x = .data$vars, + fill = .data$positive, + color = .data$positive ), stat = "identity", width = .5, @@ -107,7 +107,7 @@ plot.gg_vimp <- function(x, relative, lbls, ...) { # redundant fill legend. gg_plt <- gg_plt + ggplot2::geom_bar( - ggplot2::aes(y = msr, x = "vars", color = "positive"), + ggplot2::aes(y = .data[[msr]], x = .data$vars, color = .data$positive), stat = "identity", width = .5, ) diff --git a/R/surv_partial.rfsrc.R b/R/surv_partial.rfsrc.R index ab7b5c7b..a58f5838 100644 --- a/R/surv_partial.rfsrc.R +++ b/R/surv_partial.rfsrc.R @@ -1,26 +1,48 @@ -#' Calculate survival curve partial plot. +#' Survival partial dependence data for one or more predictors #' -#' @param rforest the randomForestSrc object -#' @param var_list a list of variables of interest. These variables should be a -#' subset of rforest$xvar.names -#' @param npts the number of points to segment the xvar of interest -#' @param partial.type the return prediction type. -#' For survival forests: type c("surv", "mort", "chf") -#' For competing risk forests: type c("years.lost", "cif", "chf") -#' see \code{randomForestSRC::partial.rfsrc} or more information +#' Computes partial dependence curves for a survival or competing-risk +#' \code{\link[randomForestSRC]{rfsrc}} forest by calling +#' \code{\link[randomForestSRC]{partial.rfsrc}} at \code{npts} evenly-spaced +#' unique values of each predictor across all stored event times. +#' +#' @param rforest A fitted \code{\link[randomForestSRC]{rfsrc}} survival or +#' competing-risk forest object. +#' @param var_list Character vector of predictor names for which partial +#' dependence should be computed. Each must appear in +#' \code{rforest$xvar.names}. +#' @param npts Integer; the number of predictor grid points to evaluate +#' (default 25). Evenly-spaced unique values are sampled from each predictor. +#' @param partial.type The prediction type to return. For survival forests one +#' of \code{"surv"} (default), \code{"mort"}, or \code{"chf"}. For competing +#' risk forests one of \code{"years.lost"}, \code{"cif"}, or \code{"chf"}. +#' See \code{\link[randomForestSRC]{partial.rfsrc}} for full details. +#' +#' @return A named list with one element per variable in \code{var_list}. Each +#' element is itself a list with: +#' \describe{ +#' \item{name}{The predictor variable name (character).} +#' \item{dta}{The raw output of +#' \code{\link[randomForestSRC]{get.partial.plot.data}}, a list containing +#' at minimum \code{x} (predictor values) and \code{yhat} (partial +#' predictions), and for survival/competing risk, \code{partial.time}.} +#' } +#' +#' @seealso \code{\link{gg_partial_rfsrc}}, +#' \code{\link[randomForestSRC]{partial.rfsrc}}, +#' \code{\link[randomForestSRC]{get.partial.plot.data}} #' #' @importFrom randomForestSRC partial.rfsrc #' @examples #' ## ------------------------------------------------------------ #' ## survival #' ## ------------------------------------------------------------ -#' +#' #' data(veteran, package = "randomForestSRC") -#' v.obj <- randomForestSRC::rfsrc(Surv(time,status)~., +#' v.obj <- randomForestSRC::rfsrc(Surv(time,status)~., #' veteran, nsplit = 10, ntree = 100) -#' +#' #' spart <- surv_partial.rfsrc(v.obj, var_list="age", partial.type = "mort") -#' +#' #' ## partial effect of age on mortality #' partial.obj <- partial(v.obj, #' partial.type = "mort", @@ -28,10 +50,10 @@ #' partial.values = v.obj$xvar$age, #' partial.time = v.obj$time.interest) #' pdta <- get.partial.plot.data(partial.obj) -#' +#' #' plot(lowess(pdta$x, pdta$yhat, f = 1/3), #' type = "l", xlab = "age", ylab = "adjusted mortality") -#' +#' #' ## example where x is discrete - partial effect of age on mortality #' ## we use the granule=TRUE option #' partial.obj <- partial(v.obj, @@ -41,8 +63,8 @@ #' partial.time = v.obj$time.interest) #' pdta <- get.partial.plot.data(partial.obj, granule = TRUE) #' boxplot(pdta$yhat ~ pdta$x, xlab = "treatment", ylab = "partial effect") -#' -#' +#' +#' #' ## partial effects of karnofsky score on survival #' karno <- quantile(v.obj$xvar$karno) #' partial.obj <- partial(v.obj, @@ -51,19 +73,19 @@ #' partial.values = karno, #' partial.time = v.obj$time.interest) #' pdta <- get.partial.plot.data(partial.obj) -#' +#' #' matplot(pdta$partial.time, t(pdta$yhat), type = "l", lty = 1, #' xlab = "time", ylab = "karnofsky adjusted survival") #' legend("topright", legend = paste0("karnofsky = ", karno), fill = 1:5) -#' -#' +#' +#' #' ## ------------------------------------------------------------ #' ## competing risk #' ## ------------------------------------------------------------ -#' +#' #' data(follic, package = "randomForestSRC") #' follic.obj <- rfsrc(Surv(time, status) ~ ., follic, nsplit = 3, ntree = 100) -#' +#' #' ## partial effect of age on years lost #' partial.obj <- partial(follic.obj, #' partial.type = "years.lost", @@ -72,13 +94,13 @@ #' partial.time = follic.obj$time.interest) #' pdta1 <- get.partial.plot.data(partial.obj, target = 1) #' pdta2 <- get.partial.plot.data(partial.obj, target = 2) -#' +#' #' par(mfrow=c(2,2)) #' plot(lowess(pdta1$x, pdta1$yhat), #' type = "l", xlab = "age", ylab = "adjusted years lost relapse") #' plot(lowess(pdta2$x, pdta2$yhat), #' type = "l", xlab = "age", ylab = "adjusted years lost death") -#' +#' #' ## partial effect of age on cif #' partial.obj <- partial(follic.obj, #' partial.type = "cif", @@ -87,23 +109,23 @@ #' partial.time = follic.obj$time.interest) #' pdta1 <- get.partial.plot.data(partial.obj, target = 1) #' pdta2 <- get.partial.plot.data(partial.obj, target = 2) -#' +#' #' matplot(pdta1$partial.time, t(pdta1$yhat), type = "l", lty = 1, #' xlab = "time", ylab = "age adjusted cif for relapse") #' matplot(pdta2$partial.time, t(pdta2$yhat), type = "l", lty = 1, #' xlab = "time", ylab = "age adjusted cif for death") #' #' @export surv_partial.rfsrc -surv_partial.rfsrc <- function(rforest, var_list, npts=25, partial.type = "surv") { +surv_partial.rfsrc <- function(rforest, var_list, npts = 25, partial.type = "surv") { # nolint: object_name_linter ###----------Partial dependency estimation, for each variable, at each time point ---- surv.lst <- lapply(var_list, function(xvar) { ## extract the key variable cat("partial plot for:", xvar, "\n") - + ## determine the partial plot data xv <- sort(unique(rforest$xvar[, xvar])) xv <- unique(xv[seq(1, length(xv), length = npts)]) - + ## Get the partial.plot.data partial.dta <- randomForestSRC::get.partial.plot.data( randomForestSRC::partial.rfsrc( @@ -114,10 +136,10 @@ surv_partial.rfsrc <- function(rforest, var_list, npts=25, partial.type = "surv" partial.time = rforest$time.interest ) ) - - list(name=xvar, + + list(name = xvar, dta = partial.dta) - + }) return(surv.lst) } diff --git a/R/utils.R b/R/utils.R new file mode 100644 index 00000000..8709f2af --- /dev/null +++ b/R/utils.R @@ -0,0 +1,95 @@ +####********************************************************************** +####********************************************************************** +#### +#### ---------------------------------------------------------------- +#### Written by: +#### John Ehrlinger, Ph.D. +#### +#### email: john.ehrlinger@gmail.com +#### URL: https://github.com/ehrlinger/ggRandomForests +#### ---------------------------------------------------------------- +#### +####********************************************************************** +####********************************************************************** +## Internal utility functions shared across the package. +## None of these are exported to end-users. + +# --------------------------------------------------------------------------- # +#' Lead / lag shift for numeric vectors +#' +#' @param x a numeric vector of values +#' @param shift_by an integer of length 1, giving the number of positions +#' to lead (positive) or lag (negative) by +#' +#' @details Lead and lag are useful for comparing values offset by a constant +#' (e.g. the previous or next value). +#' +#' Taken from: +#' http://ctszkin.com/2012/03/11/generating-a-laglead-variables/ +#' +#' This function allows removal of the dplyr::lead dependency. +#' +#' @keywords internal +#' @examples +#' d <- data.frame(x = 1:15) +#' # generate lead variable +#' d$df_lead2 <- ggRandomForests:::shift(d$x, 2) +#' # generate lag variable +#' d$df_lag2 <- ggRandomForests:::shift(d$x, -2) +shift <- function(x, shift_by = 1) { + stopifnot(is.numeric(shift_by)) + stopifnot(is.numeric(x)) + + if (length(shift_by) > 1) { + return(sapply(shift_by, shift, x = x)) + } + + abs_shift_by <- abs(shift_by) + if (shift_by > 0) { + out <- c(tail(x, -abs_shift_by), rep(NA, abs_shift_by)) + } else if (shift_by < 0) { + out <- c(rep(NA, abs_shift_by), head(x, -abs_shift_by)) + } else { + out <- x + } + out +} + +# --------------------------------------------------------------------------- # +# Internal helper: label a survfit tbl with stratum group names. +# +# survfit() concatenates strata end-to-end in ascending-time order. Stratum +# boundaries are detected by finding rows where the time column resets +# (i.e. time[i] < time[i-1]). +# +# @param tbl data.frame produced from survfit output (must have $time col) +# @param data original data.frame passed to kaplan()/nelson() +# @param by character; name of the grouping column in data +# +# @return tbl with an additional $groups column containing the group label +# for each row. +.label_strata <- function(tbl, data, by) { + # Use levels() for factors to respect the existing ordering; fall back to + # unique() (in order of first appearance) for character/numeric vectors. + by_col <- data[[by]] + lbls <- if (is.factor(by_col)) levels(by_col) else unique(by_col) + + # Single stratum or fewer than 2 rows: label everything with first group + if (nrow(tbl) < 2L) { + tbl$groups <- lbls[1L] + return(tbl) + } + + # Detect stratum boundaries where the time column resets + tm_splits <- which(c(FALSE, sapply(seq(2L, nrow(tbl)), function(ind) { + tbl$time[ind] < tbl$time[ind - 1L] + }))) + + tbl$groups <- lbls[1L] + if (length(tm_splits) > 0L) { + for (ind in seq_along(tm_splits)) { + tbl$groups[tm_splits[ind]:nrow(tbl)] <- lbls[ind + 1L] + } + } + tbl +} diff --git a/README.md b/README.md index ae2ff83c..ed4c1b3c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ +# ggRandomForests: Visually Exploring Random Forests -ggRandomForests: Visually Exploring Random Forests -======================================================== [![cranlogs](https://cranlogs.r-pkg.org:443/badges/ggRandomForests)](https://cranlogs.r-pkg.org:443/badges/ggRandomForests) [![CRAN_Status_Badge](https://www.r-pkg.org/badges/version/ggRandomForests)](https://cran.r-project.org/package=ggRandomForests) @@ -15,45 +14,100 @@ ggRandomForests: Visually Exploring Random Forests [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.11526.svg)](https://doi.org/10.5281/zenodo.11526) -[ggRandomForests](https://cran.r-project.org/package=ggRandomForests) will help uncover variable associations in the random forests models. The package is designed for use with the [randomForest](https://cran.r-project.org/package=randomForest) package (A. Liaw and M. Wiener 2002) or the [randomForestSRC](https://cran.r-project.org/package=randomForestSRC) package (Ishwaran et.al. 2014, 2008, 2007) for survival, regression and classification random forests and uses the [ggplot2](https://cran.r-project.org/package=ggplot2) package (Wickham 2009) for plotting diagnostic and variable association results. [ggRandomForests](https://cran.r-project.org/package=ggRandomForests) is structured to extract data objects from [randomForestSRC](https://cran.r-project.org/package=randomForestSRC) or [randomForest](https://cran.r-project.org/package=randomForest) objects and provides S3 functions for printing and plotting these objects. - -The [randomForestSRC](https://cran.r-project.org/package=randomForestSRC) package provides a unified treatment of Breiman's (2001) random forests for a variety of data settings. Regression and classification forests are grown when the response is numeric or categorical (factor) while survival and competing risk forests (Ishwaran et al. 2008, 2012) are grown for right-censored survival data. Recently, support for the [randomForest](https://cran.r-project.org/package=randomForest) package (A. Liaw and M. Wiener 2002) for regression and classification forests has also been added. +`ggRandomForests` provides `ggplot2`-based diagnostic and exploration plots for random forests fit with +[randomForestSRC](https://cran.r-project.org/package=randomForestSRC) (>= 3.4.0) or +[randomForest](https://cran.r-project.org/package=randomForest). +It separates data extraction from plotting so the intermediate tidy objects can be inspected, saved, or used +for custom analyses. -Many of the figures created by the `ggRandomForests` package are also available directly from within the `randomForestSRC` or `randomForest` package. However, `ggRandomForests` offers the following advantages: +## Installation - * Separation of data and figures: `ggRandomForests` contains functions that operate on either the forest object directly, or on the output from `randomForestSRC` and `randomForest` post processing functions (i.e. `plot.variable`, `var.select`, `find.interaction`) to generate intermediate `ggRandomForests` data objects. S3 functions are provide to further process these objects and plot results using the `ggplot2` graphics package. Alternatively, users can use these data objects for additional custom plotting or analysis operations. +```r +# CRAN (stable) +install.packages("ggRandomForests") - * Each data object/figure is a single, self contained object. This allows simple modification and manipulation of the data or `ggplot2` objects to meet users specific needs and requirements. +# Development version from GitHub +# install.packages("remotes") +remotes::install_github("ehrlinger/ggRandomForests") +``` - * The use of `ggplot2` for plotting. We chose to use the `ggplot2` package for our figures to allow users flexibility in modifying the figures to their liking. Each S3 plot function returns either a single `ggplot2` object, or a `list` of `ggplot2` objects, allowing users to use additional `ggplot2` functions or themes to modify and customize the figures to their liking. +## Quick start -Check out the ["Exploring Random Forests with ggRandomForests" vignette](vignettes/ggRandomForests.qmd) for a walk-through of these objects. +```r +library(randomForestSRC) +library(ggRandomForests) -The package has recently been extended for Breiman and Cutler's Random Forests for Classification and -Regression package [randomForest](https://cran.r-project.org/package=randomForest) where possible. Though methods have been provided for all `gg_*` functions, the unsupported functions will return an error message indicating where support is still lacking. +# 1. Fit a forest (regression) +rf <- rfsrc(medv ~ ., data = MASS::Boston, importance = TRUE) -## Recent improvements +# 2. Check convergence: did the forest grow enough trees? +plot(gg_error(rf)) -- `gg_error()` now computes optional in-bag training error trajectories for `randomForest` fits when `training = TRUE`, making it easy to compare OOB and training curves in a single data object. -- `gg_variable()` rebuilds the original training data from the `randomForest` call, so marginal dependence plots work even when models are trained inside helper functions or use `subset()` calls. -- `quantile_pts()` is fully quantile based, providing balanced conditioning intervals that can be dropped directly into `cut()` for coplots. -- `gg_vimp()` handles forests that were trained without importance metrics by issuing a warning and returning `NA` placeholders, ensuring downstream plotting code continues to run. +# 3. Rank predictors by importance +plot(gg_vimp(rf)) -## References +# 4. Marginal dependence for top variables +gg_v <- gg_variable(rf) +plot(gg_v, xvar = "lstat") +plot(gg_v, xvar = rf$xvar.names, panel = TRUE, se = FALSE) + +# 5. Partial dependence for a single predictor +pv <- plot.variable(rf, xvar.names = "lstat", partial = TRUE, show.plots = FALSE) +pd <- gg_partial(pv) +plot(pd) +``` + +For survival forests, see the package vignette: +```r +vignette("ggRandomForests") +``` + +## Function reference + +| Function | Input | What you get | +|---|---|---| +| `gg_error()` | `rfsrc` / `randomForest` | OOB error vs. number of trees | +| `gg_vimp()` | `rfsrc` / `randomForest` | Variable importance ranking | +| `gg_rfsrc()` | `rfsrc` / `randomForest` | Predicted vs. observed values | +| `gg_variable()` | `rfsrc` / `randomForest` | Marginal dependence data frame | +| `gg_partial()` | `plot.variable` output | Partial dependence (continuous + categorical) | +| `gg_partial_rfsrc()` | `rfsrc` model | Partial dependence via `partial.rfsrc` | +| `gg_survival()` | `rfsrc` survival forest | Kaplan–Meier / Nelson–Aalen estimates | +| `gg_roc()` | `rfsrc` / `randomForest` (class) | ROC curve data | + +Each `gg_*` function has a corresponding `plot()` S3 method that returns a `ggplot2` object, making it easy +to apply additional `ggplot2` layers or themes. -Breiman, L. (2001). Random forests, Machine Learning, 45:5-32. +## Why ggRandomForests? -Ishwaran H. and Kogalur U.B. (2014). Random Forests for Survival, -Regression and Classification (RF-SRC), R package version 1.5.5. +- **Separation of data and figures.** `gg_*` functions extract tidy data objects from the forest. + `plot()` methods turn those into `ggplot2` figures. You can inspect, save, or transform the data + before plotting. +- **Self-contained objects.** Each data object holds everything needed for its plot, so figures are + reproducible without the original forest in memory. +- **Full `ggplot2` composability.** Every `plot()` method returns a `ggplot` object that accepts + additional layers, scales, and themes. + +## Recent changes + +See [NEWS.md](NEWS.md) for the full changelog. Highlights since v2.4.0: + +- **v2.6.1** Fix factor-level assignment in `gg_partial` for categorical variables. +- **v2.6.0** New plotting functions exported; test coverage raised to 83%; removed internal dependency on `hvtiRutilities`. +- **v2.5.0** New `gg_partial_rfsrc()` computes partial dependence directly from an `rfsrc` model without a separate `plot.variable` call; supports a grouping variable via `xvar2.name`. + +## References -Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R. R News -7(2), 25--31. +Breiman, L. (2001). Random forests, *Machine Learning*, 45:5–32. -Ishwaran H., Kogalur U.B., Blackstone E.H. and Lauer M.S. (2008). Random -survival forests. Ann. Appl. Statist. 2(3), 841--860. +Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, Regression and +Classification. R package version >= 3.4.0. -A. Liaw and M. Wiener (2002). Classification and Regression by randomForest. R News 2(3), 18--22. +Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R. *R News* 7(2), 25–31. -Wickham, H. ggplot2: elegant graphics for data analysis. Springer New York, 2009. +Ishwaran H., Kogalur U.B., Blackstone E.H. and Lauer M.S. (2008). Random survival forests. +*Ann. Appl. Statist.* 2(3), 841–860. +Liaw A. and Wiener M. (2002). Classification and Regression by randomForest. *R News* 2(3), 18–22. +Wickham H. (2009). *ggplot2: Elegant Graphics for Data Analysis*. Springer New York. diff --git a/code-review.md b/code-review.md new file mode 100644 index 00000000..c2a6e834 --- /dev/null +++ b/code-review.md @@ -0,0 +1,462 @@ +# ggRandomForests — Code Review & Testing Strategy +_Pre-release review against v2.6.1 codebase — March 2026_ + +--- + +## Part 1: Testing Strategy + +### Current State + +The suite has 15 test files covering all major exported functions. The high-level +picture is good — rfsrc + randomForest, regression + classification + survival +paths, error handling — but there are systematic gaps that mean genuine bugs +can and do pass undetected. + +**Coverage summary** + +| File | Status | Notable gaps | +|---|---|---| +| `test_gg_error.R` | Solid | No message-text checks on `expect_error` | +| `test_gg_rfsrc.R` | Excellent (17 cases) | Plots tested structurally only, not visually | +| `test_gg_vimp.R` | Good | `gg_dta[1:nvar, ]` path untested for nvar=0 | +| `test_gg_roc.R` | Decent | Typo in function name silently voids two tests (see bugs) | +| `test_gg_survival.R` | Thin (3 cases) | No column-structure checks, no error-message validation | +| `test_gg_variable.R` | Good | — | +| `test_gg_partial.R` | Good (mock-based) | — | +| `test_gg_partialpro.R` | Good (mock-based) | — | +| `test_surv_partial.R` | Reasonable | `npts` test doesn't verify actual point count | +| `test_randomForest_helpers.R` | Good | — | +| `test_varpro_feature_names.R` | Thorough | — | +| `test_quantile_pts.R` | Basic | No edge cases (n=1, all-identical values) | +| `test_shift.R` | **Broken API** | Uses `expect_that`/`is_identical_to` (testthat 1.x — removed) | +| `test_lint.R` | **Dead** | Test body is commented out | +| `test_ggrandomforests_news.R` | Trivial | — | + +### Gap 1 — Deprecated testthat API throughout + +Every test file uses one or more removed/deprecated calls: + +- `context("...")` — deprecated in testthat 3.0; harmless now but will break in a future release +- `expect_is(x, "cls")` — deprecated; use `expect_s3_class(x, "cls")` +- `expect_equivalent(a, b)` — deprecated; use `expect_equal(a, b, ignore_attr = TRUE)` +- `test_shift.R`: `expect_that(x, is_identical_to(y))` — from testthat 1.x, removed entirely + +**Action:** global find-and-replace across all test files. + +### Gap 2 — Plots are tested structurally, not visually + +Every `plot.*` test does `expect_s3_class(gg_plt, "ggplot")`. A ggplot object can +be created successfully even when all aesthetics map to bare string literals +(which produces a broken plot). This means the three broken `aes()` bugs +described in Part 2 pass the test suite silently. + +**Action:** Adopt `vdiffr::expect_doppelganger()` for at least one snapshot per +plot function. Minimum targets: regression scatter, classification jitter, +survival step curve with and without CI ribbon. + +```r +# Example +test_that("plot.gg_rfsrc survival CI ribbon snapshot", { + vdiffr::expect_doppelganger( + "rfsrc-surv-ci", + plot(gg_rfsrc(rfsrc_veteran, conf.int = 0.95)) + ) +}) +``` + +### Gap 3 — `bootstrap_survival` has zero direct tests + +The function is only tested via `gg_rfsrc(..., conf.int = ...)`. There are no +unit tests that verify the output columns (`value`, `lower`, `upper`, `median`, +`mean`), that `lower <= median <= upper` holds, or that a two-element +`level_set` produces the right column names. + +**Action:** Add a `test_bootstrap_survival.R` that calls +`ggRandomForests:::bootstrap_survival(gg_t, bs_samples, level_set)` directly. + +### Gap 4 — `gg_survival` column structure untested + +`test_gg_survival.R` checks the class and then runs 7 `plot()` type +combinations, but never verifies that the returned data frame actually has the +expected columns (`time`, `surv`, `lower`, `upper`, `n.risk`, etc.). + +**Action:** +```r +expect_true(all(c("time", "surv", "lower", "upper", "n.risk") %in% + colnames(gg_dta))) +``` + +### Gap 5 — Two `expect_error` calls test a non-existent function + +In `test_gg_roc.R` lines 59–60 and 102–103: + +```r +expect_error(gg_roc.rfrsrc(rf_iris)) # typo: "rfrsrc" not "rfsrc" +``` + +`gg_roc.rfrsrc` does not exist, so this always errors with +`"could not find function"` — it vacuously passes regardless of whether input +validation works correctly. + +**Action:** Fix the typo to `gg_roc.rfsrc` and make the error message explicit. + +### Gap 6 — Error message text is never asserted + +Across all test files, `expect_error(fn(...))` is used without asserting the +message. A completely different error (e.g., NULL dereference) would pass. + +**Action:** For each validated error path, pin the message: +```r +expect_error(gg_vimp(bad_obj), "This function only works for") +expect_error(gg_rfsrc(rfsrc_boston, by = c(1,2,3)), "correct dimension") +``` + +### Gap 7 — Missing `set.seed()` and unbounded `ntree` + +Many tests build forests without `ntree` (defaulting to 500), and without +`set.seed()`. This makes the suite slow (CI times matter), non-reproducible +for debugging, and fragile against stochastic variation. + +**Action:** Every test that builds a forest should open with `set.seed(42)` and +pass `ntree = 50` (or `ntree = 75` max). + +### Gap 8 — `kaplan.R` and `nelson.R` have no tests + +Both helpers are used by `gg_survival` but have zero direct test coverage. + +### Gap 9 — `nvar = 0` path for `gg_vimp` untested + +`gg_dta[1:nvar, ]` with `nvar = 0` returns a 0-row data frame (fine in R), but +`gg_dta[1:0, ]` also works because `1:0` is `c(1, 0)` — two rows — which is +actually a silent bug. The `seq_len(nvar)` fix applied elsewhere was not applied +here. + +--- + +## Part 2: Code Review + +Severity scale: 🔴 Bug (broken output) / 🟠 Logic error (usually silent) / +🟡 Code smell / ⚪ Style + +--- + +### 🔴 BUG — `plot.gg_rfsrc.R`: every `aes()` uses bare string literals + +This is the most widespread breakage in the package. A bare string in `aes()` +maps the aesthetic to a *constant* — all points end up at the same position. + +**Affected lines** (current file, after recent edits): + +```r +# Classification binary — line ~169 +ggplot2::aes( + x = 1, + y = colnames(gg_dta)[1], # ← string, not column ref + color = colnames(gg_dta)[2], + shape = colnames(gg_dta)[2] +) +# Fix: +ggplot2::aes( + x = 1, + y = .data[[colnames(gg_dta)[1]]], + color = .data[[colnames(gg_dta)[2]]], + shape = .data[[colnames(gg_dta)[2]]] +) +``` + +```r +# Classification multi-class — line ~195 +ggplot2::aes(x = "variable", y = "value") +ggplot2::aes(color = "y", shape = "y") +# Fix: +ggplot2::aes(x = .data$variable, y = .data$value) +ggplot2::aes(color = .data$y, shape = .data$y) +``` + +```r +# Survival with CI, grouped — lines ~220–232 +ggplot2::aes(x = "value", ymin = "lower", ymax = "upper", fill = "group") +ggplot2::aes(x = "value", y = "median", color = "group") +# Fix: +ggplot2::aes(x = .data$value, ymin = .data$lower, ymax = .data$upper, + fill = .data$group) +ggplot2::aes(x = .data$value, y = .data$median, color = .data$group) +``` + +```r +# Survival with CI, ungrouped — lines ~237–244 +ggplot2::aes(x = "value", ymin = "lower", ymax = "upper") +ggplot2::aes(x = "value", y = "median") +# Fix: same pattern +``` + +```r +# Survival no CI — lines ~250–255 +ggplot2::aes(x = "variable", y = "value", col = "event", by = "obs_id") +# Fix: +ggplot2::aes(x = .data$variable, y = .data$value, + col = .data$event, group = .data$obs_id) +# Note: "by" is not a ggplot2 aes; use "group" for geom_step +``` + +```r +# Regression grouped — line ~268 +ggplot2::aes(x = "group", y = "yhat") +# Regression ungrouped — line ~271 +ggplot2::aes(x = 1, y = "yhat") +# Fix: +ggplot2::aes(x = .data$group, y = .data$yhat) +ggplot2::aes(x = 1, y = .data$yhat) +``` + +Additionally, `geom_boxplot(ggplot2::aes(x = 1, y = colnames(gg_dta)[1]))` on +line ~176 has the same string-literal problem. + +--- + +### 🔴 BUG — `plot.gg_roc.R`: multi-class aes() uses bare strings + +```r +# line ~168–170 +ggplot2::aes( + x = .data$fpr, + y = .data$sens, + linetype = "outcome", # ← constant string, not column + col = "outcome" +) +# Fix: +ggplot2::aes( + x = .data$fpr, + y = .data$sens, + linetype = .data$outcome, + col = .data$outcome +) +``` + +--- + +### 🟠 LOGIC ERROR — `bootstrap_survival`: nonsensical negative indexing + +```r +# Current (lines ~475–479): +dta <- data.frame(cbind( + time_interest, + t(rng)[-which(colnames(gg_dta) %in% c("obs_id", "event")), ], + mn[-which(colnames(gg_dta) %in% c("obs_id", "event"))] +)) +``` + +`t(rng)` has `n_time_points` rows. The negative index is derived from +`colnames(gg_dta)`, which has `n_time_points + 2` (or +3 with `group`) +columns. The positions of `"obs_id"` and `"event"` always exceed `n_time`, +so this negative index is a no-op in the common case — but the intent is +wrong, the code is misleading, and for a dataset with two or three time +points it would silently drop real rows. + +**Fix:** +```r +dta <- data.frame(cbind(time_interest, t(rng), mn)) +``` +The `gg_t` that feeds `mn_bs` already has `obs_id`/`event`/`group` stripped. + +--- + +### 🟠 LOGIC ERROR — `gg_rfsrc.rfsrc` and `.randomForest`: `is.null(df[, col])` does not detect missing columns + +```r +# Line 216 / 520: +if (is.null(object$xvar[, grp])) { ... } +``` + +`df[, "nonexistent_col"]` in R throws `"undefined columns selected"`, it does +not return NULL. The intended check should be: + +```r +if (!grp %in% colnames(object$xvar)) { + stop(paste("No column named", grp, "in forest training set.")) +} +grp <- object$xvar[, grp] +``` + +Additionally, `gg_rfsrc.randomForest` references `object$xvar` which **does not +exist** on `randomForest` objects. The `by` character-name lookup path for +randomForest would throw an obscure subscript error rather than an informative +message. + +--- + +### 🟠 LOGIC ERROR — `plot.gg_error.R`: legend-suppression check uses wrong column name + +```r +# Line ~247: +if (length(unique(gg_dta$variable)) == 1) { + gg_plt <- gg_plt + ggplot2::theme(legend.position = "none") +} +``` + +For single-outcome forests (regression, survival), `gg_dta` is not gathered — +it has a column named `"error"`, not `"variable"`. `gg_dta$variable` is NULL, +so `length(unique(NULL)) == 0 != 1`, and the legend-suppression block never +fires. Single-outcome plots show a redundant legend. + +**Fix:** +```r +# After both branches: +if (ncol(x) <= 2) { # single outcome: no legend needed + gg_plt <- gg_plt + ggplot2::theme(legend.position = "none") +} +``` + +Or check directly: `if (!"variable" %in% names(gg_dta) || length(unique(gg_dta$variable)) <= 1)`. + +--- + +### 🟠 LOGIC ERROR — `gg_vimp.rfsrc` and `.randomForest`: `1:nvar` instead of `seq_len(nvar)` + +```r +# gg_vimp.rfsrc line ~281: +gg_dta <- gg_dta[1:nvar, ] + +# gg_vimp.randomForest line ~429: +gg_dta <- gg_dta[1:nvar, ] +``` + +If `nvar = 0`, `1:0` evaluates to `c(1, 0)`, returning two rows. This was +already fixed in the loop-iteration context elsewhere in the package but missed +here. + +**Fix:** `gg_dta <- gg_dta[seq_len(nvar), ]` + +--- + +### 🟠 LOGIC ERROR — `plot.gg_rfsrc.R` survival no-CI branch: `by = "obs_id"` is not a ggplot2 aesthetic + +```r +ggplot2::aes(x = "variable", y = "value", col = "event", by = "obs_id") +``` + +`by` is not a recognised `ggplot2` aesthetic argument (it belongs to +`ggplot2::geom_line` in some versions but not `geom_step`). The correct +grouping aesthetic for step functions is `group`: + +```r +ggplot2::aes(x = .data$variable, y = .data$value, + col = .data$event, group = .data$obs_id) +``` + +--- + +### 🟡 CODE SMELL — Massive duplication: `by`-argument resolution block + +The 30-line block that validates and resolves the `by` argument is copy-pasted +verbatim between `gg_rfsrc.rfsrc` and `gg_rfsrc.randomForest`. Same pattern in +`gg_vimp.rfsrc` and `gg_vimp.randomForest` for `which.outcome` handling. + +**Action:** Extract to internal helpers: +```r +.resolve_by <- function(by, xvar) { ... } +.resolve_which_outcome <- function(which.outcome, gg_dta) { ... } +``` + +--- + +### 🟡 CODE SMELL — `tidyr::gather()` still used in four source files + +`gather()` is superseded (not just deprecated) by `pivot_longer()` since tidyr +1.0 (2019). It still works but will eventually be removed and prints a lifecycle +warning on newer tidyr versions. + +Files affected: `gg_rfsrc.R`, `gg_vimp.R`, `plot.gg_error.R`, `plot.gg_rfsrc.R`. + +--- + +### 🟡 CODE SMELL — `geom_jitter(, ...)` has a stray comma + +```r +# plot.gg_rfsrc.R line ~275: +gg_plt <- gg_plt + + ggplot2::geom_jitter(, ...) + +``` + +The leading comma before `...` is syntactically valid in R but is clearly a +typo — the first positional argument to `geom_jitter` is `mapping`. This should +be `ggplot2::geom_jitter(...)`. + +--- + +### 🟡 CODE SMELL — `geom_jitter` and `geom_boxplot` receive `...` in the same call + +In the classification and regression branches, `...` is passed to both +`geom_jitter(...)` and `geom_boxplot(...)`. Arguments like `alpha` or `size` +may not be meaningful for both geoms and could produce warnings or silently +ignored args. + +--- + +### 🟡 CODE SMELL — `plot.gg_roc` multi-class AUC annotation unreachable + +```r +# Line ~183: +if (crv < 2) { +``` + +`crv` is the number of classes and is only defined in the branch that builds the +multi-class list (`crv > 2 && is.null(which_outcome)`). `crv < 2` in a +multi-class context is impossible. The AUC annotation for the multi-class path +is dead code. + +--- + +### 🟡 CODE SMELL — `gg_rfsrc.rfsrc` builds `arg_list` but `by` is accessed via `missing(by)`, not via `arg_list` + +`by` is a named parameter, not in `...`, so `arg_list <- list(...)` does not +capture it. This is fine — but `conf.int`, `surv_type`, and `bs.sample` **are** +in `...` and are accessed via `arg_list`. The asymmetry is confusing and should +be documented clearly. + +--- + +### ⚪ STYLE — `point = FALSE` should be `point <- FALSE` + +```r +# plot.gg_error.R line ~214: +point = FALSE +``` + +Using `=` for assignment at the top level is valid but violates the package's +own style guide (and lintr's default rules). Should be `<-`. + +--- + +### ⚪ STYLE — `if(length(...) == 0)` spacing + +```r +# gg_partial.R line ~94: +if(length(cat_list) == 0) { +``` + +Missing space after `if`. Minor but lintr will flag it. + +--- + +## Summary: Prioritised fix list + +| # | Severity | File | Issue | +|---|---|---|---| +| 1 | 🔴 | `plot.gg_rfsrc.R` | All `aes()` aesthetics use bare string literals — plots are visually broken | +| 2 | 🔴 | `plot.gg_roc.R` | Multi-class `aes()` uses bare string literals | +| 3 | 🟠 | `bootstrap_survival` in `gg_rfsrc.R` | Nonsensical negative indexing — silent no-op now, latent crash for small datasets | +| 4 | 🟠 | `gg_rfsrc.rfsrc` + `.randomForest` | `is.null(df[, col])` does not detect missing columns; `randomForest` has no `$xvar` | +| 5 | 🟠 | `plot.gg_error.R` | Legend suppression uses wrong column name — legend always shown for single-outcome | +| 6 | 🟠 | `gg_vimp.rfsrc` + `.randomForest` | `1:nvar` instead of `seq_len(nvar)` — returns 2 rows when `nvar=0` | +| 7 | 🟠 | `plot.gg_rfsrc.R` | `by = "obs_id"` is not a ggplot2 aesthetic — should be `group` | +| 8 | 🟡 | `gg_rfsrc.R` | Duplicate `by`-resolution block — extract to helper | +| 9 | 🟡 | `gg_vimp.R` | Duplicate `which.outcome` block — extract to helper | +| 10 | 🟡 | 4 files | `tidyr::gather()` → `pivot_longer()` | +| 11 | 🟡 | `plot.gg_rfsrc.R` | Stray comma in `geom_jitter(, ...)` | +| 12 | 🟡 | `plot.gg_roc.R` | Dead `if (crv < 2)` AUC annotation branch | +| 13 | ⚪ | Test suite | All deprecated API calls (`expect_is`, `expect_equivalent`, `context`, `expect_that`) | +| 14 | ⚪ | Test suite | Missing `set.seed()` + unbounded `ntree` in many tests | +| 15 | ⚪ | Test suite | Zero visual/snapshot tests for a visualization package | +| 16 | ⚪ | `test_gg_roc.R` | Typo `gg_roc.rfrsrc` voids two tests silently | +| 17 | ⚪ | `test_lint.R` | Lintr check is commented out entirely | diff --git a/cran-comments.md b/cran-comments.md index 76dc6898..7b11cb32 100644 --- a/cran-comments.md +++ b/cran-comments.md @@ -1,8 +1,23 @@ -This is ggRandomForests package submission v2.4.0 +This is ggRandomForests package submission v2.7.0 ------------------------------------------------------------------------- -* Updating to latest ggplot2 functions -* Utilize some namespace referencing -* Added pkgdown documentation -* Minor testing improvements -* Removing partial plot and interaction plot functionality due to changes - in randomForestSRC 3.4.x +This is a bug-fix and code-quality release. Key changes: + +* Fix critical visual bug: `aes()` calls throughout `plot.gg_rfsrc` and + `plot.gg_roc` used bare string literals instead of `.data[[col]]`, + causing aesthetics to map to constant strings rather than data columns. +* Fix `bootstrap_survival` CI-band indexing and `gg_rfsrc.randomForest` + incorrect use of non-existent `object$xvar` field. +* Fix `seq_len(nvar)` vs `1:nvar` silent bug in `gg_vimp` and `plot.gg_vimp`. +* Full test suite migration to testthat 3.x API. +* Improved GitHub Actions CI (lintr enforcement, warnings-as-errors). + +## R CMD check results +0 errors | 0 warnings | 0 notes + +## Test environments +* local R installation (R 4.4, macOS) +* GitHub Actions: ubuntu-latest (R devel) +* GitHub Actions: ubuntu-latest (R release) +* GitHub Actions: ubuntu-latest (R oldrel-1) +* GitHub Actions: windows-latest (R release) +* GitHub Actions: macos-latest (R release) diff --git a/man/bootstrap_survival.Rd b/man/bootstrap_survival.Rd new file mode 100644 index 00000000..ae3c5e56 --- /dev/null +++ b/man/bootstrap_survival.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/gg_rfsrc.R +\name{bootstrap_survival} +\alias{bootstrap_survival} +\title{Bootstrap pointwise confidence bands for a mean survival curve} +\usage{ +bootstrap_survival(gg_dta, bs_samples, level_set) +} +\arguments{ +\item{gg_dta}{A wide \code{data.frame} of survival probabilities as returned +by the survival branch of \code{\link{gg_rfsrc.rfsrc}}, before the +optional pivot to long form. Columns \code{obs_id}, \code{event}, and +\code{group} (if present) are excluded from the resampling.} + +\item{bs_samples}{Integer; number of bootstrap resamples.} + +\item{level_set}{Numeric vector of length 2 giving the lower and upper +quantile probabilities for the confidence band (e.g. \code{c(0.025, 0.975)} +for a 95\% CI).} +} +\value{ +A \code{data.frame} with one row per unique event time and columns + \code{value} (time), \code{lower}, \code{upper}, \code{median}, and + \code{mean}. +} +\description{ +Draws \code{bs_samples} bootstrap resamples (with replacement) of the +per-observation survival curves stored in \code{gg_dta}, computes the column +means to obtain a bootstrapped mean curve per resample, then returns the +pointwise quantiles at \code{level_set} and the overall mean across +resamples. +} +\keyword{internal} diff --git a/man/calc_roc.rfsrc.Rd b/man/calc_roc.rfsrc.Rd index 448a3a8f..9a9aa248 100644 --- a/man/calc_roc.rfsrc.Rd +++ b/man/calc_roc.rfsrc.Rd @@ -9,20 +9,30 @@ \method{calc_roc}{rfsrc}(object, dta, which_outcome = "all", oob = TRUE, ...) } \arguments{ -\item{object}{\code{\link[randomForestSRC]{rfsrc}} or -\code{\link[randomForestSRC]{predict.rfsrc}} object -containing predicted response} +\item{object}{A fitted \code{\link[randomForestSRC]{rfsrc}}, +\code{\link[randomForestSRC]{predict.rfsrc}}, or +\code{\link[randomForest]{randomForest}} classification object containing +predicted class probabilities.} -\item{dta}{True response variable} +\item{dta}{A factor (or coercible to factor) of the true observed class +labels, one per observation. Typically \code{object$yvar} for rfsrc or +\code{object$y} for randomForest.} -\item{which_outcome}{If defined, only show ROC for this response.} +\item{which_outcome}{Integer index of the class for which the ROC curve is +computed (e.g. \code{1} for the first class, \code{2} for the second). +Use \code{"all"} to request all classes (currently falls back to class 1 +with a warning).} -\item{oob}{Use OOB estimates, the normal validation method (TRUE)} +\item{oob}{Logical; if \code{TRUE} (default for rfsrc) use OOB predicted +probabilities. Forced to \code{FALSE} for \code{randomForest} objects.} -\item{...}{extra arguments passed to helper functions} +\item{...}{Extra arguments passed to helper functions (currently unused).} } \value{ -A \code{gg_roc} object +A \code{gg_roc} \code{data.frame} with columns \code{sens} + (sensitivity), \code{spec} (specificity), and \code{pct} (the probability + threshold), with one row per unique prediction value. Suitable for passing + to \code{\link{calc_auc}} or \code{\link{plot.gg_roc}}. } \description{ Receiver Operator Characteristic calculator diff --git a/man/ggRandomForests-package.Rd b/man/ggRandomForests-package.Rd index 5daacf52..b5261f74 100644 --- a/man/ggRandomForests-package.Rd +++ b/man/ggRandomForests-package.Rd @@ -5,11 +5,12 @@ \title{ggRandomForests: Visually Exploring Random Forests} \description{ \code{ggRandomForests} is a utility package for -\code{randomForestSRC} (Ishwaran et.al. 2014, 2008, 2007) for survival, +\code{randomForestSRC} (Ishwaran and Kogalur) for survival, regression and classification forests and uses the \code{ggplot2} (Wickham 2009) package for plotting results. \code{ggRandomForests} is structured to extract data objects from the random forest and provides S3 functions for printing and plotting these objects. +Requires \code{randomForestSRC} >= 3.4.0. The \code{randomForestSRC} package provides a unified treatment of Breiman's (2001) random forests for a variety of data settings. Regression @@ -24,8 +25,8 @@ available directly from within the \code{randomForestSRC} package. However, \item Separation of data and figures: \code{ggRandomForest} contains functions that operate on either the \code{\link[randomForestSRC]{rfsrc}} forest object directly, or on the output from \code{randomForestSRC} post -processing functions (i.e. \code{plot.variable}, -\code{find.interaction}) to generate intermediate \code{ggRandomForests} +processing functions (i.e. \code{plot.variable}) to generate intermediate +\code{ggRandomForests} data objects. S3 functions are provide to further process these objects and plot results using the \code{ggplot2} graphics package. Alternatively, users can use these data objects for additional custom plotting or @@ -65,8 +66,9 @@ be further customized using standard \code{ggplot2} commands. \references{ Breiman, L. (2001). Random forests, Machine Learning, 45:5-32. -Ishwaran H. and Kogalur U.B. (2014). Random Forests for Survival, -Regression and Classification (RF-SRC), R package version 1.5.5.12. +Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +Regression and Classification. R package version >= 3.4.0. +\url{https://cran.r-project.org/package=randomForestSRC} Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R. R News 7(2), 25--31. diff --git a/man/gg_error.Rd b/man/gg_error.Rd index 69160de3..2fdf2268 100644 --- a/man/gg_error.Rd +++ b/man/gg_error.Rd @@ -185,10 +185,14 @@ Breiman L. (2001). Random forests, Machine Learning, 45:5-32. Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, Rnews, 7(2):25-31. -Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, Regression -and Classification (RF-SRC), R package version 1.4. +Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +Regression and Classification. R package version >= 3.4.0. +\url{https://cran.r-project.org/package=randomForestSRC} } \seealso{ -\code{\link{plot.gg_error}}, \code{\link[randomForestSRC]{rfsrc}}, - \code{\link[randomForest]{randomForest}} +\code{\link{plot.gg_error}}, \code{\link{gg_vimp}}, + \code{\link{gg_variable}}, + \code{\link[randomForestSRC]{rfsrc}}, + \code{\link[randomForest]{randomForest}}, + \code{\link[randomForestSRC]{plot.rfsrc}} } diff --git a/man/gg_partial_rfsrc.Rd b/man/gg_partial_rfsrc.Rd index c3112bc7..bddaa2d3 100644 --- a/man/gg_partial_rfsrc.Rd +++ b/man/gg_partial_rfsrc.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/gg_partial_rfsrc.R \name{gg_partial_rfsrc} \alias{gg_partial_rfsrc} -\title{Split partial lots into continuous or categorical datasets} +\title{Partial dependence data from an rfsrc model} \usage{ gg_partial_rfsrc( rf_model, @@ -13,21 +13,40 @@ gg_partial_rfsrc( ) } \arguments{ -\item{rf_model}{\code{rfsrc::rfsrc} model} +\item{rf_model}{A fitted \code{\link[randomForestSRC]{rfsrc}} object.} -\item{xvar.names}{list() Which variables to calculate partial plots} +\item{xvar.names}{Character vector of predictor names for which partial +dependence should be computed. Must be a subset of \code{rf_model$xvar.names}.} -\item{xvar2.name}{ a single grouping feature that is in the newx dataset} +\item{xvar2.name}{Optional single character name of a grouping variable in +\code{newx}. When supplied, partial dependence is computed separately for +each unique level of this variable and a \code{grp} column is appended.} -\item{newx}{a \code{data.frame} containing data to use for the partial plots} +\item{newx}{Optional \code{data.frame} of predictor values to evaluate +partial effects at. Defaults to the training data stored in +\code{rf_model$xvar}. All column names must match \code{rf_model$xvar.names}.} -\item{cat_limit}{Categorical features are build when there are fewer than -cat_limit unique features.} +\item{cat_limit}{Variables with fewer than \code{cat_limit} unique values in +\code{newx} are treated as categorical; all others are continuous. +Defaults to 10.} +} +\value{ +A named list with two elements: + \describe{ + \item{continuous}{A \code{data.frame} with columns \code{x} (numeric), + \code{yhat}, \code{name} (variable name), and optionally \code{grp} + (the level of \code{xvar2.name}) and \code{time} (survival forests + only) for all continuous predictors.} + \item{categorical}{A \code{data.frame} with the same columns but + \code{x} kept as character, for low-cardinality predictors.} + } } \description{ -gg_partial_rfsrc uses the \code{rfsrc::partial.rfsrc} to generate the partial -plot data internally. So you provide the \code{rfsrc::rfsrc} model, and the -xvar.names to generate the data. +Computes partial dependence for one or more predictors by calling +\code{\link[randomForestSRC]{partial.rfsrc}} internally, then splits the +results into separate data frames for continuous and categorical variables. +Unlike \code{\link{gg_partial}}, no separate \code{plot.variable} call is +required — supply the fitted \code{rfsrc} object directly. } \examples{ ## ------------------------------------------------------------ @@ -43,3 +62,7 @@ prt_dta <- gg_partial_rfsrc(airq.obj, xvar.names = c("Wind")) } +\seealso{ +\code{\link{gg_partial}}, \code{\link[randomForestSRC]{partial.rfsrc}}, + \code{\link[randomForestSRC]{get.partial.plot.data}} +} diff --git a/man/gg_rfsrc.rfsrc.Rd b/man/gg_rfsrc.rfsrc.Rd index 9830b4a1..545580c5 100644 --- a/man/gg_rfsrc.rfsrc.Rd +++ b/man/gg_rfsrc.rfsrc.Rd @@ -8,17 +8,51 @@ \method{gg_rfsrc}{rfsrc}(object, oob = TRUE, by, ...) } \arguments{ -\item{object}{\code{\link[randomForestSRC]{rfsrc}} object} - -\item{oob}{boolean, should we return the oob prediction , or the full -forest prediction.} - -\item{by}{stratifying variable in the training dataset, defaults to NULL} - -\item{...}{extra arguments} +\item{object}{A fitted \code{\link[randomForestSRC]{rfsrc}} or +\code{\link[randomForest]{randomForest}} object.} + +\item{oob}{Logical; if \code{TRUE} (default) return out-of-bag predictions. +Set to \code{FALSE} to use full in-bag (training) predictions. Forced to +\code{FALSE} automatically for \code{predict.rfsrc} objects, which carry +no OOB estimates.} + +\item{by}{Optional stratifying variable. Either a character column name +present in the training data, or a vector/factor of the same length as +the training set. When supplied, a \code{group} column is added to the +returned data and bootstrap CI bands (survival) are computed per group. +Omit or leave missing to return an unstratified result.} + +\item{...}{Additional arguments controlling output for specific forest +families: +\describe{ + \item{surv_type}{Character; one of \code{"surv"} (default), + \code{"chf"}, or \code{"mortality"} for survival forests.} + \item{conf.int}{Numeric coverage probability (e.g. \code{0.95}) to + request bootstrap pointwise confidence bands for survival forests. + Triggers wide-format output with \code{lower}, \code{upper}, + \code{median}, and \code{mean} columns.} + \item{bs.sample}{Integer; number of bootstrap resamples when + \code{conf.int} is set. Defaults to the number of observations.} +}} } \value{ -\code{gg_rfsrc} object +A \code{gg_rfsrc} object (a classed \code{data.frame}) whose + structure depends on the forest family: + \describe{ + \item{regression}{Columns \code{yhat} and the response name; optionally + a \code{group} column when \code{by} is supplied.} + \item{classification}{One column per class with predicted probabilities; + a \code{y} column with observed class labels; optionally \code{group}.} + \item{survival (no CI / grouping)}{Long-format with columns + \code{variable} (event time), \code{value} (survival probability), + \code{obs_id}, and \code{event}.} + \item{survival (with \code{conf.int} or \code{by})}{Wide-format with + pointwise bootstrap CI columns (\code{lower}, \code{upper}, + \code{median}, \code{mean}) per time point; a \code{group} column + when \code{by} is supplied.} + } + The object carries class attributes for the forest family so that + \code{\link{plot.gg_rfsrc}} dispatches correctly. } \description{ Extracts the predicted response values from the @@ -26,11 +60,11 @@ Extracts the predicted response values from the the response using \code{\link{plot.gg_rfsrc}}. } \details{ -\code{surv_type} ("surv", "chf", "mortality", "hazard") for survival - forests - - \code{oob} boolean, should we return the oob prediction , or the full -forest prediction. +For survival forests, use the \code{surv_type} argument + (\code{"surv"}, \code{"chf"}, or \code{"mortality"}) to select the + predicted quantity. Bootstrap confidence bands are requested by passing + \code{conf.int} (e.g. \code{conf.int = 0.95}); the number of resamples + is controlled by \code{bs.sample}. } \examples{ ## ------------------------------------------------------------ @@ -137,6 +171,7 @@ plot(gg_dta) } \seealso{ -\code{\link{plot.gg_rfsrc}} \code{rfsrc} \code{plot.rfsrc} -\code{\link{gg_survival}} +\code{\link{plot.gg_rfsrc}}, + \code{\link[randomForestSRC]{rfsrc}}, + \code{\link{gg_survival}} } diff --git a/man/gg_roc.rfsrc.Rd b/man/gg_roc.rfsrc.Rd index 8ab941aa..6e26230c 100644 --- a/man/gg_roc.rfsrc.Rd +++ b/man/gg_roc.rfsrc.Rd @@ -4,24 +4,42 @@ \alias{gg_roc.rfsrc} \alias{gg_roc} \alias{gg_roc.randomForest} -\title{ROC (Receiver operator curve) data from a classification random forest.} +\title{ROC (Receiver Operating Characteristic) curve data from a classification forest.} \usage{ -\method{gg_roc}{rfsrc}(object, which_outcome, oob, ...) +\method{gg_roc}{rfsrc}(object, which_outcome, oob = TRUE, ...) } \arguments{ -\item{object}{an \code{\link[randomForestSRC]{rfsrc}} classification object} +\item{object}{A classification \code{\link[randomForestSRC]{rfsrc}} or +\code{\link[randomForest]{randomForest}} object. Only forests with +\code{family == "class"} (rfsrc) or \code{type == "classification"} +(randomForest) are supported.} -\item{which_outcome}{select the classification outcome of interest.} +\item{which_outcome}{Integer index or character name of the class for which +the ROC curve is computed. For binary forests this is typically \code{1} +or \code{2}; for multi-class forests any valid class index. Use +\code{which_outcome = 0} to obtain the overall (averaged) ROC.} -\item{oob}{use oob estimates (default TRUE)} +\item{oob}{Logical; if \code{TRUE} (default) use out-of-bag predicted +probabilities for the curve. Set to \code{FALSE} to use full in-bag +predictions.} -\item{...}{extra arguments (not used)} +\item{...}{Extra arguments (currently unused).} } \value{ -\code{gg_roc} \code{data.frame} for plotting ROC curves. +A \code{gg_roc} \code{data.frame} with one row per unique prediction + threshold and columns: + \describe{ + \item{sens}{Sensitivity (true positive rate) at each threshold.} + \item{spec}{Specificity (true negative rate) at each threshold.} + \item{yvar}{The observed class label for each observation.} + } + Pass to \code{\link{calc_auc}} for the area under the curve. } \description{ -The sensitivity and specificity of a randomForest classification object. +Computes sensitivity (true positive rate) and specificity (1 - false positive +rate) across all prediction thresholds for one class of a classification +\code{\link[randomForestSRC]{rfsrc}} or +\code{\link[randomForest]{randomForest}} object. } \examples{ ## ------------------------------------------------------------ @@ -59,6 +77,8 @@ plot(gg_dta) } \seealso{ -\code{\link{plot.gg_roc}} \code{\link[randomForestSRC]{rfsrc}} -\code{\link[randomForest]{randomForest}} +\code{\link{plot.gg_roc}}, \code{\link{calc_roc}}, + \code{\link{calc_auc}}, + \code{\link[randomForestSRC]{rfsrc}}, + \code{\link[randomForest]{randomForest}} } diff --git a/man/gg_survival.Rd b/man/gg_survival.Rd index 1d21d87f..46238a2d 100644 --- a/man/gg_survival.Rd +++ b/man/gg_survival.Rd @@ -14,21 +14,27 @@ gg_survival( ) } \arguments{ -\item{interval}{name of the interval variable in the training dataset.} +\item{interval}{Character; name of the time-to-event column in \code{data}.} -\item{censor}{name of the censoring variable in the training dataset.} +\item{censor}{Character; name of the event-indicator column in \code{data} +(1 = event occurred, 0 = censored).} -\item{by}{stratifying variable in the training dataset, defaults to NULL} +\item{by}{Optional character; name of a grouping column in \code{data} for +stratified estimates. Defaults to \code{NULL} (unstratified).} -\item{data}{name of the training data.frame} +\item{data}{A \code{data.frame} containing the survival data.} -\item{type}{one of ("kaplan","nelson"), defaults to Kaplan-Meier} +\item{type}{One of \code{"kaplan"} (Kaplan-Meier, default) or +\code{"nelson"} (Nelson-Aalen cumulative hazard).} -\item{...}{extra arguments passed to Kaplan or Nelson functions.} +\item{...}{Additional arguments passed to \code{\link{kaplan}} or +\code{\link{nelson}} (e.g. \code{conf.int} to change the CI width).} } \value{ -A \code{gg_survival} object created using the non-parametric -Kaplan-Meier or Nelson-Aalen estimators. +A \code{gg_survival} \code{data.frame} with columns \code{time}, + \code{surv} (or \code{cum_haz} for Nelson-Aalen), \code{lower}, + \code{upper} (confidence limits), and \code{n.risk}. A \code{strata} + column is added when \code{by} is supplied. } \description{ Nonparametric survival estimates. diff --git a/man/gg_variable.Rd b/man/gg_variable.Rd index 8479e885..fa809ebd 100644 --- a/man/gg_variable.Rd +++ b/man/gg_variable.Rd @@ -18,7 +18,12 @@ gg_variable(object, ...) \code{oob} that tailor the marginal dependence extraction.} } \value{ -\code{gg_variable} object +A \code{gg_variable} object: a \code{data.frame} of all predictor + columns from the training data paired with the OOB (or in-bag) predicted + response. For survival forests each requested time horizon produces an + additional column named by \code{time_labels}. The object carries a + \code{"family"} class attribute (\code{"regr"}, \code{"class"}, or + \code{"surv"}) used by \code{\link{plot.gg_variable}} for dispatch. } \description{ \code{\link[randomForestSRC]{plot.variable}} generates a @@ -208,7 +213,6 @@ plot(gg_dta, xvar = c("age", "trig"), panel = TRUE, se = FALSE) } \seealso{ -\code{\link{plot.gg_variable}} - -\code{\link[randomForestSRC]{plot.variable}} +\code{\link{plot.gg_variable}}, + \code{\link[randomForestSRC]{plot.variable}} } diff --git a/man/ggrandomforests.news.Rd b/man/ggrandomforests.news.Rd new file mode 100644 index 00000000..4587e375 --- /dev/null +++ b/man/ggrandomforests.news.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ggrandomforests.news.R +\name{ggrandomforests.news} +\alias{ggrandomforests.news} +\title{Display the ggRandomForests NEWS file} +\usage{ +ggrandomforests.news(...) +} +\arguments{ +\item{...}{Currently unused; reserved for future arguments.} +} +\value{ +Called for its side-effect of opening the NEWS file in the system + pager (\code{file.show}). Returns \code{invisible(NULL)}. +} +\description{ +Opens the package NEWS file in the system pager so users can read the +version history and change log without leaving their R session. +} +\keyword{internal} diff --git a/man/plot.gg_error.Rd b/man/plot.gg_error.Rd index 25072bc8..344733dc 100644 --- a/man/plot.gg_error.Rd +++ b/man/plot.gg_error.Rd @@ -7,13 +7,20 @@ \method{plot}{gg_error}(x, ...) } \arguments{ -\item{x}{gg_error object created from a \code{\link[randomForestSRC]{rfsrc}} -object} - -\item{...}{extra arguments passed to \code{ggplot} functions} +\item{x}{A \code{\link{gg_error}} object created from either a +\code{\link[randomForestSRC]{rfsrc}} or a +\code{\link[randomForest]{randomForest}} object. A raw forest object +may also be supplied and will be passed through \code{\link{gg_error}} +automatically before plotting.} + +\item{...}{Extra arguments forwarded to the underlying \code{ggplot2} +geometry calls (e.g. \code{size}, \code{linetype}).} } \value{ -\code{ggplot} object +A \code{ggplot} object with \code{ntree} on the x-axis and + OOB error rate on the y-axis. Single-outcome forests (regression, + survival) produce a single line; multi-outcome forests (classification) + produce one coloured line per class. } \description{ A plot of the cumulative OOB error rates of the random forest as a @@ -183,8 +190,9 @@ Breiman L. (2001). Random forests, Machine Learning, 45:5-32. Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, Rnews, 7(2):25-31. -Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, Regression -and Classification (RF-SRC), R package version 1.4. +Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +Regression and Classification. R package version >= 3.4.0. +\url{https://cran.r-project.org/package=randomForestSRC} } \seealso{ \code{\link{gg_error}} \code{\link[randomForestSRC]{rfsrc}} diff --git a/man/plot.gg_rfsrc.Rd b/man/plot.gg_rfsrc.Rd index 9040198d..00e5226d 100644 --- a/man/plot.gg_rfsrc.Rd +++ b/man/plot.gg_rfsrc.Rd @@ -4,21 +4,54 @@ \alias{plot.gg_rfsrc} \title{Predicted response plot from a \code{\link{gg_rfsrc}} object.} \usage{ -\method{plot}{gg_rfsrc}(x, ...) +\method{plot}{gg_rfsrc}(x, notch = TRUE, ...) } \arguments{ -\item{x}{\code{\link{gg_rfsrc}} object created from a -\code{\link[randomForestSRC]{rfsrc}} object} - -\item{...}{arguments passed to \code{\link{gg_rfsrc}}.} +\item{x}{A \code{\link{gg_rfsrc}} object, or a raw +\code{\link[randomForestSRC]{rfsrc}} object (which will be passed through +\code{\link{gg_rfsrc}} automatically before plotting).} + +\item{notch}{Logical; whether to draw notched boxplots for regression and +classification forests (default \code{TRUE}). Set \code{notch = FALSE} +to suppress notches when sample sizes are too small for reliable +confidence intervals on the median.} + +\item{...}{Additional arguments forwarded to the underlying +\code{ggplot2} geometry calls. Commonly useful arguments include: +\describe{ + \item{\code{alpha}}{Numeric in \eqn{[0,1]}; point/ribbon transparency. + For survival plots with confidence bands the ribbon alpha is + automatically halved relative to the value supplied here.} + \item{\code{size}}{Point or line size passed to \code{geom_jitter}, + \code{geom_step}, etc.} +} +Arguments that control \code{\link{gg_rfsrc}} (e.g. \code{conf.int}, +\code{surv_type}, \code{by}) should be applied when constructing the +\code{gg_rfsrc} object before calling \code{plot()}.} } \value{ -\code{ggplot} object +A \code{ggplot} object. The plot appearance depends on the forest + family stored in \code{x}: + \describe{ + \item{Regression (\code{"regr"})}{Jitter + notched boxplot of OOB + predicted values. If a \code{group} column is present the x-axis + shows each group label; otherwise observations are collapsed to a + single x-position.} + \item{Classification (\code{"class"})}{Binary: jitter + notched + boxplot of the predicted class probability. Multi-class: jitter + plot with one panel per class (class probabilities in long form).} + \item{Survival (\code{"surv"})}{Step curves of the ensemble survival + function. When \code{gg_rfsrc} was called with \code{conf.int}, + a shaded ribbon is added. When called with \code{by}, curves are + coloured by group.} + } } \description{ Plot the predicted response from a \code{\link{gg_rfsrc}} object, the \code{\link[randomForestSRC]{rfsrc}} prediction, using the OOB prediction -from the forest. +from the forest. The plot type adapts automatically to the forest family: +jitter + boxplot for regression and classification, step curves for +survival. } \examples{ ## ------------------------------------------------------------ @@ -135,9 +168,11 @@ Breiman L. (2001). Random forests, Machine Learning, 45:5-32. Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, Rnews, 7(2):25-31. -Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, Regression -and Classification (RF-SRC), R package version 1.4. +Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +Regression and Classification. R package version >= 3.4.0. +\url{https://cran.r-project.org/package=randomForestSRC} } \seealso{ \code{\link{gg_rfsrc}} \code{\link[randomForestSRC]{rfsrc}} + \code{\link[randomForest]{randomForest}} } diff --git a/man/plot.gg_roc.Rd b/man/plot.gg_roc.Rd index 4714c441..084414ad 100644 --- a/man/plot.gg_roc.Rd +++ b/man/plot.gg_roc.Rd @@ -7,14 +7,26 @@ \method{plot}{gg_roc}(x, which_outcome = NULL, ...) } \arguments{ -\item{x}{\code{\link{gg_roc}} object created from a classification forest} +\item{x}{A \code{\link{gg_roc}} object, or a raw +\code{\link[randomForestSRC]{rfsrc}} classification forest or +\code{\link[randomForest]{randomForest}} classification object. When a +forest is supplied, \code{\link{gg_roc}} is called automatically.} -\item{which_outcome}{for multiclass problems, choose the class for plotting} +\item{which_outcome}{Integer; for multi-class problems, the index of the +class whose ROC curve should be plotted. When \code{NULL} (default) and +the forest has more than two classes, ROC curves for all classes are +overlaid in a single plot. For binary forests \code{NULL} defaults to +class index 2.} -\item{...}{arguments passed to the \code{\link{gg_roc}} function} +\item{...}{Additional arguments forwarded to \code{\link{gg_roc}} when +\code{x} is a raw forest object (e.g. \code{oob = FALSE}).} } \value{ -\code{ggplot} object of the ROC curve +A \code{ggplot} object. The x-axis shows 1 − Specificity (FPR) + and the y-axis shows Sensitivity (TPR). A dashed red diagonal reference + line marks the random-classifier baseline. The AUC value is annotated + on the plot for single-class curves. Multi-class plots colour and style + each class curve distinctly. } \description{ ROC plot generic function for a \code{\link{gg_roc}} object. @@ -30,18 +42,19 @@ rfsrc_iris <- rfsrc(Species ~ ., data = iris, ntree = 50) # ROC for setosa (outcome index 1) gg_dta <- gg_roc(rfsrc_iris, which_outcome = 1) -plot.gg_roc(gg_dta) +plot(gg_dta) # ROC for versicolor (outcome index 2) gg_dta <- gg_roc(rfsrc_iris, which_outcome = 2) -plot.gg_roc(gg_dta) +plot(gg_dta) # ROC for virginica (outcome index 3) gg_dta <- gg_roc(rfsrc_iris, which_outcome = 3) -plot.gg_roc(gg_dta) +plot(gg_dta) -# Alternatively, pass the forest directly to plot all three ROC curves -plot.gg_roc(rfsrc_iris) +# Plot all three ROC curves in one call by iterating over outcome indices +n_cls <- ncol(rfsrc_iris$predicted) +for (i in seq_len(n_cls)) print(plot(gg_roc(rfsrc_iris, which_outcome = i))) } \references{ @@ -50,9 +63,12 @@ Breiman L. (2001). Random forests, Machine Learning, 45:5-32. Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, Rnews, 7(2):25-31. -Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, -Regression and Classification (RF-SRC), R package version 1.4. +Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +Regression and Classification. R package version >= 3.4.0. +\url{https://cran.r-project.org/package=randomForestSRC} } \seealso{ -\code{\link{gg_roc}} rfsrc +\code{\link{gg_roc}} \code{\link{calc_roc}} \code{\link{calc_auc}} + \code{\link[randomForestSRC]{rfsrc}} + \code{\link[randomForest]{randomForest}} } diff --git a/man/plot.gg_survival.Rd b/man/plot.gg_survival.Rd index 147cbae7..f091638f 100644 --- a/man/plot.gg_survival.Rd +++ b/man/plot.gg_survival.Rd @@ -25,7 +25,10 @@ object created from a \code{\link[randomForestSRC]{rfsrc}} object} \item{...}{not used} } \value{ -\code{ggplot} object +A \code{ggplot} object. The y-axis shows the chosen \code{type} + (e.g. survival probability for \code{"surv"}) and the x-axis shows time. + Confidence shading, bars, or lines are added when the input object carries + confidence-interval columns. } \description{ Plot a \code{\link{gg_survival}} object. @@ -73,3 +76,7 @@ plot(gg_dta, error = "lines") plot(gg_dta, label = "sex", error = "lines") } +\seealso{ +\code{\link{gg_survival}}, \code{\link{kaplan}}, + \code{\link{nelson}}, \code{\link{gg_rfsrc}} +} diff --git a/man/plot.gg_variable.Rd b/man/plot.gg_variable.Rd index 1457429b..ef6efabe 100644 --- a/man/plot.gg_variable.Rd +++ b/man/plot.gg_variable.Rd @@ -37,7 +37,9 @@ \item{...}{arguments passed to the \code{ggplot2} functions.} } \value{ -A single \code{ggplot} object, or list of \code{ggplot} objects +A single \code{ggplot} object when \code{length(xvar) == 1} or + \code{panel = TRUE}. Otherwise a named list of \code{ggplot} objects, one + per variable in \code{xvar}. } \description{ Plot a \code{\link{gg_variable}} object, @@ -120,6 +122,11 @@ Breiman L. (2001). Random forests, Machine Learning, 45:5-32. Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, Rnews, 7(2):25-31. -Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, Regression -and Classification (RF-SRC), R package version 1.4. +Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +Regression and Classification. R package version >= 3.4.0. +\url{https://cran.r-project.org/package=randomForestSRC} +} +\seealso{ +\code{\link{gg_variable}}, \code{\link{gg_partial}}, + \code{\link[randomForestSRC]{plot.variable}} } diff --git a/man/plot.gg_vimp.Rd b/man/plot.gg_vimp.Rd index cb3b8a53..d591d8f3 100644 --- a/man/plot.gg_vimp.Rd +++ b/man/plot.gg_vimp.Rd @@ -50,8 +50,9 @@ Breiman L. (2001). Random forests, Machine Learning, 45:5-32. Ishwaran H. and Kogalur U.B. (2007). Random survival forests for R, Rnews, 7(2):25-31. -Ishwaran H. and Kogalur U.B. (2013). Random Forests for Survival, -Regression and Classification (RF-SRC), R package version 1.4. +Ishwaran H. and Kogalur U.B. randomForestSRC: Random Forests for Survival, +Regression and Classification. R package version >= 3.4.0. +\url{https://cran.r-project.org/package=randomForestSRC} } \seealso{ \code{\link{gg_vimp}} diff --git a/man/shift.Rd b/man/shift.Rd index d52c12ca..04951450 100644 --- a/man/shift.Rd +++ b/man/shift.Rd @@ -1,29 +1,28 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/calc_roc.R +% Please edit documentation in R/utils.R \name{shift} \alias{shift} -\title{lead function to shift by one (or more).} +\title{Lead / lag shift for numeric vectors} \usage{ shift(x, shift_by = 1) } \arguments{ -\item{x}{a vector of values} +\item{x}{a numeric vector of values} \item{shift_by}{an integer of length 1, giving the number of positions to lead (positive) or lag (negative) by} } \description{ -lead function to shift by one (or more). +Lead / lag shift for numeric vectors } \details{ Lead and lag are useful for comparing values offset by a constant -(e.g. the previous or next value) +(e.g. the previous or next value). Taken from: http://ctszkin.com/2012/03/11/generating-a-laglead-variables/ -This function allows me to remove the dplyr::lead depends. Still suggest for -vignettes though. +This function allows removal of the dplyr::lead dependency. } \examples{ d <- data.frame(x = 1:15) @@ -31,6 +30,5 @@ d <- data.frame(x = 1:15) d$df_lead2 <- ggRandomForests:::shift(d$x, 2) # generate lag variable d$df_lag2 <- ggRandomForests:::shift(d$x, -2) -# -# } +\keyword{internal} diff --git a/man/surv_partial.rfsrc.Rd b/man/surv_partial.rfsrc.Rd index 87bbf615..4e7e7acf 100644 --- a/man/surv_partial.rfsrc.Rd +++ b/man/surv_partial.rfsrc.Rd @@ -2,25 +2,42 @@ % Please edit documentation in R/surv_partial.rfsrc.R \name{surv_partial.rfsrc} \alias{surv_partial.rfsrc} -\title{Calculate survival curve partial plot.} +\title{Survival partial dependence data for one or more predictors} \usage{ surv_partial.rfsrc(rforest, var_list, npts = 25, partial.type = "surv") } \arguments{ -\item{rforest}{the randomForestSrc object} +\item{rforest}{A fitted \code{\link[randomForestSRC]{rfsrc}} survival or +competing-risk forest object.} -\item{var_list}{a list of variables of interest. These variables should be a -subset of rforest$xvar.names} +\item{var_list}{Character vector of predictor names for which partial +dependence should be computed. Each must appear in +\code{rforest$xvar.names}.} -\item{npts}{the number of points to segment the xvar of interest} +\item{npts}{Integer; the number of predictor grid points to evaluate +(default 25). Evenly-spaced unique values are sampled from each predictor.} -\item{partial.type}{the return prediction type. -For survival forests: type c("surv", "mort", "chf") -For competing risk forests: type c("years.lost", "cif", "chf") -see \code{randomForestSRC::partial.rfsrc} or more information} +\item{partial.type}{The prediction type to return. For survival forests one +of \code{"surv"} (default), \code{"mort"}, or \code{"chf"}. For competing +risk forests one of \code{"years.lost"}, \code{"cif"}, or \code{"chf"}. +See \code{\link[randomForestSRC]{partial.rfsrc}} for full details.} +} +\value{ +A named list with one element per variable in \code{var_list}. Each + element is itself a list with: + \describe{ + \item{name}{The predictor variable name (character).} + \item{dta}{The raw output of + \code{\link[randomForestSRC]{get.partial.plot.data}}, a list containing + at minimum \code{x} (predictor values) and \code{yhat} (partial + predictions), and for survival/competing risk, \code{partial.time}.} + } } \description{ -Calculate survival curve partial plot. +Computes partial dependence curves for a survival or competing-risk +\code{\link[randomForestSRC]{rfsrc}} forest by calling +\code{\link[randomForestSRC]{partial.rfsrc}} at \code{npts} evenly-spaced +unique values of each predictor across all stored event times. } \examples{ ## ------------------------------------------------------------ @@ -106,3 +123,8 @@ matplot(pdta2$partial.time, t(pdta2$yhat), type = "l", lty = 1, xlab = "time", ylab = "age adjusted cif for death") } +\seealso{ +\code{\link{gg_partial_rfsrc}}, + \code{\link[randomForestSRC]{partial.rfsrc}}, + \code{\link[randomForestSRC]{get.partial.plot.data}} +} diff --git a/release-checklist-v2.7.0.md b/release-checklist-v2.7.0.md new file mode 100644 index 00000000..e48d19c5 --- /dev/null +++ b/release-checklist-v2.7.0.md @@ -0,0 +1,119 @@ +# Release Checklist: ggRandomForests v2.7.0 +**Date:** 2026-03-25 | **Maintainer:** John Ehrlinger + +--- + +## What Changed Since 2.6.1 + +This release is a significant bug-fix and code-quality release. All changes should be summarised in NEWS.md before submission. + +| Area | Fix | +|------|-----| +| `plot.gg_rfsrc.R` | **Breaking visual bug**: all `aes()` calls used bare string literals instead of `.data[[col]]` — plots mapped aesthetics to constant text, not data columns | +| `plot.gg_roc.R` | Multi-class `aes()` bare string literals fixed; dead `if (crv < 2)` branch removed | +| `gg_rfsrc.R` | `bootstrap_survival` negative-index bug fixed; `is.null(df[,col])` column-existence check fixed; `gg_rfsrc.randomForest` used non-existent `object$xvar` — now uses `.rf_recover_model_frame()` | +| `plot.gg_error.R` | Legend-suppression logic fixed for single-outcome forests; `=` → `<-` assignment style | +| `gg_vimp.R` | `1:nvar` → `seq_len(nvar)` in both `rfsrc` and `randomForest` methods (silent bug when `nvar == 0`) | +| `gg_partial.R` | `if(` → `if (` lintr spacing | +| Test suite | Full testthat 3.x migration: `expect_is` → `expect_s3_class/expect_type/expect_true(is.*())`, `expect_equivalent` → `expect_equal(ignore_attr=TRUE)`, all `context()` removed, `expect_that`/`is_identical_to` removed, `gg_roc.rfrsrc` typo fixed | +| GitHub Actions | `lint.yaml` upgraded (checkout@v4, fail-on-lint-errors, `.lintr` config created); `R-CMD-check.yaml` rtools 42→44, `error_on: warning` added; `test-coverage.yaml` duplicate codecov upload removed | + +--- + +## Pre-Release Checklist + +### 1. Version & Metadata +- [ ] Bump `Version:` in `DESCRIPTION` from `2.6.1` → `2.7.0` +- [ ] Update `Date:` in `DESCRIPTION` to today (`2026-03-25`) +- [ ] Add `covr` to `Suggests:` in `DESCRIPTION` (referenced in `test-coverage.yaml` but not declared) +- [ ] Confirm `RoxygenNote:` in `DESCRIPTION` matches your installed roxygen2 version (`devtools::document()` will warn if not) + +### 2. NEWS.md +- [ ] Add `ggRandomForests v2.7.0` section at the top of `NEWS.md` summarising all the changes above +- [ ] Keep entries user-facing: "Fixed visual bug where all plot aesthetics were mapped to constant strings instead of data columns" is better than "fixed aes() calls" + +### 3. Documentation +- [ ] Run `devtools::document()` — confirm zero warnings, all `.Rd` files regenerate cleanly +- [ ] Spot-check exported help pages: `?gg_rfsrc`, `?plot.gg_rfsrc`, `?gg_roc`, `?gg_vimp` +- [ ] Confirm `plot.gg_vimp` still has `1:nvar` in `R/plot.gg_vimp.R:77` — this was **not** fixed and should be added to Phase 2 or fixed now + +### 4. Test Suite +- [ ] Run `devtools::test()` locally — zero failures, zero skips that aren't intentional +- [ ] Confirm no `expect_is()`, `expect_equivalent()`, or `context()` calls remain: + ```r + grep -r "expect_is\|expect_equivalent\|context(" tests/ + ``` +- [ ] Check coverage with `covr::package_coverage()` — should be ≥ 83% (existing baseline) +- [ ] Verify `test_gg_roc.R` actually exercises `gg_roc.rfsrc()` error path (typo was `rfrsrc` → now fixed to `rfsrc`) + +### 5. R CMD CHECK (local) +Run the full CRAN-equivalent check: +```r +devtools::check(args = c("--as-cran")) +``` +- [ ] **0 errors** — hard gate, CRAN will reject +- [ ] **0 warnings** — hard gate (CI now enforces this with `error_on: "warning"`) +- [ ] **0 notes** — aim for this; if unavoidable, document in `cran-comments.md` +- [ ] Check output for `WARNING: no visible binding for global variable` — the `.data` pronoun fixes should have eliminated most, but verify +- [ ] Check for `NOTE: checking for unstated dependencies in examples` — confirm `MASS` and `datasets` are in `Suggests:` + +### 6. Known Remaining Issues (Phase 2 — decide: fix now or document) +These were identified in the code review but not yet fixed. Decide before release: + +- [ ] **`plot.gg_vimp.R:77`** — `gg_dta[1:nvar, ]` should be `seq_len(nvar)` (same class of bug fixed in `gg_vimp.R`) — **recommend fixing now** +- [ ] **`tidyr::gather()` → `pivot_longer()`** — `gather()` is superseded but still works; acceptable to defer if timeline is tight +- [ ] **No `vdiffr` snapshot tests** — visual regression testing gap; safe to defer to v2.8.0 +- [ ] **`bootstrap_survival` unit tests** — no direct test for the CI band helper; safe to defer + +### 7. CRAN Submission Prep +- [ ] Update `cran-comments.md` for v2.7.0: + - List the R CMD CHECK result (0 errors | 0 warnings | 0 notes) + - List test environments (local + GitHub Actions matrix) + - Summarise the changes briefly for the CRAN reviewer +- [ ] Run `devtools::check_win_devel()` — submits to CRAN's own Windows R-devel server; results emailed within ~30 min +- [ ] Run `rhub::check_for_cran()` (or use the `rhub.yaml` workflow) to test on additional platforms +- [ ] Confirm no reverse-dependency breakage (no `revdep/` folder exists — if you have downstream users, consider `revdepcheck::revdep_check()`) + +### 8. Git / GitHub +- [ ] All changes committed to `main` +- [ ] GitHub Actions are green on `main` (R-CMD-check, lint, test-coverage) +- [ ] Create and push a git tag: `git tag v2.7.0 && git push origin v2.7.0` +- [ ] Create a GitHub Release from the tag, copy the NEWS.md entry as the release notes + +### 9. Submit to CRAN +```r +devtools::submit_cran() +``` +- [ ] Confirm submission acknowledgement email received +- [ ] Watch for CRAN incoming queue status at https://cran.r-project.org/incoming/ +- [ ] Respond to any CRAN reviewer queries within 14 days + +--- + +## Rollback Triggers + +If CRAN rejects or a critical issue is found post-release: +- [ ] Patch release `2.7.1` addressing the specific CRAN note/warning +- [ ] If a user-facing regression is reported: hotfix branch from `v2.7.0` tag, fix + test, release `2.7.1` within 24–48 h +- [ ] Re-submit with updated `cran-comments.md` explaining what was changed + +--- + +## Quick Reference Commands + +```r +# Full local CRAN check +devtools::check(args = "--as-cran") + +# Regenerate documentation +devtools::document() + +# Run tests with coverage +covr::package_coverage() + +# CRAN Windows devel check +devtools::check_win_devel() + +# Submit +devtools::submit_cran() +``` diff --git a/tests/testthat/test_gg_error.R b/tests/testthat/test_gg_error.R index 249366be..ec0a9177 100644 --- a/tests/testthat/test_gg_error.R +++ b/tests/testthat/test_gg_error.R @@ -1,5 +1,4 @@ # testthat for gg_error function -context("gg_error tests") test_that("gg_error.rfsrc classifications", { ## Load the cached forest @@ -11,37 +10,37 @@ test_that("gg_error.rfsrc classifications", { tree.err = TRUE ) # Test the cached forest type - expect_is(rfsrc_iris, "rfsrc") - + expect_s3_class(rfsrc_iris, "rfsrc") + # Test the forest family expect_match(rfsrc_iris$family, "class") - + ## Create the correct gg_error object gg_dta <- gg_error(rfsrc_iris) - + # Test object type - expect_is(gg_dta, "gg_error") - + expect_s3_class(gg_dta, "gg_error") + # Test classification dimensions expect_equal(dim(gg_dta)[1], dim(na.omit(rfsrc_iris$err.rate))[1]) expect_equal(dim(gg_dta)[2], dim(rfsrc_iris$err.rate)[2] + 1) - + # Test data is correctly pulled from randomForest obect. - # expect_equivalent(as.matrix(gg_dta[, -which(colnames(gg_dta) == "ntree")]), + # expect_equal(as.matrix(gg_dta[, -which(colnames(gg_dta) == "ntree")], ignore_attr = TRUE), # rfsrc_iris$err.rate) - + ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + # "Incorrect object type: Expects a gg_error object" expect_error(gg_error(gg_plt)) expect_error(gg_error.rfsrc(gg_plt)) rfsrc_iris$err.rate <- NULL expect_error(gg_error(rfsrc_iris)) - + rfsrc_iris <- randomForestSRC::rfsrc( Species ~ ., data = iris, @@ -50,15 +49,15 @@ test_that("gg_error.rfsrc classifications", { ) ## Create the correct gg_error object gg_dta <- gg_error(rfsrc_iris, training = TRUE) - + # Test object type - expect_is(gg_dta, "gg_error") - - ## Test plotting the gg_error object - gg_plt <- expect_warning(plot(gg_dta)) - + expect_s3_class(gg_dta, "gg_error") + + ## Test plotting the gg_error object (no warnings expected since pivot_longer migration) + gg_plt <- plot(gg_dta) + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) @@ -69,32 +68,32 @@ test_that("gg_error.randomForest classifications", { rf_iris <- randomForest::randomForest(Species ~ ., data = iris, ntree = 75) - + # Test the cached forest type - expect_is(rf_iris, "randomForest") - + expect_s3_class(rf_iris, "randomForest") + # Test the forest family expect_match(rf_iris$type, "classification") - + ## Create the correct gg_error object gg_dta <- gg_error(rf_iris) - + # Test object type - expect_is(gg_dta, "gg_error") - + expect_s3_class(gg_dta, "gg_error") + # Test classification dimensions expect_equal(dim(gg_dta)[1], dim(rf_iris$err.rate)[1]) expect_equal(dim(gg_dta)[2], dim(rf_iris$err.rate)[2] + 1) - + # Test data is correctly pulled from randomForest obect. - expect_equivalent(as.matrix(gg_dta[, -which(colnames(gg_dta) == "ntree")]), + expect_equal(as.matrix(gg_dta[, -which(colnames(gg_dta) == "ntree")], ignore_attr = TRUE), rf_iris$err.rate) - + ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") ## Ensure training curve can be requested gg_train <- gg_error(rf_iris, training = TRUE) @@ -102,14 +101,14 @@ test_that("gg_error.randomForest classifications", { expect_length(gg_train$train, nrow(gg_train)) expect_true(min(gg_train$train, na.rm = TRUE) >= 0) expect_true(max(gg_train$train, na.rm = TRUE) <= 1) - + # "Incorrect object type: Expects a gg_error object" expect_error(gg_error(gg_plt)) expect_error(gg_error.randomForest(gg_plt)) rf_iris$err.rate <- NULL expect_error(gg_error(rf_iris)) - - + + }) @@ -117,85 +116,170 @@ test_that("gg_error.randomForest classifications", { test_that("gg_error regression rfsrc", { ## Load the cached forest data(Boston, package = "MASS") - - Boston$chas <- as.logical(Boston$chas) - + + Boston$chas <- as.logical(Boston$chas) # nolint: object_name_linter + rfsrc_boston <- randomForestSRC::rfsrc(medv ~ ., data = Boston) # Test the cached forest type - expect_is(rfsrc_boston, "rfsrc") - + expect_s3_class(rfsrc_boston, "rfsrc") + # Test the forest family expect_match(rfsrc_boston$family, "regr") - + ## Create the correct gg_error object gg_dta <- gg_error(rfsrc_boston) - + # Test object type - expect_is(gg_dta, "gg_error") - + expect_s3_class(gg_dta, "gg_error") + # Test classification dimensions expect_equal(nrow(gg_dta), length(na.omit(rfsrc_boston$err.rate))) expect_equal(ncol(gg_dta), 2) - - # Test data is correctly pulled from randomForest obect. - expect_equivalent(c(gg_dta[, 1]), na.omit(rfsrc_boston$err.rate)) - + + # Test data is correctly pulled from randomForest object. + expect_equal(as.numeric(gg_dta[[1]]), as.numeric(na.omit(rfsrc_boston$err.rate))) + ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + # Test the exception for input expect_error(gg_error(gg_plt)) - - + + gg_dta <- gg_error(rfsrc_boston, training = TRUE) - expect_is(gg_dta, "gg_error") - - + expect_s3_class(gg_dta, "gg_error") + + }) test_that("gg_error regression randomForest", { ## Load the cached forest data(Boston, package = "MASS") - - Boston$chas <- as.logical(Boston$chas) - + + Boston$chas <- as.logical(Boston$chas) # nolint: object_name_linter + rf_boston <- randomForest::randomForest(medv ~ ., data = Boston, ntree = 100) # Test the cached forest type - expect_is(rf_boston, "randomForest") - + expect_s3_class(rf_boston, "randomForest") + # Test the forest family expect_match(rf_boston$type, "regression") - + ## Create the correct gg_error object gg_dta <- gg_error(rf_boston) - + # Test object type - expect_is(gg_dta, "gg_error") - + expect_s3_class(gg_dta, "gg_error") + # Test classification dimensions expect_equal(nrow(gg_dta), length(na.omit(rf_boston$mse))) expect_equal(ncol(gg_dta), 2) - + # Test data is correctly pulled from randomForest obect. - expect_equivalent(c(gg_dta[, 1]), rf_boston$mse) - + expect_equal(c(gg_dta[, 1]), rf_boston$mse, ignore_attr = TRUE) + ## Create the correct gg_error object gg_dta <- gg_error(rf_boston) # Test object type - expect_is(gg_dta, "gg_error") + expect_s3_class(gg_dta, "gg_error") ## Training curve is available gg_train <- gg_error(rf_boston, training = TRUE) expect_true("train" %in% names(gg_train)) expect_length(gg_train$train, nrow(gg_train)) expect_true(all(is.finite(gg_train$train))) - + +}) + +## ---- Direct plot.gg_error() tests ----------------------------------------- +# The tests above exercise plot() via S3 dispatch. The tests below call +# plot.gg_error() by name to ensure the function itself is exercised directly +# and all branches are covered explicitly. + +test_that("plot.gg_error direct: regression rfsrc (single-outcome path, no legend)", { + data(Boston, package = "MASS") + Boston$chas <- as.logical(Boston$chas) # nolint: object_name_linter + rfsrc_boston <- randomForestSRC::rfsrc(medv ~ ., data = Boston, ntree = 50L) + gg_dta <- gg_error(rfsrc_boston) + + # Call by name — not via generic dispatch + gg_plt <- plot.gg_error(gg_dta) + + expect_s3_class(gg_plt, "ggplot") + # Single-outcome: gg_dta has exactly 2 columns (error, ntree) + expect_equal(ncol(gg_dta), 2L) + # Legend should be suppressed — theme has legend.position = "none" + expect_equal(gg_plt$theme$legend.position, "none") +}) + +test_that("plot.gg_error direct: survival rfsrc (single-outcome path)", { + data(pbc, package = "randomForestSRC") + pbc$time <- pbc$days / 364.25 + pbc_sub <- pbc[, c("time", "status", "age", "bili")] + Surv <- survival::Surv # nolint: object_name_linter + rfsrc_pbc <- randomForestSRC::rfsrc( + Surv(time, status) ~ ., + data = pbc_sub, + ntree = 50L + ) + gg_dta <- gg_error(rfsrc_pbc) + gg_plt <- plot.gg_error(gg_dta) + + expect_s3_class(gg_plt, "ggplot") + expect_equal(ncol(gg_dta), 2L) +}) + +test_that("plot.gg_error direct: classification rfsrc (multi-outcome path, legend shown)", { + data(iris, package = "datasets") + rfsrc_iris <- randomForestSRC::rfsrc( + Species ~ ., data = iris, importance = TRUE, tree.err = TRUE + ) + gg_dta <- gg_error(rfsrc_iris) + gg_plt <- plot.gg_error(gg_dta) + + expect_s3_class(gg_plt, "ggplot") + # Multi-outcome: more than 2 columns before pivot + expect_true(ncol(gg_dta) > 2L) + # Legend should NOT be suppressed — multiple outcomes need the colour key + expect_false(identical(gg_plt$theme$legend.position, "none")) +}) + +test_that("plot.gg_error direct: accepts raw rfsrc object (auto-extract path)", { + data(Boston, package = "MASS") + Boston$chas <- as.logical(Boston$chas) # nolint: object_name_linter + rfsrc_boston <- randomForestSRC::rfsrc(medv ~ ., data = Boston, ntree = 50L) + + # Pass the rfsrc object directly — plot.gg_error should call gg_error() internally + gg_plt <- plot.gg_error(rfsrc_boston) + expect_s3_class(gg_plt, "ggplot") +}) + +test_that("plot.gg_error direct: errors on non-gg_error non-rfsrc input", { + expect_error(plot.gg_error(list(a = 1)), "Incorrect object type") + expect_error(plot.gg_error("not a forest"), "Incorrect object type") +}) + +test_that("plot.gg_error direct: point geometry used when only one valid row", { + data(Boston, package = "MASS") + Boston$chas <- as.logical(Boston$chas) # nolint: object_name_linter + rfsrc_boston <- randomForestSRC::rfsrc(medv ~ ., data = Boston, ntree = 50L) + gg_dta <- gg_error(rfsrc_boston) + + # Manufacture a single-row gg_error to trigger the point branch + single_row <- gg_dta[1L, ] + class(single_row) <- class(gg_dta) + gg_plt <- plot.gg_error(single_row) + + expect_s3_class(gg_plt, "ggplot") + # Confirm geom_point was used (not geom_line) + geom_classes <- sapply(gg_plt$layers, function(l) class(l$geom)[1]) + expect_true(any(grepl("GeomPoint", geom_classes))) }) diff --git a/tests/testthat/test_gg_partial.R b/tests/testthat/test_gg_partial.R index 5520c9d3..4981d002 100644 --- a/tests/testthat/test_gg_partial.R +++ b/tests/testthat/test_gg_partial.R @@ -1,5 +1,4 @@ # Tests for gg_partial and gg_partial_rfsrc -context("gg_partial tests") # Helper: create mock partial plot data (matching rfsrc::plot.variable structure) make_mock_partial_data <- function() { @@ -150,8 +149,10 @@ test_that("gg_partial_rfsrc error on invalid xvar.names", { airq <- na.omit(airquality) rf <- randomForestSRC::rfsrc(Ozone ~ ., data = airq, ntree = 50, nsplit = 5) - result <- gg_partial_rfsrc(rf, xvar.names = c("NotAColumn")) - expect_match(result, "xvar.names contains column names not found") + expect_error( + gg_partial_rfsrc(rf, xvar.names = c("NotAColumn")), + "xvar.names contains column names not found" + ) }) test_that("gg_partial_rfsrc error on invalid newx columns", { @@ -162,8 +163,10 @@ test_that("gg_partial_rfsrc error on invalid newx columns", { rf <- randomForestSRC::rfsrc(Ozone ~ ., data = airq, ntree = 50, nsplit = 5) bad_newx <- data.frame(foo = 1:10, bar = 1:10) - result <- gg_partial_rfsrc(rf, xvar.names = c("Wind"), newx = bad_newx) - expect_match(result, "newx must be a dataframe") + expect_error( + gg_partial_rfsrc(rf, xvar.names = c("Wind"), newx = bad_newx), + "newx must be a dataframe" + ) }) test_that("gg_partial_rfsrc uses supplied newx data", { diff --git a/tests/testthat/test_gg_partialpro.R b/tests/testthat/test_gg_partialpro.R index 9e24d8ff..d6ab51fe 100644 --- a/tests/testthat/test_gg_partialpro.R +++ b/tests/testthat/test_gg_partialpro.R @@ -1,5 +1,4 @@ # Tests for gg_partialpro -context("gg_partialpro tests") # Helper: create mock VarPro partialpro data # - Continuous: length(xvirtual) > cat_limit (10) diff --git a/tests/testthat/test_gg_rfsrc.R b/tests/testthat/test_gg_rfsrc.R index afc26dab..538426a1 100644 --- a/tests/testthat/test_gg_rfsrc.R +++ b/tests/testthat/test_gg_rfsrc.R @@ -1,8 +1,9 @@ # testthat for gg_rfsrc function -context("gg_rfsrc tests") -# Survival formula helper (rfsrc requires Surv to be in local scope) -Surv <- survival::Surv +# Bring survival::Surv into the file-scope environment so that rfsrc() survival +# formulas (e.g. Surv(time, status) ~ .) can resolve the symbol without +# requiring library(survival) to be called explicitly in every test block. +Surv <- survival::Surv # nolint: object_name_linter test_that("gg_rfsrc classifications", { ## Load the cached forest @@ -12,59 +13,59 @@ test_that("gg_rfsrc classifications", { forest = TRUE, importance = TRUE, save.memory = TRUE) - + # Test the cached forest type - expect_is(rfsrc_iris, "rfsrc") - + expect_s3_class(rfsrc_iris, "rfsrc") + # Test the forest family - expect_is(rfsrc_iris, "class") - + expect_s3_class(rfsrc_iris, "class") + ## Create the correct gg_error object gg_dta <- gg_rfsrc(rfsrc_iris) - + # Test object type - expect_is(gg_dta, "gg_rfsrc") - + expect_s3_class(gg_dta, "gg_rfsrc") + # Test classification dimensions expect_equal(nrow(gg_dta), nrow(rfsrc_iris$predicted.oob)) expect_equal(ncol(gg_dta), ncol(rfsrc_iris$predicted.oob) + 1) - + # Test data is correctly pulled from randomForest obect. - expect_equivalent(as.matrix(gg_dta[, -which(colnames(gg_dta) == "y")]), + expect_equal(as.matrix(gg_dta[, -which(colnames(gg_dta) == "y")], ignore_attr = TRUE), rfsrc_iris$predicted.oob) - + ## Test plotting the gg_error object gg_plt <- plot.gg_rfsrc(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - - + expect_s3_class(gg_plt, "ggplot") + + ## Create the correct gg_error object gg_dta <- gg_rfsrc(rfsrc_iris, oob = FALSE) - + # Test object type - expect_is(gg_dta, "gg_rfsrc") - + expect_s3_class(gg_dta, "gg_rfsrc") + # Test classification dimensions expect_equal(nrow(gg_dta), nrow(rfsrc_iris$predicted)) expect_equal(ncol(gg_dta), ncol(rfsrc_iris$predicted) + 1) - + # Test data is correctly pulled from randomForest obect. - expect_equivalent(as.matrix(gg_dta[, -which(colnames(gg_dta) == "y")]), + expect_equal(as.matrix(gg_dta[, -which(colnames(gg_dta) == "y")], ignore_attr = TRUE), rfsrc_iris$predicted) - + rf_iris <- randomForest::randomForest(Species ~ ., data = iris) gg_dta <- gg_rfsrc(rf_iris) - + }) test_that("gg_rfsrc regression", { data(Boston, package = "MASS") boston <- Boston - + boston$chas <- as.logical(boston$chas) - + ## Load the cached forest rfsrc_boston <- randomForestSRC::rfsrc( @@ -74,56 +75,56 @@ test_that("gg_rfsrc regression", { importance = TRUE, tree.err = TRUE, save.memory = TRUE) - + # Test the cached forest type - expect_is(rfsrc_boston, "rfsrc") - + expect_s3_class(rfsrc_boston, "rfsrc") + # Test the forest family expect_match(rfsrc_boston$family, "regr") - + ## Create the correct gg_error object gg_dta <- gg_rfsrc(rfsrc_boston) - + # Test object type - expect_is(gg_dta, "gg_rfsrc") - expect_is(gg_dta, "regr") - + expect_s3_class(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "regr") + ## Test plotting the gg_error object gg_plt <- plot.gg_rfsrc(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + ## Create the correct gg_error object gg_dta <- gg_rfsrc(rfsrc_boston, oob = FALSE) - + # Test object type - expect_is(gg_dta, "gg_rfsrc") - + expect_s3_class(gg_dta, "gg_rfsrc") + # Test classification dimensions ## Create the correct gg_error object gg_dta <- gg_rfsrc(rfsrc_boston, by = "chas") - + # Test object type - expect_is(gg_dta, "gg_rfsrc") - expect_is(gg_dta, "regr") - + expect_s3_class(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "regr") + ## Test plotting the gg_error object gg_plt <- plot.gg_rfsrc(gg_dta) - + # Test data is correctly pulled from randomForest obect. # Predicted values rfsrc_boston$family <- "test" expect_error(gg_rfsrc(rfsrc_boston)) - + # Test exceptions # Is it an rfsrc object? expect_error(gg_rfsrc(gg_plt)) - + # Does it contain the forest? rfsrc_boston$forest <- NULL expect_error(gg_rfsrc(rfsrc_boston)) - + data(Boston, package = "MASS") rf_boston <- randomForest(medv ~ ., data = Boston) plot(gg_rfsrc(rf_boston)) @@ -142,13 +143,13 @@ test_that("gg_rfsrc survival: per-observation curves (no conf.int, no by)", { save.memory = TRUE ) - expect_is(rfsrc_veteran, "rfsrc") + expect_s3_class(rfsrc_veteran, "rfsrc") expect_match(rfsrc_veteran$family, "surv") gg_dta <- gg_rfsrc(rfsrc_veteran) - expect_is(gg_dta, "gg_rfsrc") - expect_is(gg_dta, "surv") + expect_s3_class(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "surv") # Per-observation long format: should have variable, value, obs_id, event expect_true("variable" %in% colnames(gg_dta)) expect_true("value" %in% colnames(gg_dta)) @@ -170,7 +171,7 @@ test_that("gg_rfsrc survival: plot of per-observation curves", { gg_dta <- gg_rfsrc(rfsrc_veteran) gg_plt <- plot.gg_rfsrc(gg_dta) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("gg_rfsrc survival: confidence interval calculation", { @@ -187,8 +188,8 @@ test_that("gg_rfsrc survival: confidence interval calculation", { gg_dta <- gg_rfsrc(rfsrc_veteran, conf.int = 0.95) - expect_is(gg_dta, "gg_rfsrc") - expect_is(gg_dta, "surv") + expect_s3_class(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "surv") # conf.int output has lower, upper, median columns expect_true("lower" %in% colnames(gg_dta)) expect_true("upper" %in% colnames(gg_dta)) @@ -196,7 +197,7 @@ test_that("gg_rfsrc survival: confidence interval calculation", { # Plot the conf.int result gg_plt <- plot.gg_rfsrc(gg_dta) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("gg_rfsrc survival: confidence interval with percentage > 1", { @@ -214,7 +215,7 @@ test_that("gg_rfsrc survival: confidence interval with percentage > 1", { # conf.int > 1 should be divided by 100 gg_dta <- gg_rfsrc(rfsrc_veteran, conf.int = 95) - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") expect_true("lower" %in% colnames(gg_dta)) }) @@ -232,13 +233,13 @@ test_that("gg_rfsrc survival: by argument groups observations", { gg_dta <- gg_rfsrc(rfsrc_veteran, by = "trt") - expect_is(gg_dta, "gg_rfsrc") - expect_is(gg_dta, "surv") + expect_s3_class(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "surv") expect_true("group" %in% colnames(gg_dta)) # Plot the grouped result gg_plt <- plot.gg_rfsrc(gg_dta) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("gg_rfsrc survival: by + conf.int produces grouped confidence bands", { @@ -255,13 +256,13 @@ test_that("gg_rfsrc survival: by + conf.int produces grouped confidence bands", gg_dta <- gg_rfsrc(rfsrc_veteran, by = "trt", conf.int = 0.95) - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") expect_true("group" %in% colnames(gg_dta)) expect_true("lower" %in% colnames(gg_dta)) # Plot gg_plt <- plot.gg_rfsrc(gg_dta) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("gg_rfsrc survival: surv_type = 'chf' (cumulative hazard)", { @@ -277,7 +278,7 @@ test_that("gg_rfsrc survival: surv_type = 'chf' (cumulative hazard)", { ) gg_dta <- gg_rfsrc(rfsrc_veteran, surv_type = "chf") - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") }) test_that("gg_rfsrc survival: surv_type = 'mortality'", { @@ -293,7 +294,7 @@ test_that("gg_rfsrc survival: surv_type = 'mortality'", { ) gg_dta <- gg_rfsrc(rfsrc_veteran, surv_type = "mortality") - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") }) test_that("gg_rfsrc survival: unknown surv_type throws error", { @@ -326,7 +327,7 @@ test_that("gg_rfsrc survival: by vector (not column name) works", { by_vec <- veteran$trt gg_dta <- gg_rfsrc(rfsrc_veteran, by = by_vec) - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") expect_true("group" %in% colnames(gg_dta)) }) @@ -358,8 +359,8 @@ test_that("gg_rfsrc survival: oob = FALSE uses in-bag predictions", { ) gg_dta <- gg_rfsrc(rfsrc_veteran, oob = FALSE) - expect_is(gg_dta, "gg_rfsrc") - expect_is(gg_dta, "surv") + expect_s3_class(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "surv") }) test_that("gg_rfsrc survival: conf.int with custom bs.sample", { @@ -375,7 +376,7 @@ test_that("gg_rfsrc survival: conf.int with custom bs.sample", { ) gg_dta <- gg_rfsrc(rfsrc_veteran, conf.int = 0.95, bs.sample = 20) - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") expect_true("lower" %in% colnames(gg_dta)) }) @@ -392,7 +393,7 @@ test_that("gg_rfsrc survival: by + conf.int with custom bs.sample", { ) gg_dta <- gg_rfsrc(rfsrc_veteran, by = "trt", conf.int = 0.95, bs.sample = 20) - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") expect_true("group" %in% colnames(gg_dta)) expect_true("lower" %in% colnames(gg_dta)) }) @@ -411,7 +412,7 @@ test_that("gg_rfsrc survival: conf.int with two-element level_set", { # Two-element conf.int (lower and upper directly) gg_dta <- gg_rfsrc(rfsrc_veteran, conf.int = c(0.025, 0.975)) - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") }) test_that("gg_rfsrc classification: by argument adds group column", { @@ -426,7 +427,7 @@ test_that("gg_rfsrc classification: by argument adds group column", { by_vec <- iris$Petal.Width > median(iris$Petal.Width) gg_dta <- gg_rfsrc(rfsrc_iris, by = by_vec) - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") expect_true("group" %in% colnames(gg_dta)) }) @@ -445,6 +446,28 @@ test_that("gg_rfsrc regression: by = vector works", { by_vec <- boston$chas gg_dta <- gg_rfsrc(rfsrc_boston, by = by_vec) - expect_is(gg_dta, "gg_rfsrc") + expect_s3_class(gg_dta, "gg_rfsrc") expect_true("group" %in% colnames(gg_dta)) }) + +test_that("plot.gg_rfsrc notch=FALSE suppresses notch in boxplot", { + # The notch argument used to be hardcoded to TRUE; callers had no way to + # suppress it. Passing notch = FALSE must not error and must return a ggplot. + set.seed(42) + rfsrc_iris <- randomForestSRC::rfsrc( + Species ~ ., data = iris, ntree = 50, save.memory = TRUE + ) + gg_dta <- gg_rfsrc(rfsrc_iris) + gg_plt <- plot.gg_rfsrc(gg_dta, notch = FALSE) + expect_s3_class(gg_plt, "ggplot") + + data(Boston, package = "MASS") + Boston$chas <- as.logical(Boston$chas) # nolint: object_name_linter + set.seed(42) + rfsrc_boston <- randomForestSRC::rfsrc( + medv ~ ., data = Boston, ntree = 50, save.memory = TRUE + ) + gg_dta_regr <- gg_rfsrc(rfsrc_boston) + gg_plt_regr <- plot.gg_rfsrc(gg_dta_regr, notch = FALSE) + expect_s3_class(gg_plt_regr, "ggplot") +}) diff --git a/tests/testthat/test_gg_roc.R b/tests/testthat/test_gg_roc.R index 1a3cd1d3..bc7b0e60 100644 --- a/tests/testthat/test_gg_roc.R +++ b/tests/testthat/test_gg_roc.R @@ -1,5 +1,4 @@ # testthat for gg_roc function -context("gg_roc tests") test_that("gg_roc classifications", { ## Load the cached forest @@ -9,105 +8,122 @@ test_that("gg_roc classifications", { forest = TRUE, importance = TRUE, save.memory = TRUE) - + # Test the cached forest type - expect_is(rfsrc_iris, "rfsrc") - + expect_s3_class(rfsrc_iris, "rfsrc") + # Test the forest family expect_match(rfsrc_iris$family, "class") - + ## Create the correct gg_roc object which_outcome <- 1 gg_dta <- gg_roc(rfsrc_iris, which_outcome) - + # Test object type - expect_is(gg_dta, "gg_roc") - + expect_s3_class(gg_dta, "gg_roc") + # Test classification dimensions expect_equal(ncol(gg_dta), 3) - + # Test data is correctly pulled from randomForest obect. unts <- sort(unique(rfsrc_iris$predicted.oob[, which_outcome])) - + ## Test plotting the gg_roc object gg_obj <- plot.gg_roc(gg_dta) - + # Test return is s ggplot object - expect_is(gg_obj, "ggplot") - + expect_s3_class(gg_obj, "ggplot") + # Try test set prediction. gg_dta <- gg_roc(rfsrc_iris, which_outcome, oob = FALSE) # Try test set prediction. gg_plt <- plot.gg_roc(rfsrc_iris) - + # Test object type - expect_is(gg_dta, "gg_roc") - + expect_s3_class(gg_dta, "gg_roc") + # Test classification dimensions expect_equal(ncol(gg_dta), 3) - + # Test data is correctly pulled from randomForest obect. unts <- sort(unique(rfsrc_iris$predicted[, which_outcome])) - + ## Test plotting the gg_roc object gg_obj <- plot.gg_roc(gg_dta) - + # Test return is s ggplot object - expect_is(gg_obj, "ggplot") - - expect_is(plot.gg_roc(rfsrc_iris), "ggplot") + expect_s3_class(gg_obj, "ggplot") + + expect_s3_class(plot.gg_roc(rfsrc_iris), "ggplot") expect_error(gg_roc.randomForest(rfsrc_iris)) - expect_error(gg_roc.rfrsrc(rf_iris)) + expect_error(gg_roc.rfsrc(rf_iris)) }) test_that("gg_roc randomForest classifications", { ## Load the cached forest rf_iris <- randomForest(Species ~ ., data = iris) - + # Test the cached forest type - expect_is(rf_iris, "randomForest") - + expect_s3_class(rf_iris, "randomForest") + # Test the forest family expect_match(rf_iris$type, "classification") - + ## Create the correct gg_roc object which_outcome <- 1 gg_dta <- gg_roc(rf_iris, which_outcome) - + # Test object type - expect_is(gg_dta, "gg_roc") - + expect_s3_class(gg_dta, "gg_roc") + ## Test plotting the gg_roc object gg_obj <- plot.gg_roc(gg_dta) - + # Test return is s ggplot object - expect_is(gg_obj, "ggplot") - + expect_s3_class(gg_obj, "ggplot") + # Try test set prediction. gg_dta <- gg_roc(rf_iris, which_outcome, oob = FALSE) - + # Test object type - expect_is(gg_dta, "gg_roc") + expect_s3_class(gg_dta, "gg_roc") # Test classification dimensions expect_equal(ncol(gg_dta), 3) - + ## Test plotting the gg_roc object gg_obj <- plot.gg_roc(gg_dta) - + # Test return is s ggplot object - expect_is(gg_obj, "ggplot") - - expect_is(plot.gg_roc(rf_iris), "ggplot") - - expect_error(gg_roc.rfrsrc(rf_iris)) + expect_s3_class(gg_obj, "ggplot") + + expect_s3_class(plot.gg_roc(rf_iris), "ggplot") + + expect_error(gg_roc.rfsrc(rf_iris)) +}) + +test_that("gg_roc default oob=TRUE works without explicit argument", { + # Regression test: gg_roc() was crashing with + # "argument 'oob' is missing, with no default" when oob was not supplied. + set.seed(42) + rfsrc_iris <- randomForestSRC::rfsrc(Species ~ ., data = iris, ntree = 50) + + # All three outcomes must work without passing oob + expect_s3_class(gg_roc(rfsrc_iris, which_outcome = 1), "gg_roc") + expect_s3_class(gg_roc(rfsrc_iris, which_outcome = 2), "gg_roc") + expect_s3_class(gg_roc(rfsrc_iris, which_outcome = 3), "gg_roc") + + # OOB default should equal oob = TRUE explicitly + gg_default <- gg_roc(rfsrc_iris, which_outcome = 1) + gg_explicit <- gg_roc(rfsrc_iris, which_outcome = 1, oob = TRUE) + expect_equal(gg_default, gg_explicit) }) test_that("gg_roc regression", { data(Boston, package = "MASS") boston <- Boston - + boston$chas <- as.logical(boston$chas) - + ## Load the cached forest rfsrc_boston <- randomForestSRC::rfsrc( @@ -117,17 +133,17 @@ test_that("gg_roc regression", { importance = TRUE, tree.err = TRUE, save.memory = TRUE) - + # Test the cached forest type - expect_is(rfsrc_boston, "rfsrc") - + expect_s3_class(rfsrc_boston, "rfsrc") + # Test the forest family expect_match(rfsrc_boston$family, "regr") - + ## Create the correct gg_roc object expect_error(gg_roc(rfsrc_boston)) expect_error(plot.gg_roc(rfsrc_boston)) - + }) test_that("calc_roc", { @@ -137,34 +153,34 @@ test_that("calc_roc", { forest = TRUE, importance = TRUE, save.memory = TRUE) - + # Test the cached forest type - expect_is(rfsrc_iris, "rfsrc") - + expect_s3_class(rfsrc_iris, "rfsrc") + # Test the forest family expect_match(rfsrc_iris$family, "class") - + gg_dta <- calc_roc.rfsrc(rfsrc_iris, rfsrc_iris$yvar, which_outcome = 1, oob = TRUE) - + # Test the cached forest type - expect_is(gg_dta, "data.frame") - + expect_s3_class(gg_dta, "data.frame") + expect_equal(ncol(gg_dta), 3) - + # Test oob=FALSE gg_dta <- calc_roc.rfsrc(rfsrc_iris, rfsrc_iris$yvar, which_outcome = 1, oob = FALSE) - + # Test the cached forest type - expect_is(gg_dta, "data.frame") - + expect_s3_class(gg_dta, "data.frame") + expect_equal(ncol(gg_dta), 3) - + # test the auc calculator auc <- calc_auc(gg_dta) expect_true(auc > .9) @@ -174,12 +190,12 @@ test_that("calc_roc", { rfsrc_iris$yvar, which_outcome = 2, oob = TRUE) - + # Test the cached forest type - expect_is(gg_dta, "data.frame") - + expect_s3_class(gg_dta, "data.frame") + expect_equal(ncol(gg_dta), 3) - + # test the auc calculator auc <- calc_auc(gg_dta) expect_true(auc > .9) @@ -189,12 +205,12 @@ test_that("calc_roc", { rfsrc_iris$yvar, which_outcome = 3, oob = TRUE) - + # Test the cached forest type - expect_is(gg_dta, "data.frame") - + expect_s3_class(gg_dta, "data.frame") + expect_equal(ncol(gg_dta), 3) - + # test the auc calculator auc <- calc_auc(gg_dta) expect_true(auc > .9) diff --git a/tests/testthat/test_gg_survival.R b/tests/testthat/test_gg_survival.R index 4b55f9fd..c25d0c31 100644 --- a/tests/testthat/test_gg_survival.R +++ b/tests/testthat/test_gg_survival.R @@ -1,5 +1,4 @@ # testthat for gg_survival function -context("gg_survival tests") test_that("gg_survival classifications", { expect_error(gg_survival(data = iris)) @@ -9,10 +8,10 @@ test_that("gg_survival classifications", { test_that("gg_survival survival", { # ## Load the cached forest data(pbc, package = "randomForestSRC") - + # Test the cached forest type - expect_is(pbc, "data.frame") - + expect_s3_class(pbc, "data.frame") + # Test object type gg_dta <- gg_survival( interval = "days", @@ -21,25 +20,25 @@ test_that("gg_survival survival", { data = pbc, conf.int = .95 ) - - expect_is(gg_dta, "gg_survival") - + + expect_s3_class(gg_dta, "gg_survival") + ## Test plotting the gg_error object gg_plt <- plot.gg_survival(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - - expect_is(plot(gg_dta, error = "bars"), "ggplot") - expect_is(plot(gg_dta, error = "none"), "ggplot") - expect_is(plot(gg_dta, error = "lines"), "ggplot") - expect_is(plot(gg_dta, type = "surv"), "ggplot") - expect_is(plot(gg_dta, type = "cum_haz"), "ggplot") - expect_is(plot(gg_dta, type = "density"), "ggplot") - expect_is(plot(gg_dta, type = "mid_int"), "ggplot") - expect_is(plot(gg_dta, type = "life"), "ggplot") - expect_is(plot(gg_dta, type = "hazard"), "ggplot") - expect_is(plot(gg_dta, type = "proplife"), "ggplot") + expect_s3_class(gg_plt, "ggplot") + + expect_s3_class(plot(gg_dta, error = "bars"), "ggplot") + expect_s3_class(plot(gg_dta, error = "none"), "ggplot") + expect_s3_class(plot(gg_dta, error = "lines"), "ggplot") + expect_s3_class(plot(gg_dta, type = "surv"), "ggplot") + expect_s3_class(plot(gg_dta, type = "cum_haz"), "ggplot") + expect_s3_class(plot(gg_dta, type = "density"), "ggplot") + expect_s3_class(plot(gg_dta, type = "mid_int"), "ggplot") + expect_s3_class(plot(gg_dta, type = "life"), "ggplot") + expect_s3_class(plot(gg_dta, type = "hazard"), "ggplot") + expect_s3_class(plot(gg_dta, type = "proplife"), "ggplot") # Test object type gg_dta <- gg_survival( interval = "days", @@ -49,16 +48,16 @@ test_that("gg_survival survival", { conf.int = .95, type = "nelson" ) - - expect_is(gg_dta, "gg_survival") - + + expect_s3_class(gg_dta, "gg_survival") + ## Test plotting the gg_error object gg_plt <- plot.gg_survival(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - - + expect_s3_class(gg_plt, "ggplot") + + # Test object type gg_dta <- gg_survival( interval = "days", @@ -66,32 +65,32 @@ test_that("gg_survival survival", { data = pbc, conf.int = .95 ) - - expect_is(gg_dta, "gg_survival") - + + expect_s3_class(gg_dta, "gg_survival") + ## Test plotting the gg_error object gg_plt <- plot.gg_survival(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - - expect_is(plot(gg_dta, error = "bars"), "ggplot") - expect_is(plot(gg_dta, error = "none"), "ggplot") - expect_is(plot(gg_dta, error = "lines"), "ggplot") - expect_is(plot(gg_dta, type = "surv"), "ggplot") - expect_is(plot(gg_dta, type = "cum_haz"), "ggplot") - expect_is(plot(gg_dta, type = "density"), "ggplot") - expect_is(plot(gg_dta, type = "mid_int"), "ggplot") - expect_is(plot(gg_dta, type = "life"), "ggplot") - expect_is(plot(gg_dta, type = "hazard"), "ggplot") - expect_is(plot(gg_dta, type = "proplife"), "ggplot") - + expect_s3_class(gg_plt, "ggplot") + + expect_s3_class(plot(gg_dta, error = "bars"), "ggplot") + expect_s3_class(plot(gg_dta, error = "none"), "ggplot") + expect_s3_class(plot(gg_dta, error = "lines"), "ggplot") + expect_s3_class(plot(gg_dta, type = "surv"), "ggplot") + expect_s3_class(plot(gg_dta, type = "cum_haz"), "ggplot") + expect_s3_class(plot(gg_dta, type = "density"), "ggplot") + expect_s3_class(plot(gg_dta, type = "mid_int"), "ggplot") + expect_s3_class(plot(gg_dta, type = "life"), "ggplot") + expect_s3_class(plot(gg_dta, type = "hazard"), "ggplot") + expect_s3_class(plot(gg_dta, type = "proplife"), "ggplot") + }) test_that("gg_survival regression", { ## Load the data data(Boston, package = "MASS") - + ## Create the correct gg_error object expect_error(gg_survival(data = Boston)) }) diff --git a/tests/testthat/test_gg_variable.R b/tests/testthat/test_gg_variable.R index 9dc92c76..e026d3ff 100644 --- a/tests/testthat/test_gg_variable.R +++ b/tests/testthat/test_gg_variable.R @@ -1,8 +1,7 @@ # testthat for gg_variable function -context("gg_variable tests") # Survival formula helper (rfsrc requires Surv to be in local scope) -Surv <- survival::Surv +Surv <- survival::Surv # nolint: object_name_linter test_that("gg_variable classifications", { ## Load the cached forest @@ -12,58 +11,58 @@ test_that("gg_variable classifications", { forest = TRUE, importance = TRUE, save.memory = TRUE) - + # Test the cached forest type - expect_is(rfsrc_iris, "rfsrc") - + expect_s3_class(rfsrc_iris, "rfsrc") + # Test the forest family expect_equal(rfsrc_iris$family, "class") - + ## Create the correct gg_error object gg_dta <- gg_variable(rfsrc_iris) - + # Test object type - expect_is(gg_dta, "gg_variable") + expect_s3_class(gg_dta, "gg_variable") ## Ensure non-OOB predictions keep predictor columns gg_full <- gg_variable(rfsrc_iris, oob = FALSE) expect_true(all(rfsrc_iris$xvar.names %in% names(gg_full))) - + ## Test plotting the gg_error object gg_plt <- plot(gg_dta, xvar = "Petal.Width") - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + ## Test plotting the gg_error object gg_plt <- plot.gg_variable(gg_dta, xvar = rfsrc_iris$xvar.names) - + # Test return is s ggplot object - expect_is(gg_plt, "list") + expect_type(gg_plt, "list") expect_equal(length(gg_plt), length(rfsrc_iris$xvar.names)) for (ind in seq_along(rfsrc_iris$xvar.names)) - expect_is(gg_plt[[ind]], "ggplot") + expect_s3_class(gg_plt[[ind]], "ggplot") ## Test plotting the gg_error object gg_plt <- plot.gg_variable(gg_dta, xvar = rfsrc_iris$xvar.names, panel = TRUE) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + rf_iris <- randomForest::randomForest(Species ~ ., data = iris) - + ## Create the correct gg_error object gg_dta <- gg_variable(rf_iris) - + # Test object type - expect_is(gg_dta, "gg_variable") - + expect_s3_class(gg_dta, "gg_variable") + ## Test plotting the gg_error object gg_plt <- plot(gg_dta, xvar = "Petal.Width") - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") gg_plt <- plot(gg_dta) ## Ensure we can rebuild training data for subset() calls @@ -72,18 +71,18 @@ test_that("gg_variable classifications", { data = iris_two, ntree = 60) gg_subset <- gg_variable(rf_subset) - expect_is(gg_subset, "gg_variable") + expect_s3_class(gg_subset, "gg_variable") expect_true(all(c("Sepal.Length", "Sepal.Width") %in% names(gg_subset))) - + }) test_that("gg_variable regression", { data(Boston, package = "MASS") boston <- Boston - + boston$chas <- as.logical(boston$chas) - + ## Load the cached forest rfsrc_boston <- randomForestSRC::rfsrc( @@ -93,53 +92,53 @@ test_that("gg_variable regression", { importance = TRUE, tree.err = TRUE, save.memory = TRUE) - + # Test the cached forest type - expect_is(rfsrc_boston, "rfsrc") - + expect_s3_class(rfsrc_boston, "rfsrc") + ## Create the correct gg_error object gg_dta <- gg_variable(rfsrc_boston) - + # Test object type - expect_is(gg_dta, "gg_variable") - + expect_s3_class(gg_dta, "gg_variable") + ## Test plotting the gg_error object gg_plt <- plot.gg_variable(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "list") + expect_type(gg_plt, "list") expect_equal(length(gg_plt), length(rfsrc_boston$xvar.names)) for (ind in seq_along(rfsrc_boston$xvar.names)) - expect_is(gg_plt[[ind]], "ggplot") - - + expect_s3_class(gg_plt[[ind]], "ggplot") + + ## Test plotting the gg_error object expect_warning(gg_plt <- plot.gg_variable(gg_dta, panel = TRUE)) - expect_is(gg_plt, "ggplot") - - + expect_s3_class(gg_plt, "ggplot") + + data(Boston, package = "MASS") rf_boston <- randomForest::randomForest(medv ~ ., data = Boston) gg_dta <- gg_variable(rf_boston) - + # Test object type - expect_is(gg_dta, "gg_variable") - + expect_s3_class(gg_dta, "gg_variable") + expect_warning(gg_plt <- plot(gg_dta, panel = TRUE)) - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + }) test_that("gg_variable survival handles late time requests", { data(veteran, package = "randomForestSRC") - Surv <- survival::Surv + Surv <- survival::Surv # nolint: object_name_linter rfsrc_veteran <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran, ntree = 50, nsplit = 5) late_time <- max(rfsrc_veteran$time.interest) + 50 expect_silent(gg_dta <- gg_variable(rfsrc_veteran, time = late_time)) - expect_is(gg_dta, "gg_variable") + expect_s3_class(gg_dta, "gg_variable") }) test_that("gg_variable survival: single time, single continuous variable plot", { @@ -153,10 +152,10 @@ test_that("gg_variable survival: single time, single continuous variable plot", ) gg_dta <- gg_variable(rfsrc_veteran, time = 90) - expect_is(gg_dta, "gg_variable") + expect_s3_class(gg_dta, "gg_variable") gg_plt <- plot(gg_dta, xvar = "age") - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("gg_variable survival: single time, panel plot", { @@ -171,7 +170,7 @@ test_that("gg_variable survival: single time, panel plot", { gg_dta <- gg_variable(rfsrc_veteran, time = 90) gg_plt <- plot(gg_dta, xvar = c("age", "diagtime"), panel = TRUE) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("gg_variable survival: multiple times, single variable facets", { @@ -185,16 +184,16 @@ test_that("gg_variable survival: multiple times, single variable facets", { ) gg_dta <- gg_variable(rfsrc_veteran, time = c(30, 90, 365)) - expect_is(gg_dta, "gg_variable") + expect_s3_class(gg_dta, "gg_variable") # Single xvar → plot.gg_variable returns a single ggplot (not a list) gg_plt <- plot(gg_dta, xvar = "age") - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") # Multiple xvars → returns a list gg_plt2 <- plot(gg_dta, xvar = c("age", "diagtime")) - expect_is(gg_plt2, "list") - expect_is(gg_plt2[[1]], "ggplot") + expect_type(gg_plt2, "list") + expect_s3_class(gg_plt2[[1]], "ggplot") }) test_that("gg_variable survival: multiple times, panel plot", { @@ -209,7 +208,7 @@ test_that("gg_variable survival: multiple times, panel plot", { gg_dta <- gg_variable(rfsrc_veteran, time = c(30, 90)) gg_plt <- plot(gg_dta, xvar = c("age", "diagtime"), panel = TRUE) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("gg_variable survival: points=FALSE smooth=TRUE options", { @@ -225,10 +224,10 @@ test_that("gg_variable survival: points=FALSE smooth=TRUE options", { gg_dta <- gg_variable(rfsrc_veteran, time = 90) gg_plt <- plot(gg_dta, xvar = "age", points = FALSE, smooth = TRUE) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") gg_plt2 <- plot(gg_dta, xvar = "age", points = TRUE, smooth = FALSE) - expect_is(gg_plt2, "ggplot") + expect_s3_class(gg_plt2, "ggplot") }) test_that("gg_variable survival: oob=FALSE uses in-bag predictions", { @@ -242,7 +241,7 @@ test_that("gg_variable survival: oob=FALSE uses in-bag predictions", { ) gg_dta <- gg_variable(rfsrc_veteran, time = 90, oob = FALSE) - expect_is(gg_dta, "gg_variable") + expect_s3_class(gg_dta, "gg_variable") }) test_that("plot.gg_variable regression: points=FALSE smooth=TRUE", { @@ -260,10 +259,10 @@ test_that("plot.gg_variable regression: points=FALSE smooth=TRUE", { gg_dta <- gg_variable(rfsrc_boston) gg_plt <- plot(gg_dta, xvar = "lstat", points = FALSE, smooth = TRUE) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") gg_plt2 <- plot(gg_dta, xvar = "lstat", points = FALSE, smooth = FALSE) - expect_is(gg_plt2, "ggplot") + expect_s3_class(gg_plt2, "ggplot") }) test_that("plot.gg_variable regression: factor x variable triggers boxplot", { @@ -282,7 +281,7 @@ test_that("plot.gg_variable regression: factor x variable triggers boxplot", { gg_dta$chas <- factor(gg_dta$chas) gg_plt <- plot(gg_dta, xvar = "chas") - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("plot.gg_variable regression: panel with two continuous variables", { @@ -299,7 +298,7 @@ test_that("plot.gg_variable regression: panel with two continuous variables", { gg_dta <- gg_variable(rfsrc_boston) gg_plt <- plot(gg_dta, xvar = c("lstat", "rm"), panel = TRUE) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("plot.gg_variable regression: panel points=FALSE smooth=TRUE", { @@ -317,7 +316,7 @@ test_that("plot.gg_variable regression: panel points=FALSE smooth=TRUE", { gg_dta <- gg_variable(rfsrc_boston) gg_plt <- plot(gg_dta, xvar = c("lstat", "rm"), panel = TRUE, points = FALSE, smooth = TRUE) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("plot.gg_variable classification: panel with multiple continuous vars", { @@ -330,7 +329,7 @@ test_that("plot.gg_variable classification: panel with multiple continuous vars" gg_dta <- gg_variable(rfsrc_iris) gg_plt <- plot(gg_dta, xvar = c("Petal.Width", "Sepal.Width"), panel = TRUE) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("plot.gg_variable classification: smooth and no-points path", { @@ -343,7 +342,7 @@ test_that("plot.gg_variable classification: smooth and no-points path", { gg_dta <- gg_variable(rfsrc_iris) gg_plt <- plot(gg_dta, xvar = "Petal.Width", points = FALSE, smooth = TRUE) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") }) test_that("plot.gg_variable: missing xvar returns list for all predictors", { @@ -360,7 +359,7 @@ test_that("plot.gg_variable: missing xvar returns list for all predictors", { gg_dta <- gg_variable(rfsrc_boston) gg_plt <- plot(gg_dta) - expect_is(gg_plt, "list") + expect_type(gg_plt, "list") expect_gt(length(gg_plt), 0) - expect_is(gg_plt[[1]], "ggplot") + expect_s3_class(gg_plt[[1]], "ggplot") }) diff --git a/tests/testthat/test_gg_vimp.R b/tests/testthat/test_gg_vimp.R index 99134256..59ad08b6 100644 --- a/tests/testthat/test_gg_vimp.R +++ b/tests/testthat/test_gg_vimp.R @@ -1,5 +1,4 @@ # testthat for gg_vimp function -context("gg_vimp tests") test_that("gg_vimp classifications", { ## Load the cached forest @@ -11,111 +10,111 @@ test_that("gg_vimp classifications", { tree.err = TRUE ) # Test the cached forest type - expect_is(rfsrc_iris, "rfsrc") - + expect_s3_class(rfsrc_iris, "rfsrc") + # Test the forest family expect_equal(rfsrc_iris$family, "class") - + ## Create the correct gg_error object gg_dta <- gg_vimp(rfsrc_iris) - + # Test object type - expect_is(gg_dta, "gg_vimp") - + expect_s3_class(gg_dta, "gg_vimp") + ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + # Grab only one class... by number. gg_dta <- gg_vimp(rfsrc_iris, which.outcome = 2) - + # Test object type - expect_is(gg_dta, "gg_vimp") + expect_s3_class(gg_dta, "gg_vimp") ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") # Grab only one class... by number - for the overall model. gg_dta <- gg_vimp(rfsrc_iris, which.outcome = 0) - + # Test object type - expect_is(gg_dta, "gg_vimp") + expect_s3_class(gg_dta, "gg_vimp") ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") # Grab only one class... by name - for the overall model. gg_dta <- gg_vimp(rfsrc_iris, which.outcome = "all") - + # Test object type - expect_is(gg_dta, "gg_vimp") + expect_s3_class(gg_dta, "gg_vimp") ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") # Grab only one class... by name - for the overall model. gg_dta <- gg_vimp(rfsrc_iris, which.outcome = "setosa") - + # Test object type - expect_is(gg_dta, "gg_vimp") + expect_s3_class(gg_dta, "gg_vimp") ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + # Grab only one class... by name - that doesn't exist. expect_error(gg_vimp(rfsrc_iris, which.outcome = "nothing special")) - + # Grab only one class... by number - that doesn't exist. expect_error(gg_vimp(rfsrc_iris, which.outcome = 200)) - + ## Single class/ iris2 <- iris iris2$spec <- factor(as.character(iris2$Species) == "setosa") iris2 <- iris2[, -which(colnames(iris2) == "Species")] - + rf <- rfsrc(spec ~ ., iris2, importance = TRUE) - + gg_dta <- gg_vimp(rf) - - expect_is(gg_dta, "gg_vimp") - + + expect_s3_class(gg_dta, "gg_vimp") + # Test passing in the wrong object expect_error(gg_vimp(gg_dta)) expect_error(gg_vimp.rfsrc(gg_dta)) - + ## RandomForest case rf_iris <- randomForest::randomForest(Species ~ ., data = iris) - + gg_dta <- gg_vimp(rf_iris) - - expect_is(gg_dta, "gg_vimp") + + expect_s3_class(gg_dta, "gg_vimp") rf_iris_noimp <- randomForest::randomForest(Species ~ ., data = iris, importance = FALSE) rf_iris_noimp$importance <- NULL expect_warning(gg_dta <- gg_vimp(rf_iris_noimp)) - expect_is(gg_dta, "gg_vimp") - + expect_s3_class(gg_dta, "gg_vimp") + # Test passing in the wrong object expect_error(gg_vimp(gg_dta)) expect_error(gg_vimp.rfsrc(gg_dta)) - - + + gg_dta <- gg_vimp(rf_iris, which.outcome = 1) - - expect_is(gg_dta, "gg_vimp") - - - expect_is(gg_dta, "gg_vimp") + + expect_s3_class(gg_dta, "gg_vimp") + + + expect_s3_class(gg_dta, "gg_vimp") # Test passing in the wrong object expect_error(gg_vimp(gg_dta)) expect_error(gg_vimp.rfsrc(gg_dta)) @@ -153,19 +152,19 @@ test_that("gg_vimp survival", { } # Convert age to years pbc$age <- pbc$age / 364.24 - + pbc$years <- pbc$days / 364.24 pbc <- pbc[, -which(colnames(pbc) == "days")] pbc$treatment <- as.numeric(pbc$treatment) pbc$treatment[which(pbc$treatment == 1)] <- "DPCA" pbc$treatment[which(pbc$treatment == 2)] <- "placebo" pbc$treatment <- factor(pbc$treatment) - + cat("pbc: rfsrc\n") dta_train <- pbc[-which(is.na(pbc$treatment)), ] # Create a test set from the remaining patients pbc_test <- pbc[which(is.na(pbc$treatment)), ] - + rfsrc_pbc <- randomForestSRC::rfsrc( Surv(years, status) ~ ., dta_train, @@ -175,89 +174,92 @@ test_that("gg_vimp survival", { tree.err = TRUE ) # Test the cached forest type - expect_is(rfsrc_pbc, "rfsrc") - + expect_s3_class(rfsrc_pbc, "rfsrc") + ## Create the correct gg_error object gg_dta <- gg_vimp(rfsrc_pbc) - + # Test object type - expect_is(gg_dta, "gg_vimp") - + expect_s3_class(gg_dta, "gg_vimp") + # Test varselect is the same expect_equal(gg_dta$vimp, as.vector(sort(rfsrc_pbc$importance, decreasing = TRUE))) - + ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + ## Test plotting the gg_error object gg_plt <- plot(gg_dta, nvar = 5) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - - expect_is(plot(gg_dta, relative = TRUE), "ggplot") - + expect_s3_class(gg_plt, "ggplot") + + expect_s3_class(plot(gg_dta, relative = TRUE), "ggplot") + # Test cutting the size down - expect_is(gg_dta <- gg_vimp(rfsrc_pbc, nvar = 10), "gg_vimp") + expect_s3_class(gg_dta <- gg_vimp(rfsrc_pbc, nvar = 10), "gg_vimp") expect_equal(nrow(gg_dta), 10) - expect_is(plot(gg_dta), "ggplot") - + expect_s3_class(plot(gg_dta), "ggplot") + # Test the relative vimp output and plotting - expect_is(gg_dta <- - gg_vimp(rfsrc_pbc, relative = TRUE), "gg_vimp") - expect_is(plot(gg_dta), "ggplot") - - expect_is(gg_dta <- - gg_vimp(rfsrc_pbc, nvar = 10, relative = TRUE), - "gg_vimp") - expect_is(plot(gg_dta), "ggplot") - + expect_s3_class( + gg_dta <- gg_vimp(rfsrc_pbc, relative = TRUE), + "gg_vimp" + ) + expect_s3_class(plot(gg_dta), "ggplot") + + expect_s3_class( + gg_dta <- gg_vimp(rfsrc_pbc, nvar = 10, relative = TRUE), + "gg_vimp" + ) + expect_s3_class(plot(gg_dta), "ggplot") + # Test importance calculations. # If the forest does not have importance rfsrc_pbc$importance <- NULL expect_warning(gg_dta <- gg_vimp(rfsrc_pbc)) - expect_is(gg_dta, "gg_vimp") - expect_is(plot(gg_dta), "ggplot") - + expect_s3_class(gg_dta, "gg_vimp") + expect_s3_class(plot(gg_dta), "ggplot") + }) test_that("gg_vimp regression", { ## Load the cached forest data(Boston, package = "MASS") - - Boston$chas <- as.logical(Boston$chas) - - rfsrc_boston <- randomForestSRC::rfsrc(medv ~ ., data = Boston, + boston <- Boston + boston$chas <- as.logical(boston$chas) + + rfsrc_boston <- randomForestSRC::rfsrc(medv ~ ., data = boston, importance = TRUE) # Test the cached forest type - expect_is(rfsrc_boston, "rfsrc") - + expect_s3_class(rfsrc_boston, "rfsrc") + ## Create the correct gg_error object gg_dta <- gg_vimp(rfsrc_boston) - + # Test object type - expect_is(gg_dta, "gg_vimp") - + expect_s3_class(gg_dta, "gg_vimp") + # Test varselect is the same expect_equal(gg_dta$vimp, as.vector(sort(rfsrc_boston$importance, decreasing = TRUE))) - + ## Test plotting the gg_error object gg_plt <- plot(gg_dta) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") ## Test plotting the gg_error object gg_plt <- plot(gg_dta, relative = TRUE) - + # Test return is s ggplot object - expect_is(gg_plt, "ggplot") - - + expect_s3_class(gg_plt, "ggplot") + + cls <- sapply(Boston, class) # lbls <- @@ -291,7 +293,7 @@ test_that("gg_vimp regression", { # medv "Median value of homes ($1000s)." ) - + # Build a table for data description dta_labs <- data.frame(cbind( @@ -299,11 +301,11 @@ test_that("gg_vimp regression", { Description = lbls, type = cls )) - + # Build a named vector for labeling figures later/ st_labs <- as.character(dta_labs$Description) names(st_labs) <- names(cls) - + ## Test plotting the rfsrc object gg_plt <- plot.gg_vimp( rfsrc_boston, @@ -311,20 +313,39 @@ test_that("gg_vimp regression", { relative = TRUE, bars = rfsrc_boston$xvar.names ) - expect_is(gg_plt, "ggplot") - + expect_s3_class(gg_plt, "ggplot") + rf_boston <- randomForest::randomForest(medv ~ ., Boston) gg_dta <- gg_vimp(rf_boston) # Test varselect is the same expect_equal(gg_dta$vimp, as.vector(sort(rf_boston$importance, decreasing = TRUE))) - + gg_plt <- plot(gg_dta) - expect_is(gg_plt, "ggplot") + expect_s3_class(gg_plt, "ggplot") rf_boston_noimp <- randomForest::randomForest(medv ~ ., Boston, importance = FALSE) rf_boston_noimp$importance <- NULL expect_warning(gg_dta <- gg_vimp(rf_boston_noimp)) - expect_is(gg_dta, "gg_vimp") - + expect_s3_class(gg_dta, "gg_vimp") + +}) + +test_that("gg_vimp.randomForest regression: vimp column present even when importance is IncNodePurity", { + # Guard test: when randomForest stores importance as IncNodePurity (not X.IncMSE), + # gg_vimp must still produce a 'vimp' column so plot.gg_vimp and the positive + # flag work correctly. + data(Boston, package = "MASS") + # importance = FALSE (default) → $importance has only IncNodePurity + rf_boston <- randomForest::randomForest(medv ~ ., data = Boston, + importance = FALSE) + gg_dta <- gg_vimp(rf_boston) + + expect_s3_class(gg_dta, "gg_vimp") + expect_true("vimp" %in% colnames(gg_dta), + info = "vimp column must exist regardless of original column name") + expect_false(any(is.na(gg_dta$vimp)), + info = "vimp values must not be NA") + expect_true("positive" %in% colnames(gg_dta)) + expect_s3_class(plot(gg_dta), "ggplot") }) diff --git a/tests/testthat/test_ggrandomforests_news.R b/tests/testthat/test_ggrandomforests_news.R index 24a69abb..da587d9d 100644 --- a/tests/testthat/test_ggrandomforests_news.R +++ b/tests/testthat/test_ggrandomforests_news.R @@ -1,5 +1,4 @@ # Tests for ggrandomforests.news -context("ggrandomforests.news tests") test_that("ggrandomforests.news NEWS file exists in package", { newsfile <- file.path(system.file(package = "ggRandomForests"), "NEWS") diff --git a/tests/testthat/test_kaplan_nelson.R b/tests/testthat/test_kaplan_nelson.R new file mode 100644 index 00000000..9b1c3bc9 --- /dev/null +++ b/tests/testthat/test_kaplan_nelson.R @@ -0,0 +1,228 @@ +# Unit tests for kaplan(), nelson(), and bootstrap_survival() + +## ---- Shared survival data -------------------------------------------------- + +data(pbc, package = "randomForestSRC") +pbc_dta <- pbc +pbc_dta$time <- pbc_dta$days / 364.25 + +## ---- kaplan() -------------------------------------------------------------- + +test_that("kaplan returns a gg_survival data frame", { + gg_dta <- kaplan(interval = "time", censor = "status", data = pbc_dta) + expect_s3_class(gg_dta, "gg_survival") + expect_s3_class(gg_dta, "data.frame") +}) + +test_that("kaplan output has required columns", { + gg_dta <- kaplan(interval = "time", censor = "status", data = pbc_dta) + required_cols <- c("time", "n", "cens", "dead", "surv", "se", + "lower", "upper", "cum_haz", + "hazard", "density", "mid_int", "life", "proplife") + expect_true(all(required_cols %in% colnames(gg_dta))) +}) + +test_that("kaplan retains only event rows (dead > 0)", { + gg_dta <- kaplan(interval = "time", censor = "status", data = pbc_dta) + expect_true(all(gg_dta$dead > 0)) +}) + +test_that("kaplan survival is monotonically non-increasing", { + gg_dta <- kaplan(interval = "time", censor = "status", data = pbc_dta) + expect_true(all(diff(gg_dta$surv) <= 0)) +}) + +test_that("kaplan survival is bounded in [0, 1]", { + gg_dta <- kaplan(interval = "time", censor = "status", data = pbc_dta) + expect_true(all(gg_dta$surv >= 0)) + expect_true(all(gg_dta$surv <= 1)) +}) + +test_that("kaplan with stratification adds groups column", { + pbc_strat <- pbc_dta + pbc_strat$treatment <- factor(pbc_strat$treatment) + gg_dta <- kaplan(interval = "time", censor = "status", + data = pbc_strat, by = "treatment") + expect_true("groups" %in% colnames(gg_dta)) + # Both treatment levels should appear + expect_true(length(unique(gg_dta$groups)) >= 2L) +}) + +test_that("kaplan life column is non-decreasing and proplife is in [0, 1]", { + # Regression test: the trapezoidal formula L(t_i) = L(t_{i-1}) + (S(t_{i-1}) + S(t_i))/2 * Δt + # must produce a monotonically non-decreasing cumulative integral. + # The old Adams-Bashforth formula could produce this too, but values were + # numerically wrong (over-estimated). + gg_dta <- kaplan(interval = "time", censor = "status", data = pbc_dta) + expect_true(all(diff(gg_dta$life) >= 0), + info = "life must be non-decreasing (cumulative area under S(t))") + expect_true(all(gg_dta$proplife >= 0), + info = "proplife must be >= 0") + expect_true(all(gg_dta$proplife <= 1 + .Machine$double.eps^0.5), + info = "proplife must be <= 1 (area under S(t) <= t * 1)") +}) + +test_that("kaplan with character (non-factor) by uses unique() labels", { + # .label_strata() has two code paths: levels() for factors, unique() for + # character/numeric. This test exercises the unique() path. + pbc_strat <- pbc_dta + pbc_strat$trt_chr <- as.character(pbc_strat$treatment) + pbc_strat <- pbc_strat[!is.na(pbc_strat$trt_chr), ] + + gg_dta <- kaplan(interval = "time", censor = "status", + data = pbc_strat, by = "trt_chr") + expect_true("groups" %in% colnames(gg_dta)) + expect_true(length(unique(gg_dta$groups)) >= 2L) +}) + +test_that("kaplan plot returns a ggplot", { + gg_dta <- kaplan(interval = "time", censor = "status", data = pbc_dta) + expect_s3_class(plot(gg_dta), "ggplot") +}) + +test_that("kaplan plot with error = 'none' returns a ggplot", { + gg_dta <- kaplan(interval = "time", censor = "status", data = pbc_dta) + expect_s3_class(plot(gg_dta, error = "none"), "ggplot") +}) + +## ---- nelson() -------------------------------------------------------------- + +test_that("nelson returns a gg_survival data frame", { + gg_dta <- nelson(interval = "time", censor = "status", data = pbc_dta) + expect_s3_class(gg_dta, "gg_survival") + expect_s3_class(gg_dta, "data.frame") +}) + +test_that("nelson output has required columns", { + gg_dta <- nelson(interval = "time", censor = "status", data = pbc_dta) + required_cols <- c("time", "n", "cens", "dead", "surv", "se", + "lower", "upper", "cum_haz", + "hazard", "density", "mid_int", "life", "proplife") + expect_true(all(required_cols %in% colnames(gg_dta))) +}) + +test_that("nelson retains only event rows (dead > 0)", { + gg_dta <- nelson(interval = "time", censor = "status", data = pbc_dta) + expect_true(all(gg_dta$dead > 0)) +}) + +test_that("nelson cum_haz is non-decreasing", { + gg_dta <- nelson(interval = "time", censor = "status", data = pbc_dta) + expect_true(all(diff(gg_dta$cum_haz) >= 0)) +}) + +test_that("nelson with stratification adds groups column", { + pbc_strat <- pbc_dta + pbc_strat$treatment <- factor(pbc_strat$treatment) + gg_dta <- nelson(interval = "time", censor = "status", + data = pbc_strat, by = "treatment") + expect_true("groups" %in% colnames(gg_dta)) + expect_true(length(unique(gg_dta$groups)) >= 2L) +}) + +test_that("nelson and kaplan agree on survival estimates", { + kap <- kaplan(interval = "time", censor = "status", data = pbc_dta) + nel <- nelson(interval = "time", censor = "status", data = pbc_dta) + # Both should produce estimates at the same time points + common_times <- intersect(kap$time, nel$time) + expect_true(length(common_times) > 0L) + kap_surv <- kap$surv[kap$time %in% common_times] + nel_surv <- nel$surv[nel$time %in% common_times] + # KM and NA estimates are equivalent for the same data + expect_equal(kap_surv, nel_surv, tolerance = 1e-6) +}) + +test_that("nelson life column is non-decreasing and proplife is in [0, 1]", { + gg_dta <- nelson(interval = "time", censor = "status", data = pbc_dta) + expect_true(all(diff(gg_dta$life) >= 0), + info = "life must be non-decreasing") + expect_true(all(gg_dta$proplife >= 0)) + expect_true(all(gg_dta$proplife <= 1 + .Machine$double.eps^0.5)) +}) + +test_that("nelson with character (non-factor) by uses unique() labels", { + pbc_strat <- pbc_dta + pbc_strat$trt_chr <- as.character(pbc_strat$treatment) + pbc_strat <- pbc_strat[!is.na(pbc_strat$trt_chr), ] + + gg_dta <- nelson(interval = "time", censor = "status", + data = pbc_strat, by = "trt_chr") + expect_true("groups" %in% colnames(gg_dta)) + expect_true(length(unique(gg_dta$groups)) >= 2L) +}) + +test_that("nelson plot returns a ggplot", { + gg_dta <- nelson(interval = "time", censor = "status", data = pbc_dta) + expect_s3_class(plot(gg_dta), "ggplot") +}) + +## ---- bootstrap_survival() -------------------------------------------------- + +# Helper: build wide-form survival data (obs x time-point columns) that +# bootstrap_survival() expects. gg_rfsrc() without conf.int pivots to long +# form, so we construct the wide matrix directly from the rfsrc internals. +.make_wide_surv <- function(rfsrc_obj) { + wide <- data.frame(rfsrc_obj$survival.oob) + colnames(wide) <- as.character(rfsrc_obj$time.interest) + wide$obs_id <- seq_len(nrow(wide)) + wide$event <- as.logical(rfsrc_obj$yvar[, 2L]) + wide +} + +test_that("bootstrap_survival returns a data frame with correct columns (95% CI)", { + data(pbc, package = "randomForestSRC") + pbc_sub <- pbc[, c("days", "status", "treatment", "age", "bili", "albumin")] + pbc_sub$time <- pbc_sub$days / 364.25 + + set.seed(1L) + rfsrc_pbc <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = pbc_sub, ntree = 50L) + + wide_dta <- .make_wide_surv(rfsrc_pbc) + n_timepts <- length(rfsrc_pbc$time.interest) + level_set <- c(0.025, 0.975) + + set.seed(42L) + result <- ggRandomForests:::bootstrap_survival(wide_dta, 50L, level_set) + + expect_s3_class(result, "data.frame") + expect_equal(colnames(result), c("value", "lower", "upper", "median", "mean")) + expect_equal(nrow(result), n_timepts) + expect_true(all(result$lower <= result$mean + 1e-10)) + expect_true(all(result$mean <= result$upper + 1e-10)) +}) + +test_that("bootstrap_survival lower bound <= median <= upper bound", { + data(pbc, package = "randomForestSRC") + pbc_sub <- pbc[, c("days", "status", "treatment", "age", "bili", "albumin")] + pbc_sub$time <- pbc_sub$days / 364.25 + + set.seed(2L) + rfsrc_pbc <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = pbc_sub, ntree = 50L) + + wide_dta <- .make_wide_surv(rfsrc_pbc) + level_set <- c(0.025, 0.975) + + set.seed(42L) + result <- ggRandomForests:::bootstrap_survival(wide_dta, 50L, level_set) + + expect_true(all(result$lower <= result$median + 1e-10)) + expect_true(all(result$median <= result$upper + 1e-10)) +}) + +test_that("bootstrap_survival time points match rfsrc$time.interest", { + data(pbc, package = "randomForestSRC") + pbc_sub <- pbc[, c("days", "status", "treatment", "age", "bili", "albumin")] + pbc_sub$time <- pbc_sub$days / 364.25 + + set.seed(3L) + rfsrc_pbc <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = pbc_sub, ntree = 50L) + + wide_dta <- .make_wide_surv(rfsrc_pbc) + expected_times <- rfsrc_pbc$time.interest + level_set <- c(0.025, 0.975) + + set.seed(42L) + result <- ggRandomForests:::bootstrap_survival(wide_dta, 50L, level_set) + + expect_equal(result$value, expected_times) +}) diff --git a/tests/testthat/test_quantile_pts.R b/tests/testthat/test_quantile_pts.R index 1a854400..e5f48284 100644 --- a/tests/testthat/test_quantile_pts.R +++ b/tests/testthat/test_quantile_pts.R @@ -1,11 +1,10 @@ # testthat for quantile_pts function -context("quantile_pts tests") test_that("cutting a vector at evenly space points", { data(Boston, package = "MASS") boston <- Boston boston$chas <- as.logical(boston$chas) - + rfsrc_boston <- randomForestSRC::rfsrc( medv ~ ., @@ -20,26 +19,26 @@ test_that("cutting a vector at evenly space points", { quantile_pts(rfsrc_boston$xvar$rm, groups = 6, intervals = TRUE) - - expect_is(rm_pts, "numeric") + + expect_true(is.numeric(rm_pts)) expect_true(length(rm_pts) >= 2) expect_true(rm_pts[1] < min(rfsrc_boston$xvar$rm)) expect_equal(max(rm_pts), max(rfsrc_boston$xvar$rm)) expect_true(all(diff(rm_pts) > 0)) - + # Use cut to create the intervals rm_grp <- cut(rfsrc_boston$xvar$rm, breaks = rm_pts) - expect_is(rm_grp, "factor") + expect_s3_class(rm_grp, "factor") expect_equal(length(rm_grp), length(rfsrc_boston$xvar$rm)) expect_equal(length(levels(rm_grp)), length(rm_pts) - 1) - + rm_pts <- quantile_pts(rfsrc_boston$xvar$rm, groups = 6) - - expect_is(rm_pts, "numeric") + + expect_true(is.numeric(rm_pts)) expect_equal(length(rm_pts), 6) # First quantile coincides with the minimum value expect_equal(min(rfsrc_boston$xvar$rm), min(rm_pts), tolerance = 1.e-7) - + # Test the number of points for lots of groups. rm_pts <- quantile_pts(rfsrc_boston$xvar$rm, groups = nrow(rfsrc_boston$xvar) + 2) diff --git a/tests/testthat/test_randomForest_helpers.R b/tests/testthat/test_randomForest_helpers.R index 28c0763d..426209da 100644 --- a/tests/testthat/test_randomForest_helpers.R +++ b/tests/testthat/test_randomForest_helpers.R @@ -1,4 +1,3 @@ -context("randomForest helper coverage") skip_if_not_installed("randomForest") @@ -18,7 +17,7 @@ test_that(".rf_recover_model_frame rebuilds subsetted data", { keep.forest = TRUE ) info <- ggRandomForests:::`.rf_recover_model_frame`(rf_subset) - expect_is(info, "list") + expect_type(info, "list") expect_true(all(c("Sepal.Length", "Sepal.Width") %in% names(info$model_frame))) expect_equal(info$response_name, "Species") }) diff --git a/tests/testthat/test_shift.R b/tests/testthat/test_shift.R index 5bd04680..8c8fa418 100644 --- a/tests/testthat/test_shift.R +++ b/tests/testthat/test_shift.R @@ -1,11 +1,11 @@ # testthat for shift function -context("shift tests") test_that("lead or lag a vector", { - expect_that(shift(1:10, 2), is_identical_to(c(3:10, NA, NA))) - expect_that(shift(1:10, -2), is_identical_to(c(NA, NA, 1:8))) - expect_that(shift(1:10, 0), is_identical_to(1:10)) - expect_that(shift(1:10, 0), is_identical_to(1:10)) - expect_that(shift(1:10, 1:2), is_identical_to(cbind(c(2:10, NA), - c(3:10, NA, NA)))) + expect_identical(shift(1:10, 2), c(3:10, NA, NA)) + expect_identical(shift(1:10, -2), c(NA, NA, 1:8)) + expect_identical(shift(1:10, 0), 1:10) + expect_identical( + shift(1:10, 1:2), + cbind(c(2:10, NA), c(3:10, NA, NA)) + ) }) diff --git a/tests/testthat/test_snapshots.R b/tests/testthat/test_snapshots.R new file mode 100644 index 00000000..30e32072 --- /dev/null +++ b/tests/testthat/test_snapshots.R @@ -0,0 +1,158 @@ +# Visual regression tests using vdiffr. +# +# These tests generate reference SVGs on the first run and compare on +# subsequent runs. To regenerate snapshots (e.g. after an intentional +# visual change) run: vdiffr::manage_cases() +# +# All models are built with set.seed() to ensure reproducible plots. + +# Skip the entire file if vdiffr is not available (e.g. on CRAN). +if (!requireNamespace("vdiffr", quietly = TRUE)) { + skip("vdiffr not installed") +} + +# Guard: only register snapshot tests when explicitly opted in. +# Set VDIFFR_RUN_TESTS=true to generate or compare visual baselines. +# This avoids failures on fresh checkouts (no _snaps/ directory) and in CI. +# +# To generate baselines locally: +# 1. Run Sys.setenv(VDIFFR_RUN_TESTS = "true") +# 2. Run devtools::test(filter = "snapshots") +# 3. Call testthat::snapshot_accept() +# 4. Commit tests/testthat/_snaps/ to the repo +if (identical(Sys.getenv("VDIFFR_RUN_TESTS"), "true")) { + +## ---- Shared fixtures ------------------------------------------------------- + +# Classification — iris +local({ + set.seed(42L) + rfsrc_iris <- randomForestSRC::rfsrc( + Species ~ ., + data = iris, + importance = TRUE, + tree.err = TRUE, + ntree = 100L + ) + + test_that("snapshot: gg_vimp classification", { + gg_dta <- gg_vimp(rfsrc_iris) + vdiffr::expect_doppelganger("gg_vimp classification rfsrc", plot(gg_dta)) + }) + + test_that("snapshot: gg_error classification", { + gg_dta <- gg_error(rfsrc_iris) + vdiffr::expect_doppelganger("gg_error classification rfsrc", plot(gg_dta)) + }) + + test_that("snapshot: gg_roc classification rfsrc", { + gg_dta <- gg_roc(rfsrc_iris, which_outcome = 1L) + vdiffr::expect_doppelganger("gg_roc classification rfsrc", plot(gg_dta)) + }) + + test_that("snapshot: gg_rfsrc classification", { + gg_dta <- gg_rfsrc(rfsrc_iris) + vdiffr::expect_doppelganger("gg_rfsrc classification rfsrc", plot(gg_dta)) + }) +}) + +# Regression — Boston housing +local({ + data(Boston, package = "MASS") + boston <- Boston + boston$chas <- as.logical(boston$chas) + + set.seed(42L) + rfsrc_boston <- randomForestSRC::rfsrc( + medv ~ ., + data = boston, + importance = TRUE, + tree.err = TRUE, + ntree = 100L + ) + + test_that("snapshot: gg_vimp regression", { + gg_dta <- gg_vimp(rfsrc_boston) + vdiffr::expect_doppelganger("gg_vimp regression rfsrc", plot(gg_dta)) + }) + + test_that("snapshot: gg_error regression", { + gg_dta <- gg_error(rfsrc_boston) + vdiffr::expect_doppelganger("gg_error regression rfsrc", plot(gg_dta)) + }) + + test_that("snapshot: gg_rfsrc regression", { + gg_dta <- gg_rfsrc(rfsrc_boston) + vdiffr::expect_doppelganger("gg_rfsrc regression rfsrc", plot(gg_dta)) + }) + + test_that("snapshot: gg_variable regression single xvar", { + gg_rfsrc_dta <- gg_rfsrc(rfsrc_boston) + xvar <- "lstat" + gg_dta <- gg_variable(rfsrc_boston, time = NULL) + vdiffr::expect_doppelganger( + "gg_variable regression lstat", + plot(gg_dta, xvar = xvar) + ) + }) +}) + +# Survival — pbc +local({ + data(pbc, package = "randomForestSRC") + pbc$time <- pbc$days / 364.25 + pbc_sub <- pbc[, c("time", "status", "treatment", "age", "bili", "albumin")] + + set.seed(42L) + rfsrc_pbc <- randomForestSRC::rfsrc( + Surv(time, status) ~ ., + data = pbc_sub, + importance = TRUE, + tree.err = TRUE, + ntree = 100L + ) + + test_that("snapshot: gg_vimp survival", { + gg_dta <- gg_vimp(rfsrc_pbc) + vdiffr::expect_doppelganger("gg_vimp survival rfsrc", plot(gg_dta)) + }) + + test_that("snapshot: gg_error survival", { + gg_dta <- gg_error(rfsrc_pbc) + vdiffr::expect_doppelganger("gg_error survival rfsrc", plot(gg_dta)) + }) + + test_that("snapshot: gg_rfsrc survival no CI", { + gg_dta <- gg_rfsrc(rfsrc_pbc) + vdiffr::expect_doppelganger("gg_rfsrc survival no ci", plot(gg_dta)) + }) + + test_that("snapshot: gg_rfsrc survival with bootstrap CI", { + set.seed(1L) + gg_dta <- gg_rfsrc(rfsrc_pbc, conf.int = 0.95, bs_samples = 50L) + vdiffr::expect_doppelganger("gg_rfsrc survival bootstrap ci", plot(gg_dta)) + }) +}) + +# randomForest classification — iris +local({ + set.seed(42L) + rf_iris <- randomForest::randomForest(Species ~ ., data = iris) + + test_that("snapshot: gg_vimp randomForest classification", { + gg_dta <- gg_vimp(rf_iris) + vdiffr::expect_doppelganger("gg_vimp classification rf", plot(gg_dta)) + }) + + test_that("snapshot: gg_error randomForest classification", { + gg_dta <- gg_error(rf_iris) + vdiffr::expect_doppelganger("gg_error classification rf", plot(gg_dta)) + }) + + test_that("snapshot: gg_roc randomForest classification", { + gg_dta <- gg_roc(rf_iris, which_outcome = 1L) + vdiffr::expect_doppelganger("gg_roc classification rf", plot(gg_dta)) + }) +}) + +} # end CI guard diff --git a/tests/testthat/test_surv_partial.R b/tests/testthat/test_surv_partial.R index 601d02a5..d1efb63b 100644 --- a/tests/testthat/test_surv_partial.R +++ b/tests/testthat/test_surv_partial.R @@ -1,8 +1,7 @@ # Tests for surv_partial.rfsrc -context("surv_partial.rfsrc tests") # Survival formula helper (rfsrc requires Surv to be in local scope) -Surv <- survival::Surv +Surv <- survival::Surv # nolint: object_name_linter test_that("surv_partial.rfsrc returns list with one element per variable", { skip_if_not_installed("randomForestSRC") @@ -104,3 +103,75 @@ test_that("surv_partial.rfsrc npts argument limits unique x values", { expect_type(result, "list") expect_length(result, 1) }) + +## ---- Shared fixture (built once, reused below) ---------------------------- + +local({ + skip_if_not_installed("randomForestSRC") + data(veteran, package = "randomForestSRC") + set.seed(42) + v.obj <- randomForestSRC::rfsrc( + Surv(time, status) ~ ., + data = veteran, + ntree = 50, + nsplit = 5 + ) + + test_that("surv_partial.rfsrc dta element has x and yhat columns", { + result <- surv_partial.rfsrc(v.obj, var_list = "age", partial.type = "mort") + dta <- result[[1]]$dta + expect_true(!is.null(dta)) + # get.partial.plot.data returns a list with $x (predictor values) and + # $yhat (matrix of partial predictions, one column per time point) + expect_true("x" %in% names(dta)) + expect_true("yhat" %in% names(dta)) + }) + + test_that("surv_partial.rfsrc npts limits evaluation points", { + npts_requested <- 5L + result <- surv_partial.rfsrc(v.obj, + var_list = "age", + partial.type = "mort", + npts = npts_requested + ) + dta <- result[[1]]$dta + # The number of evaluation points should be <= npts_requested + expect_true(length(dta$x) <= npts_requested) + }) + + test_that("surv_partial.rfsrc mort and surv partial.types return different yhat scales", { + res_mort <- surv_partial.rfsrc(v.obj, var_list = "age", partial.type = "mort") + res_surv <- surv_partial.rfsrc(v.obj, var_list = "age", partial.type = "surv") + + yhat_mort <- res_mort[[1]]$dta$yhat + yhat_surv <- res_surv[[1]]$dta$yhat + + # Mortality (years lost) and survival (probability) are on different scales; + # survival probabilities are bounded [0, 1]; mortality values are unbounded + if (is.matrix(yhat_surv)) { + expect_true(all(yhat_surv >= 0 & yhat_surv <= 1 + 1e-8)) + } + # The two types should not produce identical predictions + expect_false(identical(yhat_mort, yhat_surv)) + }) + + test_that("surv_partial.rfsrc verbose: prints variable name during computation", { + expect_output( + surv_partial.rfsrc(v.obj, var_list = "age", partial.type = "mort"), + regexp = "age" + ) + }) + + test_that("surv_partial.rfsrc errors on invalid variable name", { + expect_error( + surv_partial.rfsrc(v.obj, var_list = "nonexistent_var", partial.type = "mort") + ) + }) + + test_that("surv_partial.rfsrc result names match requested var_list order", { + vars <- c("karno", "age", "diagtime") + result <- surv_partial.rfsrc(v.obj, var_list = vars, partial.type = "mort") + names_out <- vapply(result, function(x) x$name, character(1L)) + expect_equal(names_out, vars) + }) +}) diff --git a/tests/testthat/test_varpro_feature_names.R b/tests/testthat/test_varpro_feature_names.R index 29739c37..8413bc5b 100644 --- a/tests/testthat/test_varpro_feature_names.R +++ b/tests/testthat/test_varpro_feature_names.R @@ -1,5 +1,4 @@ # Tests for varpro_feature_names -context("varpro_feature_names tests") test_that("varpro_feature_names returns exact matches unchanged", { dataset <- data.frame(age = 1:5, sex = 1:5, weight = 1:5)