diff --git a/NAMESPACE b/NAMESPACE index 99acd297..22061aa7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -7,6 +7,9 @@ S3method(apply_transformations,matrix) S3method(diagnostic_factor,neff_ratio) S3method(diagnostic_factor,rhat) S3method(log_posterior,CmdStanMCMC) +S3method(log_posterior,draws_array) +S3method(log_posterior,draws_df) +S3method(log_posterior,draws_matrix) S3method(log_posterior,stanfit) S3method(log_posterior,stanreg) S3method(melt_mcmc,matrix) @@ -21,6 +24,9 @@ S3method(num_iters,mcmc_array) S3method(num_params,data.frame) S3method(num_params,mcmc_array) S3method(nuts_params,CmdStanMCMC) +S3method(nuts_params,draws_array) +S3method(nuts_params,draws_df) +S3method(nuts_params,draws_matrix) S3method(nuts_params,list) S3method(nuts_params,stanfit) S3method(nuts_params,stanreg) diff --git a/NEWS.md b/NEWS.md index fceb6a8f..590c3f08 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # bayesplot (development version) +* Added `log_posterior()` and `nuts_params()` methods for `draws_array`, `draws_df`, and `draws_matrix` objects. * `ppc_ecdf_overlay()`, `ppc_ecdf_overlay_grouped()`, and `ppd_ecdf_overlay()` now always use `geom_step()`. The `discrete` argument is deprecated. * Fixed missing `drop = FALSE` in `nuts_params.CmdStanMCMC()`. * Replace `apply()` with `storage.mode()` for integer-to-numeric matrix conversion in `validate_predictions()`. diff --git a/R/bayesplot-extractors.R b/R/bayesplot-extractors.R index 79f7966e..b10b9af4 100644 --- a/R/bayesplot-extractors.R +++ b/R/bayesplot-extractors.R @@ -103,14 +103,37 @@ log_posterior.stanreg <- function(object, inc_warmup = FALSE, ...) { #' @export #' @method log_posterior CmdStanMCMC log_posterior.CmdStanMCMC <- function(object, inc_warmup = FALSE, ...) { - lp <- object$draws("lp__", inc_warmup = inc_warmup) - lp <- reshape2::melt(lp) + log_posterior.draws_array(object$draws("lp__", inc_warmup = inc_warmup), ...) +} + +#' @rdname bayesplot-extractors +#' @export +#' @method log_posterior draws_array +log_posterior.draws_array <- function(object, ...) { + if (!"lp__" %in% posterior::variables(object)) { + abort("draws object does not contain an 'lp__' variable.") + } + lp <- reshape2::melt(object[, , "lp__", drop = FALSE]) lp$variable <- NULL lp <- dplyr::rename_with(lp, capitalize_first) validate_df_classes(lp[, c("Chain", "Iteration", "Value")], c("integer", "integer", "numeric")) } +#' @rdname bayesplot-extractors +#' @export +#' @method log_posterior draws_df +log_posterior.draws_df <- function(object, ...) { + log_posterior.draws_array(posterior::as_draws_array(object), ...) +} + +#' @rdname bayesplot-extractors +#' @export +#' @method log_posterior draws_matrix +log_posterior.draws_matrix <- function(object, ...) { + log_posterior.draws_array(posterior::as_draws_array(object), ...) +} + #' @rdname bayesplot-extractors #' @export @@ -173,10 +196,28 @@ nuts_params.list <- function(object, pars = NULL, ...) { #' @export #' @method nuts_params CmdStanMCMC nuts_params.CmdStanMCMC <- function(object, pars = NULL, ...) { - arr <- object$sampler_diagnostics() - if (!is.null(pars)) { - arr <- arr[,, pars, drop = FALSE] + nuts_params.draws_array(object$sampler_diagnostics(), pars = pars, ...) +} + +#' @rdname bayesplot-extractors +#' @export +#' @method nuts_params draws_array +nuts_params.draws_array <- function(object, pars = NULL, ...) { + vars <- posterior::variables(object) + if (is.null(pars)) { + pars <- grep("__$", vars, value = TRUE) + pars <- setdiff(pars, "lp__") + if (!length(pars)) { + abort("draws object does not contain any NUTS sampler diagnostic variables (names ending in '__').") + } + } else { + missing_pars <- setdiff(pars, vars) + if (length(missing_pars)) { + abort(paste0("Variables not found in draws object: ", + paste(missing_pars, collapse = ", "), ".")) + } } + arr <- object[, , pars, drop = FALSE] out <- reshape2::melt(arr) colnames(out)[colnames(out) == "variable"] <- "parameter" out <- dplyr::rename_with(out, capitalize_first) @@ -184,6 +225,20 @@ nuts_params.CmdStanMCMC <- function(object, pars = NULL, ...) { c("integer", "integer", "factor", "numeric")) } +#' @rdname bayesplot-extractors +#' @export +#' @method nuts_params draws_df +nuts_params.draws_df <- function(object, pars = NULL, ...) { + nuts_params.draws_array(posterior::as_draws_array(object), pars = pars, ...) +} + +#' @rdname bayesplot-extractors +#' @export +#' @method nuts_params draws_matrix +nuts_params.draws_matrix <- function(object, pars = NULL, ...) { + nuts_params.draws_array(posterior::as_draws_array(object), pars = pars, ...) +} + #' @rdname bayesplot-extractors #' @export diff --git a/man/bayesplot-extractors.Rd b/man/bayesplot-extractors.Rd index 593a2b0e..c1f41435 100644 --- a/man/bayesplot-extractors.Rd +++ b/man/bayesplot-extractors.Rd @@ -9,10 +9,16 @@ \alias{log_posterior.stanfit} \alias{log_posterior.stanreg} \alias{log_posterior.CmdStanMCMC} +\alias{log_posterior.draws_array} +\alias{log_posterior.draws_df} +\alias{log_posterior.draws_matrix} \alias{nuts_params.stanfit} \alias{nuts_params.stanreg} \alias{nuts_params.list} \alias{nuts_params.CmdStanMCMC} +\alias{nuts_params.draws_array} +\alias{nuts_params.draws_df} +\alias{nuts_params.draws_matrix} \alias{rhat.stanfit} \alias{rhat.stanreg} \alias{rhat.CmdStanMCMC} @@ -35,6 +41,12 @@ neff_ratio(object, ...) \method{log_posterior}{CmdStanMCMC}(object, inc_warmup = FALSE, ...) +\method{log_posterior}{draws_array}(object, ...) + +\method{log_posterior}{draws_df}(object, ...) + +\method{log_posterior}{draws_matrix}(object, ...) + \method{nuts_params}{stanfit}(object, pars = NULL, inc_warmup = FALSE, ...) \method{nuts_params}{stanreg}(object, pars = NULL, inc_warmup = FALSE, ...) @@ -43,6 +55,12 @@ neff_ratio(object, ...) \method{nuts_params}{CmdStanMCMC}(object, pars = NULL, ...) +\method{nuts_params}{draws_array}(object, pars = NULL, ...) + +\method{nuts_params}{draws_df}(object, pars = NULL, ...) + +\method{nuts_params}{draws_matrix}(object, pars = NULL, ...) + \method{rhat}{stanfit}(object, pars = NULL, ...) \method{rhat}{stanreg}(object, pars = NULL, regex_pars = NULL, ...) diff --git a/tests/testthat/test-extractors.R b/tests/testthat/test-extractors.R index 355e1418..54e50ace 100644 --- a/tests/testthat/test-extractors.R +++ b/tests/testthat/test-extractors.R @@ -145,3 +145,55 @@ test_that("cmdstanr methods work", { expect_equal(range(np_one$Chain), c(1, 2)) expect_true(all(np_one$Value == 0)) }) + + +# draws object methods ---------------------------------------------------- +make_draws_array <- function(iter = 50, chains = 2) { + vars <- c("mu", "sigma", "lp__", "accept_stat__", "stepsize__", + "treedepth__", "n_leapfrog__", "divergent__", "energy__") + arr <- array(stats::rnorm(iter * chains * length(vars)), + dim = c(iter, chains, length(vars)), + dimnames = list(NULL, NULL, vars)) + posterior::as_draws_array(arr) +} + +test_that("log_posterior methods for draws objects return correct structure", { + d <- make_draws_array(iter = 50, chains = 2) + + lp_arr <- log_posterior(d) + expect_identical(colnames(lp_arr), c("Chain", "Iteration", "Value")) + expect_equal(length(unique(lp_arr$Iteration)), 50) + expect_equal(length(unique(lp_arr$Chain)), 2) + + lp_df <- log_posterior(posterior::as_draws_df(d)) + lp_mat <- log_posterior(posterior::as_draws_matrix(d)) + expect_equal(lp_df$Value, lp_arr$Value) + expect_equal(lp_mat$Value, lp_arr$Value) +}) + +test_that("nuts_params methods for draws objects return correct structure", { + d <- make_draws_array(iter = 50, chains = 2) + + np <- nuts_params(d) + expect_identical(colnames(np), c("Chain", "Iteration", "Parameter", "Value")) + expect_setequal( + levels(np$Parameter), + c("accept_stat__", "stepsize__", "treedepth__", + "n_leapfrog__", "divergent__", "energy__") + ) + expect_false("lp__" %in% levels(np$Parameter)) + + np_one <- nuts_params(d, pars = "divergent__") + expect_identical(levels(np_one$Parameter), "divergent__") + + np_df <- nuts_params(posterior::as_draws_df(d)) + expect_equal(np_df$Value, np$Value) +}) + +test_that("draws-object extractors error on missing variables", { + d <- make_draws_array() + bare <- d[, , c("mu", "sigma"), drop = FALSE] + expect_error(log_posterior(bare), "lp__") + expect_error(nuts_params(bare), "sampler diagnostic") + expect_error(nuts_params(d, pars = "nope__"), "nope__") +})