Skip to content

Commit dadf6e5

Browse files
authored
Merge pull request #32 from hyunjimoon/nonHMC
Support for variational, (optimize)
2 parents 841c692 + 5d5b341 commit dadf6e5

6 files changed

Lines changed: 73 additions & 10 deletions

File tree

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ S3method("[",SBC_datasets)
44
S3method("[",SBC_results)
55
S3method(SBC_fit,SBC_backend_brms)
66
S3method(SBC_fit,SBC_backend_cmdstan_sample)
7+
S3method(SBC_fit,SBC_backend_cmdstan_variational)
78
S3method(SBC_fit,SBC_backend_rstan_sample)
89
S3method(SBC_fit_to_diagnostics,CmdStanMCMC)
910
S3method(SBC_fit_to_diagnostics,brmsfit)
1011
S3method(SBC_fit_to_diagnostics,default)
1112
S3method(SBC_fit_to_diagnostics,stanfit)
1213
S3method(SBC_fit_to_draws_matrix,CmdStanMCMC)
14+
S3method(SBC_fit_to_draws_matrix,CmdStanVB)
1315
S3method(SBC_fit_to_draws_matrix,brmsfit)
1416
S3method(SBC_fit_to_draws_matrix,default)
1517
S3method(check_all_SBC_diagnostics,SBC_results)
@@ -34,6 +36,7 @@ S3method(summary,SBC_results)
3436
export(SBC_backend_brms)
3537
export(SBC_backend_brms_from_generator)
3638
export(SBC_backend_cmdstan_sample)
39+
export(SBC_backend_cmdstan_variational)
3740
export(SBC_backend_rstan_sample)
3841
export(SBC_datasets)
3942
export(SBC_diagnostic_messages)

R/.Rapp.history

Whitespace-only changes.

R/backends.R

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,48 @@ SBC_fit_to_diagnostics.CmdStanMCMC <- function(fit, fit_output, fit_messages, fi
267267
res
268268
}
269269

270+
#' Backend based on variational approximation via `cmdstanr`.
271+
#'
272+
#' @param model an object of class `CmdStanModel` (as created by `cmdstanr::cmdstan_model`)
273+
#' @param ... other arguments passed to the `$variational()` method of the model. The `data` and
274+
#' `parallel_chains` arguments cannot be set this way as they need to be controlled by the SBC
275+
#' package.
276+
#' @export
277+
SBC_backend_cmdstan_variational <- function(model, ...) {
278+
stopifnot(inherits(model, "CmdStanModel"))
279+
if(length(model$exe_file()) == 0) {
280+
stop("The model has to be already compiled, call $compile() first.")
281+
}
282+
args <- list(...)
283+
unacceptable_params <- c("data")
284+
if(any(names(args) %in% unacceptable_params)) {
285+
stop(paste0("Parameters ", paste0("'", unacceptable_params, "'", collapse = ", "),
286+
" cannot be provided when defining a backend as they need to be set ",
287+
"by the SBC package"))
288+
}
289+
structure(list(model = model, args = args), class = "SBC_backend_cmdstan_variational")
290+
}
291+
292+
#' @export
293+
SBC_fit.SBC_backend_cmdstan_variational <- function(backend, generated, cores) {
294+
fit <- do.call(backend$model$variational,
295+
combine_args(backend$args,
296+
list(
297+
data = generated)))
298+
299+
if(all(fit$return_codes() != 0)) {
300+
stop("Variational inference did not finish succesfully")
301+
}
302+
303+
fit
304+
}
305+
306+
#' @export
307+
SBC_fit_to_draws_matrix.CmdStanVB <- function(fit) {
308+
fit$draws(format = "draws_matrix")
309+
310+
}
311+
270312
# For internal use, creates brms backend.
271313
new_SBC_backend_brms <- function(compiled_model,
272314
args

R/results.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,13 @@ length.SBC_results <- function(x) {
212212
#'
213213
#' Parallel processing is supported via the `future` package, for most uses, it is most sensible
214214
#' to just call `plan(multisession)` once in your R session and all
215-
#' cores your computer has will be used. For more details refer to the documentation
215+
#' cores your computer will be used. For more details refer to the documentation
216216
#' of the `future` package.
217217
#'
218218
#' @param datasets an object of class `SBC_datasets`
219219
#' @param backend the model + sampling algorithm. The built-in backends can be constructed
220-
#' using `SBC_backend_cmdstan_sample()`, `SBC_backend_rstan_sample()` and `SBC_backend_brms()`.
221-
#' (more to come). The backend is an S3 class supporting at least the `SBC_fit`,
220+
#' using `SBC_backend_cmdstan_sample()`, `SBC_backend_cmdstan_variational()`, `SBC_backend_rstan_sample()` and `SBC_backend_brms()`.
221+
#' (more to come: issue 31, 38, 39). The backend is an S3 class supporting at least the `SBC_fit`,
222222
#' `SBC_fit_to_draws_matrix` methods.
223223
#' @param cores_per_fit how many cores should the backend be allowed to use for a single fit?
224224
#' Defaults to the maximum number that does not produce more parallel chains
@@ -271,7 +271,6 @@ compute_results <- function(datasets, backend,
271271
generated = datasets$generated[[i]]
272272
)
273273
}
274-
275274
if(is.null(gen_quants)) {
276275
future.globals <- FALSE
277276
} else {
@@ -288,7 +287,6 @@ compute_results <- function(datasets, backend,
288287
future.globals = future.globals,
289288
future.chunk.size = chunk_size)
290289

291-
292290
# Combine, check and summarise
293291
fits <- rep(list(NULL), length(datasets))
294292
outputs <- rep(list(NULL), length(datasets))
@@ -307,7 +305,9 @@ compute_results <- function(datasets, backend,
307305
stats_list[[i]] <- results_raw[[i]]$stats
308306
stats_list[[i]]$dataset_id <- i
309307
backend_diagnostics_list[[i]] <- results_raw[[i]]$backend_diagnostics
310-
backend_diagnostics_list[[i]]$dataset_id <- i
308+
if(!is.null(results_raw[[i]]$backend_diagnostics)){
309+
backend_diagnostics_list[[i]]$dataset_id <- i
310+
}
311311
}
312312
else {
313313
if(n_errors < max_errors_to_show) {

man/SBC_backend_cmdstan_variational.Rd

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/compute_results.Rd

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)