Skip to content

Commit c062d6c

Browse files
authored
Merge pull request #1121 from stan-dev/as_cmdstan_fit-with-variables
Allow selecting subset of variables when using `as_cmdstan_fit()`
2 parents 34ccd06 + 41a96d9 commit c062d6c

5 files changed

Lines changed: 125 additions & 64 deletions

File tree

R/csv.R

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ read_cmdstan_csv <- function(files,
204204
lp = lp
205205
))
206206
}
207+
user_variables_subset <- FALSE
207208
if (is.null(variables)) { # variables = NULL returns all
208209
variables <- metadata$variables
209210
} else if (!any(nzchar(variables))) { # if variables = "" returns none
@@ -215,6 +216,7 @@ read_cmdstan_csv <- function(files,
215216
paste(res$not_found, collapse = ", "), call. = FALSE)
216217
}
217218
variables <- unrepair_variable_names(res$matching)
219+
user_variables_subset <- TRUE
218220
}
219221
if (is.null(sampler_diagnostics)) {
220222
sampler_diagnostics <- metadata$sampler_diagnostics
@@ -281,8 +283,13 @@ read_cmdstan_csv <- function(files,
281283
draws_list_id <- length(draws) + 1
282284
warmup_draws_list_id <- length(warmup_draws) + 1
283285
if (metadata$method == "pathfinder") {
284-
metadata$variables = union(metadata$sampler_diagnostics, metadata$variables)
285-
variables = union(metadata$sampler_diagnostics, variables)
286+
metadata$variables <- union(metadata$sampler_diagnostics, metadata$variables)
287+
if (!user_variables_subset) {
288+
# because for pathfinder variables and diagnostics are read in together,
289+
# if user hasn't selected a custom subset of variables we need to include
290+
# all diagnostics
291+
variables <- union(metadata$sampler_diagnostics, variables)
292+
}
286293
}
287294
suppressWarnings(
288295
draws[[draws_list_id]] <- data.table::fread(
@@ -489,10 +496,24 @@ read_sample_csv <- function(files,
489496
#' `TRUE` but set to `FALSE` to avoid checking for problems with divergences
490497
#' and treedepth.
491498
#'
492-
as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("cmdstanr_draws_format")) {
493-
csv_contents <- read_cmdstan_csv(files, format = format)
499+
as_cmdstan_fit <- function(files,
500+
variables = NULL,
501+
check_diagnostics = TRUE,
502+
format = getOption("cmdstanr_draws_format")) {
503+
csv_contents <- read_cmdstan_csv(files, variables = variables, format = format)
504+
method <- csv_contents$metadata$method
505+
if (!is.null(variables)) {
506+
if (method == "sample") {
507+
variables <- posterior::variables(csv_contents$post_warmup_draws)
508+
} else if (method == "optimize") {
509+
variables <- posterior::variables(csv_contents$point_estimates)
510+
} else { # variational, laplace, pathfinder
511+
variables <- posterior::variables(csv_contents$draws)
512+
}
513+
csv_contents$metadata$variables <- variables
514+
}
494515
switch(
495-
csv_contents$metadata$method,
516+
method,
496517
"sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
497518
"optimize" = CmdStanMLE_CSV$new(csv_contents, files),
498519
"variational" = CmdStanVB_CSV$new(csv_contents, files),
@@ -638,6 +659,7 @@ for (method in unavailable_methods_CmdStanFit_CSV) {
638659
CmdStanMLE_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
639660
CmdStanVB_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
640661
CmdStanLaplace_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
662+
CmdStanPathfinder_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
641663
}
642664

643665

man/read_cmdstan_csv.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/helper-models.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ testing_fit <-
2424
"optimize",
2525
"laplace",
2626
"variational",
27+
"pathfinder",
2728
"generate_quantities"),
2829
seed = 123,
2930
...) {

tests/testthat/test-csv.R

Lines changed: 96 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ fit_logistic_optimize <- testing_fit("logistic", method = "optimize", seed = 123
88
fit_logistic_variational <- testing_fit("logistic", method = "variational", seed = 123)
99
fit_logistic_variational_short <- testing_fit("logistic", method = "variational", output_samples = 100, seed = 123)
1010
fit_logistic_laplace <- testing_fit("logistic", method = "laplace", seed = 123)
11+
fit_logistic_pathfinder <- testing_fit("logistic", method = "pathfinder", seed = 123)
1112

1213
fit_bernoulli_diag_e_no_samples <- testing_fit("bernoulli", method = "sample",
1314
seed = 123, chains = 2, iter_sampling = 0, metric = "diag_e")
@@ -524,64 +525,6 @@ test_that("time from read_cmdstan_csv matches time from fit$time()", {
524525
)
525526
})
526527

527-
test_that("as_cmdstan_fit creates fitted model objects from csv", {
528-
fits <- list(
529-
mle = as_cmdstan_fit(fit_logistic_optimize$output_files()),
530-
vb = as_cmdstan_fit(fit_logistic_variational$output_files()),
531-
laplace = as_cmdstan_fit(fit_logistic_laplace$output_files()),
532-
mcmc = as_cmdstan_fit(fit_logistic_thin_1$output_files())
533-
)
534-
for (class in names(fits)) {
535-
fit <- fits[[class]]
536-
class_name <- if (class == "laplace") "Laplace" else toupper(class)
537-
checkmate::expect_r6(fit, classes = paste0("CmdStan", class_name, "_CSV"))
538-
expect_s3_class(fit$draws(), "draws")
539-
checkmate::expect_numeric(fit$lp())
540-
expect_output(fit$print(), "variable")
541-
expect_length(fit$output_files(), if (class == "mcmc") fit$num_chains() else 1)
542-
expect_s3_class(fit$summary(), "draws_summary")
543-
544-
if (class == "mcmc") {
545-
expect_s3_class(fit$sampler_diagnostics(), "draws_array")
546-
expect_type(fit$inv_metric(), "list")
547-
expect_equal(fit$time()$total, NA_integer_)
548-
expect_s3_class(fit$time()$chains, "data.frame")
549-
}
550-
if (class == "mle") {
551-
checkmate::expect_numeric(fit$mle())
552-
}
553-
if (class == "vb") {
554-
checkmate::expect_numeric(fit$lp_approx())
555-
}
556-
if (class == "laplace") {
557-
checkmate::expect_numeric(fit$lp_approx())
558-
}
559-
560-
for (method in unavailable_methods_CmdStanFit_CSV) {
561-
if (!(method == "time" && class == "mcmc")) {
562-
expect_error(fit[[method]](), "This method is not available")
563-
}
564-
}
565-
}
566-
})
567-
568-
test_that("as_cmdstan_fit can check MCMC diagnostics", {
569-
fit_schools <- suppressMessages(
570-
testing_fit("schools", chains = 2,
571-
adapt_delta = 0.5, max_treedepth = 4,
572-
show_messages = FALSE)
573-
)
574-
expect_message(
575-
as_cmdstan_fit(fit_schools$output_files()),
576-
"transitions ended with a divergence"
577-
)
578-
expect_message(
579-
as_cmdstan_fit(fit_schools$output_files()),
580-
"transitions hit the maximum treedepth"
581-
)
582-
expect_silent(as_cmdstan_fit(fit_schools$output_files(), check_diagnostics = FALSE))
583-
})
584-
585528
test_that("read_cmdstan_csv reads seed correctly", {
586529
opt <- read_cmdstan_csv(fit_bernoulli_optimize$output_files())
587530
vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files())
@@ -900,3 +843,98 @@ test_that("read_cmdstan_csv() works with tilde expansion", {
900843
tildified_path <- file.path("~", fs::path_rel(full_path, "~"))
901844
expect_no_error(read_cmdstan_csv(tildified_path))
902845
})
846+
847+
848+
test_that("as_cmdstan_fit creates fitted model objects from csv", {
849+
fits <- list(
850+
mle = as_cmdstan_fit(fit_logistic_optimize$output_files()),
851+
vb = as_cmdstan_fit(fit_logistic_variational$output_files()),
852+
laplace = as_cmdstan_fit(fit_logistic_laplace$output_files()),
853+
pathfinder = as_cmdstan_fit(fit_logistic_pathfinder$output_files()),
854+
mcmc = as_cmdstan_fit(fit_logistic_thin_1$output_files())
855+
)
856+
857+
for (class in names(fits)) {
858+
fit <- fits[[class]]
859+
if (class == "laplace") {
860+
class_name <- "Laplace"
861+
} else if (class == "pathfinder") {
862+
class_name <- "Pathfinder"
863+
} else {
864+
class_name <- toupper(class)
865+
}
866+
checkmate::expect_r6(fit, classes = paste0("CmdStan", class_name, "_CSV"))
867+
expect_s3_class(fit$draws(), "draws")
868+
checkmate::expect_numeric(fit$lp())
869+
expect_output(fit$print(), "variable")
870+
expect_length(fit$output_files(), if (class == "mcmc") fit$num_chains() else 1)
871+
expect_s3_class(fit$summary(), "draws_summary")
872+
873+
if (class == "mcmc") {
874+
expect_s3_class(fit$sampler_diagnostics(), "draws_array")
875+
expect_type(fit$inv_metric(), "list")
876+
expect_equal(fit$time()$total, NA_integer_)
877+
expect_s3_class(fit$time()$chains, "data.frame")
878+
}
879+
if (class == "mle") {
880+
checkmate::expect_numeric(fit$mle())
881+
}
882+
if (class %in% c("vb", "laplace", "pathfinder")) {
883+
checkmate::expect_numeric(fit$lp_approx())
884+
}
885+
for (method in unavailable_methods_CmdStanFit_CSV) {
886+
if (!(method == "time" && class == "mcmc")) {
887+
expect_error(fit[[method]](), "This method is not available", info = class)
888+
}
889+
}
890+
}
891+
})
892+
893+
test_that("as_cmdstan_fit can check MCMC diagnostics", {
894+
fit_schools <- suppressMessages(
895+
testing_fit("schools", chains = 2,
896+
adapt_delta = 0.5, max_treedepth = 4,
897+
show_messages = FALSE)
898+
)
899+
expect_message(
900+
as_cmdstan_fit(fit_schools$output_files()),
901+
"transitions ended with a divergence"
902+
)
903+
expect_message(
904+
as_cmdstan_fit(fit_schools$output_files()),
905+
"transitions hit the maximum treedepth"
906+
)
907+
expect_silent(as_cmdstan_fit(fit_schools$output_files(), check_diagnostics = FALSE))
908+
})
909+
910+
test_that("as_cmdstan_fit filters variables across methods", {
911+
mcmc_vars <- c("alpha", "beta[2]")
912+
mcmc <- as_cmdstan_fit(fit_logistic_thin_1$output_files(), variables = mcmc_vars)
913+
expect_equal(posterior::variables(mcmc$draws()), mcmc_vars)
914+
expect_equal(mcmc$summary()$variable, mcmc_vars)
915+
expect_equal(mcmc$metadata()$variables, mcmc_vars)
916+
917+
mle_vars <- c("beta[1]", "beta[3]")
918+
mle <- as_cmdstan_fit(fit_logistic_optimize$output_files(), variables = mle_vars)
919+
expect_equal(posterior::variables(mle$draws()), mle_vars)
920+
expect_equal(mle$summary()$variable, mle_vars)
921+
expect_equal(mle$metadata()$variables, mle_vars)
922+
923+
vb_vars <- "beta"
924+
vb <- as_cmdstan_fit(fit_logistic_variational$output_files(), variables = vb_vars)
925+
expect_equal(posterior::variables(vb$draws()), c("beta[1]", "beta[2]", "beta[3]"))
926+
expect_equal(vb$summary()$variable, c("beta[1]", "beta[2]", "beta[3]"))
927+
expect_equal(vb$metadata()$variables, c("beta[1]", "beta[2]", "beta[3]"))
928+
929+
laplace_vars <- "alpha"
930+
laplace <- as_cmdstan_fit(fit_logistic_laplace$output_files(), variables = laplace_vars)
931+
expect_equal(posterior::variables(laplace$draws()), laplace_vars)
932+
expect_equal(laplace$summary()$variable, laplace_vars)
933+
expect_equal(laplace$metadata()$variables, laplace_vars)
934+
935+
pathfinder_vars <- c("alpha", "beta[1]", "beta[3]")
936+
pathfinder <- as_cmdstan_fit(fit_logistic_pathfinder$output_files(), variables = pathfinder_vars)
937+
expect_equal(posterior::variables(pathfinder$draws()), pathfinder_vars)
938+
expect_equal(pathfinder$summary()$variable, pathfinder_vars)
939+
expect_equal(pathfinder$metadata()$variables, pathfinder_vars)
940+
})

vignettes/posterior.Rmd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ fit$summary(variables = c("mu", "tau"), mean, sd)
4949

5050
To summarize all variables with non-default functions, it is necessary to set explicitly set the variables argument, either to `NULL` or the full vector of variable names.
5151
```{r}
52-
fit$metadata()$model_params
5352
fit$summary(variables = NULL, "mean", "median")
5453
```
5554

0 commit comments

Comments
 (0)