|
1 | | -#' Widely applicable information criterion (WAIC) |
2 | | -#' |
3 | | -#' The `waic()` methods can be used to compute WAIC from the pointwise |
4 | | -#' log-likelihood. However, we recommend LOO-CV using PSIS (as implemented by |
5 | | -#' the [loo()] function) because PSIS provides useful diagnostics as well as |
6 | | -#' effective sample size and Monte Carlo estimates. |
7 | | -#' |
8 | | -#' @export waic waic.array waic.matrix waic.function |
9 | | -#' @inheritParams loo |
10 | | -#' |
11 | | -#' @return A named list (of class `c("waic", "loo")`) with components: |
12 | | -#' |
13 | | -#' \describe{ |
14 | | -#' \item{`estimates`}{ |
15 | | -#' A matrix with two columns (`"Estimate"`, `"SE"`) and three |
16 | | -#' rows (`"elpd_waic"`, `"p_waic"`, `"waic"`). This contains |
17 | | -#' point estimates and standard errors of the expected log pointwise predictive |
18 | | -#' density (`elpd_waic`), the effective number of parameters |
19 | | -#' (`p_waic`) and the information criterion `waic` (which is just |
20 | | -#' `-2 * elpd_waic`, i.e., converted to deviance scale). |
21 | | -#' } |
22 | | -#' \item{`pointwise`}{ |
23 | | -#' A matrix with three columns (and number of rows equal to the number of |
24 | | -#' observations) containing the pointwise contributions of each of the above |
25 | | -#' measures (`elpd_waic`, `p_waic`, `waic`). |
26 | | -#' } |
27 | | -#' } |
28 | | -#' |
29 | | -#' @seealso |
30 | | -#' * The __loo__ package [vignettes](https://mc-stan.org/loo/articles/) and |
31 | | -#' Vehtari, Gelman, and Gabry (2017) and Vehtari, Simpson, Gelman, Yao, |
32 | | -#' and Gabry (2024) for more details on why we prefer `loo()` to `waic()`. |
33 | | -#' * [loo_compare()] for comparing models on approximate LOO-CV or WAIC. |
34 | | -#' |
35 | | -#' @references |
36 | | -#' Watanabe, S. (2010). Asymptotic equivalence of Bayes cross validation and |
37 | | -#' widely applicable information criterion in singular learning theory. |
38 | | -#' *Journal of Machine Learning Research* **11**, 3571-3594. |
39 | | -#' |
40 | | -#' @template loo-and-psis-references |
41 | | -#' |
42 | | -#' @examples |
43 | | -#' ### Array and matrix methods |
44 | | -#' LLarr <- example_loglik_array() |
45 | | -#' dim(LLarr) |
46 | | -#' |
47 | | -#' LLmat <- example_loglik_matrix() |
48 | | -#' dim(LLmat) |
49 | | -#' |
50 | | -#' waic_arr <- waic(LLarr) |
51 | | -#' waic_mat <- waic(LLmat) |
52 | | -#' identical(waic_arr, waic_mat) |
53 | | -#' |
54 | | -#' |
55 | | -#' \dontrun{ |
56 | | -#' log_lik1 <- extract_log_lik(stanfit1) |
57 | | -#' log_lik2 <- extract_log_lik(stanfit2) |
58 | | -#' (waic1 <- waic(log_lik1)) |
59 | | -#' (waic2 <- waic(log_lik2)) |
60 | | -#' print(compare(waic1, waic2), digits = 2) |
61 | | -#' } |
62 | | -#' |
63 | | -waic <- function(x, ...) { |
64 | | - UseMethod("waic") |
65 | | -} |
66 | | - |
67 | | -#' @export |
68 | | -#' @templateVar fn waic |
69 | | -#' @template array |
70 | | -#' |
71 | | -waic.array <- function(x, ...) { |
72 | | - waic.matrix(llarray_to_matrix(x), ...) |
73 | | -} |
74 | | - |
75 | | -#' @export |
76 | | -#' @templateVar fn waic |
77 | | -#' @template matrix |
78 | | -#' |
79 | | -waic.matrix <- function(x, ...) { |
80 | | - ll <- validate_ll(x) |
81 | | - lldim <- dim(ll) |
82 | | - lpd <- matrixStats::colLogSumExps(ll) - log(nrow(ll)) # colLogMeanExps |
83 | | - p_waic <- matrixStats::colVars(ll) |
84 | | - elpd_waic <- lpd - p_waic |
85 | | - waic <- -2 * elpd_waic |
86 | | - pointwise <- cbind(elpd_waic, p_waic, waic) |
87 | | - |
88 | | - throw_pwaic_warnings(pointwise[, "p_waic"], digits = 1) |
89 | | - return(waic_object(pointwise, dims = lldim)) |
90 | | -} |
91 | | - |
92 | | - |
93 | | -#' @export |
94 | | -#' @templateVar fn waic |
95 | | -#' @template function |
96 | | -#' @param draws,data,... For the function method only. See the |
97 | | -#' **Methods (by class)** section below for details on these arguments. |
98 | | -#' |
99 | | -waic.function <- |
100 | | - function(x, |
101 | | - ..., |
102 | | - data = NULL, |
103 | | - draws = NULL) { |
104 | | - stopifnot(is.data.frame(data) || is.matrix(data), !is.null(draws)) |
105 | | - |
106 | | - .llfun <- validate_llfun(x) |
107 | | - N <- dim(data)[1] |
108 | | - S <- length(as.vector(.llfun(data_i = data[1,, drop=FALSE], draws = draws, ...))) |
109 | | - waic_list <- lapply(seq_len(N), FUN = function(i) { |
110 | | - ll_i <- .llfun(data_i = data[i,, drop=FALSE], draws = draws, ...) |
111 | | - ll_i <- as.vector(ll_i) |
112 | | - lpd_i <- logMeanExp(ll_i) |
113 | | - p_waic_i <- var(ll_i) |
114 | | - elpd_waic_i <- lpd_i - p_waic_i |
115 | | - c(elpd_waic = elpd_waic_i, p_waic = p_waic_i) |
116 | | - }) |
117 | | - pointwise <- do.call(rbind, waic_list) |
118 | | - pointwise <- cbind(pointwise, waic = -2 * pointwise[, "elpd_waic"]) |
119 | | - |
120 | | - throw_pwaic_warnings(pointwise[, "p_waic"], digits = 1) |
121 | | - waic_object(pointwise, dims = c(S, N)) |
122 | | - } |
123 | | - |
124 | | - |
125 | | -#' @export |
126 | | -dim.waic <- function(x) { |
127 | | - attr(x, "dims") |
128 | | -} |
129 | | - |
130 | | -#' @rdname waic |
131 | | -#' @export |
132 | | -is.waic <- function(x) { |
133 | | - inherits(x, "waic") && is.loo(x) |
134 | | -} |
135 | | - |
136 | | - |
137 | | -# internal ---------------------------------------------------------------- |
138 | | - |
139 | | -# structure the object returned by the waic methods |
140 | | -waic_object <- function(pointwise, dims) { |
141 | | - estimates <- table_of_estimates(pointwise) |
142 | | - out <- nlist(estimates, pointwise) |
143 | | - # maintain backwards compatibility |
144 | | - old_nms <- c("elpd_waic", "p_waic", "waic", "se_elpd_waic", "se_p_waic", "se_waic") |
145 | | - out <- c(out, setNames(as.list(estimates), old_nms)) |
146 | | - structure( |
147 | | - out, |
148 | | - dims = dims, |
149 | | - class = c("waic", "loo") |
150 | | - ) |
151 | | -} |
152 | | - |
153 | | -# waic warnings |
154 | | -# @param p 'p_waic' estimates |
155 | | -throw_pwaic_warnings <- function(p, digits = 1, warn = TRUE) { |
156 | | - badp <- p > 0.4 |
157 | | - if (any(badp)) { |
158 | | - count <- sum(badp) |
159 | | - prop <- count / length(badp) |
160 | | - msg <- paste0("\n", count, " (", .fr(100 * prop, digits), |
161 | | - "%) p_waic estimates greater than 0.4. ", |
162 | | - "We recommend trying loo instead.") |
163 | | - if (warn) .warn(msg) else cat(msg, "\n") |
164 | | - } |
165 | | - invisible(NULL) |
166 | | -} |
167 | | - |
| 1 | +#' Widely applicable information criterion (WAIC) |
| 2 | +#' |
| 3 | +#' The `waic()` methods can be used to compute WAIC from the pointwise |
| 4 | +#' log-likelihood. However, we recommend LOO-CV using PSIS (as implemented by |
| 5 | +#' the [loo()] function) because PSIS provides useful diagnostics as well as |
| 6 | +#' effective sample size and Monte Carlo estimates. |
| 7 | +#' |
| 8 | +#' @export waic waic.array waic.matrix waic.function |
| 9 | +#' @inheritParams loo |
| 10 | +#' |
| 11 | +#' @return A named list (of class `c("waic", "loo")`) with components: |
| 12 | +#' |
| 13 | +#' \describe{ |
| 14 | +#' \item{`estimates`}{ |
| 15 | +#' A matrix with two columns (`"Estimate"`, `"SE"`) and three |
| 16 | +#' rows (`"elpd_waic"`, `"p_waic"`, `"waic"`). This contains |
| 17 | +#' point estimates and standard errors of the expected log pointwise predictive |
| 18 | +#' density (`elpd_waic`), the effective number of parameters |
| 19 | +#' (`p_waic`) and the information criterion `waic` (which is just |
| 20 | +#' `-2 * elpd_waic`, i.e., converted to deviance scale). |
| 21 | +#' } |
| 22 | +#' \item{`pointwise`}{ |
| 23 | +#' A matrix with three columns (and number of rows equal to the number of |
| 24 | +#' observations) containing the pointwise contributions of each of the above |
| 25 | +#' measures (`elpd_waic`, `p_waic`, `waic`). |
| 26 | +#' } |
| 27 | +#' } |
| 28 | +#' |
| 29 | +#' @seealso |
| 30 | +#' * The __loo__ package [vignettes](https://mc-stan.org/loo/articles/) and |
| 31 | +#' Vehtari, Gelman, and Gabry (2017) and Vehtari, Simpson, Gelman, Yao, |
| 32 | +#' and Gabry (2024) for more details on why we prefer `loo()` to `waic()`. |
| 33 | +#' * [loo_compare()] for comparing models on approximate LOO-CV or WAIC. |
| 34 | +#' |
| 35 | +#' @references |
| 36 | +#' Watanabe, S. (2010). Asymptotic equivalence of Bayes cross validation and |
| 37 | +#' widely applicable information criterion in singular learning theory. |
| 38 | +#' *Journal of Machine Learning Research* **11**, 3571-3594. |
| 39 | +#' |
| 40 | +#' @template loo-and-psis-references |
| 41 | +#' |
| 42 | +#' @examples |
| 43 | +#' ### Array and matrix methods |
| 44 | +#' LLarr <- example_loglik_array() |
| 45 | +#' dim(LLarr) |
| 46 | +#' |
| 47 | +#' LLmat <- example_loglik_matrix() |
| 48 | +#' dim(LLmat) |
| 49 | +#' |
| 50 | +#' waic_arr <- waic(LLarr) |
| 51 | +#' waic_mat <- waic(LLmat) |
| 52 | +#' identical(waic_arr, waic_mat) |
| 53 | +#' |
| 54 | +#' |
| 55 | +#' \dontrun{ |
| 56 | +#' log_lik1 <- extract_log_lik(stanfit1) |
| 57 | +#' log_lik2 <- extract_log_lik(stanfit2) |
| 58 | +#' (waic1 <- waic(log_lik1)) |
| 59 | +#' (waic2 <- waic(log_lik2)) |
| 60 | +#' print(compare(waic1, waic2), digits = 2) |
| 61 | +#' } |
| 62 | +#' |
| 63 | +waic <- function(x, ...) { |
| 64 | + UseMethod("waic") |
| 65 | +} |
| 66 | + |
| 67 | +#' @export |
| 68 | +#' @templateVar fn waic |
| 69 | +#' @template array |
| 70 | +#' |
| 71 | +waic.array <- function(x, ...) { |
| 72 | + waic.matrix(llarray_to_matrix(x), ...) |
| 73 | +} |
| 74 | + |
| 75 | +#' @export |
| 76 | +#' @templateVar fn waic |
| 77 | +#' @template matrix |
| 78 | +#' |
| 79 | +waic.matrix <- function(x, ...) { |
| 80 | + ll <- validate_ll(x) |
| 81 | + lldim <- dim(ll) |
| 82 | + lpd <- matrixStats::colLogSumExps(ll) - log(nrow(ll)) # colLogMeanExps |
| 83 | + p_waic <- matrixStats::colVars(ll) |
| 84 | + elpd_waic <- lpd - p_waic |
| 85 | + waic <- -2 * elpd_waic |
| 86 | + pointwise <- cbind(elpd_waic, p_waic, waic) |
| 87 | + |
| 88 | + throw_pwaic_warnings(pointwise[, "p_waic"], digits = 1) |
| 89 | + return(waic_object(pointwise, dims = lldim)) |
| 90 | +} |
| 91 | + |
| 92 | + |
| 93 | +#' @export |
| 94 | +#' @templateVar fn waic |
| 95 | +#' @template function |
| 96 | +#' @param draws,data,... For the function method only. See the |
| 97 | +#' **Methods (by class)** section below for details on these arguments. |
| 98 | +#' |
| 99 | +waic.function <- |
| 100 | + function(x, |
| 101 | + ..., |
| 102 | + data = NULL, |
| 103 | + draws = NULL) { |
| 104 | + stopifnot(is.data.frame(data) || is.matrix(data), !is.null(draws)) |
| 105 | + |
| 106 | + .llfun <- validate_llfun(x) |
| 107 | + N <- dim(data)[1] |
| 108 | + S <- length(as.vector(.llfun(data_i = data[1,, drop=FALSE], draws = draws, ...))) |
| 109 | + waic_list <- lapply(seq_len(N), FUN = function(i) { |
| 110 | + ll_i <- .llfun(data_i = data[i,, drop=FALSE], draws = draws, ...) |
| 111 | + ll_i <- as.vector(ll_i) |
| 112 | + lpd_i <- logMeanExp(ll_i) |
| 113 | + p_waic_i <- var(ll_i) |
| 114 | + elpd_waic_i <- lpd_i - p_waic_i |
| 115 | + c(elpd_waic = elpd_waic_i, p_waic = p_waic_i) |
| 116 | + }) |
| 117 | + pointwise <- do.call(rbind, waic_list) |
| 118 | + pointwise <- cbind(pointwise, waic = -2 * pointwise[, "elpd_waic"]) |
| 119 | + |
| 120 | + throw_pwaic_warnings(pointwise[, "p_waic"], digits = 1) |
| 121 | + waic_object(pointwise, dims = c(S, N)) |
| 122 | + } |
| 123 | + |
| 124 | + |
| 125 | +#' @export |
| 126 | +dim.waic <- function(x) { |
| 127 | + attr(x, "dims") |
| 128 | +} |
| 129 | + |
| 130 | +#' @rdname waic |
| 131 | +#' @export |
| 132 | +is.waic <- function(x) { |
| 133 | + inherits(x, "waic") && is.loo(x) |
| 134 | +} |
| 135 | + |
| 136 | + |
| 137 | +# internal ---------------------------------------------------------------- |
| 138 | + |
| 139 | +# structure the object returned by the waic methods |
| 140 | +waic_object <- function(pointwise, dims) { |
| 141 | + estimates <- table_of_estimates(pointwise) |
| 142 | + out <- nlist(estimates, pointwise) |
| 143 | + # maintain backwards compatibility |
| 144 | + old_nms <- c("elpd_waic", "p_waic", "waic", "se_elpd_waic", "se_p_waic", "se_waic") |
| 145 | + out <- c(out, setNames(as.list(estimates), old_nms)) |
| 146 | + structure( |
| 147 | + out, |
| 148 | + dims = dims, |
| 149 | + class = c("waic", "loo") |
| 150 | + ) |
| 151 | +} |
| 152 | + |
| 153 | +# waic warnings |
| 154 | +# @param p 'p_waic' estimates |
| 155 | +throw_pwaic_warnings <- function(p, digits = 1, warn = TRUE) { |
| 156 | + badp <- p > 0.4 |
| 157 | + if (any(badp)) { |
| 158 | + count <- sum(badp) |
| 159 | + prop <- count / length(badp) |
| 160 | + msg <- paste0("\n", count, " (", .fr(100 * prop, digits), |
| 161 | + "%) p_waic estimates greater than 0.4. ", |
| 162 | + "We recommend trying loo instead.") |
| 163 | + if (warn) .warn(msg) else cat(msg, "\n") |
| 164 | + } |
| 165 | + invisible(NULL) |
| 166 | +} |
| 167 | + |
0 commit comments