Skip to content

Commit 2a916b1

Browse files
committed
Swapped to lf properly
1 parent 53ecde4 commit 2a916b1

File tree

2 files changed

+171
-169
lines changed

2 files changed

+171
-169
lines changed

.gitattributes

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
* text=auto
22
data/* binary
3-
src/* text=lf
4-
R/* text=lf
3+
src/* text eol=lf
4+
R/* text eol=lf
5+
*.rda binary
6+

R/waic.R

Lines changed: 167 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,167 +1,167 @@
1-
#' Widely applicable information criterion (WAIC)
2-
#'
3-
#' The `waic()` methods can be used to compute WAIC from the pointwise
4-
#' log-likelihood. However, we recommend LOO-CV using PSIS (as implemented by
5-
#' the [loo()] function) because PSIS provides useful diagnostics as well as
6-
#' effective sample size and Monte Carlo estimates.
7-
#'
8-
#' @export waic waic.array waic.matrix waic.function
9-
#' @inheritParams loo
10-
#'
11-
#' @return A named list (of class `c("waic", "loo")`) with components:
12-
#'
13-
#' \describe{
14-
#' \item{`estimates`}{
15-
#' A matrix with two columns (`"Estimate"`, `"SE"`) and three
16-
#' rows (`"elpd_waic"`, `"p_waic"`, `"waic"`). This contains
17-
#' point estimates and standard errors of the expected log pointwise predictive
18-
#' density (`elpd_waic`), the effective number of parameters
19-
#' (`p_waic`) and the information criterion `waic` (which is just
20-
#' `-2 * elpd_waic`, i.e., converted to deviance scale).
21-
#' }
22-
#' \item{`pointwise`}{
23-
#' A matrix with three columns (and number of rows equal to the number of
24-
#' observations) containing the pointwise contributions of each of the above
25-
#' measures (`elpd_waic`, `p_waic`, `waic`).
26-
#' }
27-
#' }
28-
#'
29-
#' @seealso
30-
#' * The __loo__ package [vignettes](https://mc-stan.org/loo/articles/) and
31-
#' Vehtari, Gelman, and Gabry (2017) and Vehtari, Simpson, Gelman, Yao,
32-
#' and Gabry (2024) for more details on why we prefer `loo()` to `waic()`.
33-
#' * [loo_compare()] for comparing models on approximate LOO-CV or WAIC.
34-
#'
35-
#' @references
36-
#' Watanabe, S. (2010). Asymptotic equivalence of Bayes cross validation and
37-
#' widely applicable information criterion in singular learning theory.
38-
#' *Journal of Machine Learning Research* **11**, 3571-3594.
39-
#'
40-
#' @template loo-and-psis-references
41-
#'
42-
#' @examples
43-
#' ### Array and matrix methods
44-
#' LLarr <- example_loglik_array()
45-
#' dim(LLarr)
46-
#'
47-
#' LLmat <- example_loglik_matrix()
48-
#' dim(LLmat)
49-
#'
50-
#' waic_arr <- waic(LLarr)
51-
#' waic_mat <- waic(LLmat)
52-
#' identical(waic_arr, waic_mat)
53-
#'
54-
#'
55-
#' \dontrun{
56-
#' log_lik1 <- extract_log_lik(stanfit1)
57-
#' log_lik2 <- extract_log_lik(stanfit2)
58-
#' (waic1 <- waic(log_lik1))
59-
#' (waic2 <- waic(log_lik2))
60-
#' print(compare(waic1, waic2), digits = 2)
61-
#' }
62-
#'
63-
waic <- function(x, ...) {
64-
UseMethod("waic")
65-
}
66-
67-
#' @export
68-
#' @templateVar fn waic
69-
#' @template array
70-
#'
71-
waic.array <- function(x, ...) {
72-
waic.matrix(llarray_to_matrix(x), ...)
73-
}
74-
75-
#' @export
76-
#' @templateVar fn waic
77-
#' @template matrix
78-
#'
79-
waic.matrix <- function(x, ...) {
80-
ll <- validate_ll(x)
81-
lldim <- dim(ll)
82-
lpd <- matrixStats::colLogSumExps(ll) - log(nrow(ll)) # colLogMeanExps
83-
p_waic <- matrixStats::colVars(ll)
84-
elpd_waic <- lpd - p_waic
85-
waic <- -2 * elpd_waic
86-
pointwise <- cbind(elpd_waic, p_waic, waic)
87-
88-
throw_pwaic_warnings(pointwise[, "p_waic"], digits = 1)
89-
return(waic_object(pointwise, dims = lldim))
90-
}
91-
92-
93-
#' @export
94-
#' @templateVar fn waic
95-
#' @template function
96-
#' @param draws,data,... For the function method only. See the
97-
#' **Methods (by class)** section below for details on these arguments.
98-
#'
99-
waic.function <-
100-
function(x,
101-
...,
102-
data = NULL,
103-
draws = NULL) {
104-
stopifnot(is.data.frame(data) || is.matrix(data), !is.null(draws))
105-
106-
.llfun <- validate_llfun(x)
107-
N <- dim(data)[1]
108-
S <- length(as.vector(.llfun(data_i = data[1,, drop=FALSE], draws = draws, ...)))
109-
waic_list <- lapply(seq_len(N), FUN = function(i) {
110-
ll_i <- .llfun(data_i = data[i,, drop=FALSE], draws = draws, ...)
111-
ll_i <- as.vector(ll_i)
112-
lpd_i <- logMeanExp(ll_i)
113-
p_waic_i <- var(ll_i)
114-
elpd_waic_i <- lpd_i - p_waic_i
115-
c(elpd_waic = elpd_waic_i, p_waic = p_waic_i)
116-
})
117-
pointwise <- do.call(rbind, waic_list)
118-
pointwise <- cbind(pointwise, waic = -2 * pointwise[, "elpd_waic"])
119-
120-
throw_pwaic_warnings(pointwise[, "p_waic"], digits = 1)
121-
waic_object(pointwise, dims = c(S, N))
122-
}
123-
124-
125-
#' @export
126-
dim.waic <- function(x) {
127-
attr(x, "dims")
128-
}
129-
130-
#' @rdname waic
131-
#' @export
132-
is.waic <- function(x) {
133-
inherits(x, "waic") && is.loo(x)
134-
}
135-
136-
137-
# internal ----------------------------------------------------------------
138-
139-
# structure the object returned by the waic methods
140-
waic_object <- function(pointwise, dims) {
141-
estimates <- table_of_estimates(pointwise)
142-
out <- nlist(estimates, pointwise)
143-
# maintain backwards compatibility
144-
old_nms <- c("elpd_waic", "p_waic", "waic", "se_elpd_waic", "se_p_waic", "se_waic")
145-
out <- c(out, setNames(as.list(estimates), old_nms))
146-
structure(
147-
out,
148-
dims = dims,
149-
class = c("waic", "loo")
150-
)
151-
}
152-
153-
# waic warnings
154-
# @param p 'p_waic' estimates
155-
throw_pwaic_warnings <- function(p, digits = 1, warn = TRUE) {
156-
badp <- p > 0.4
157-
if (any(badp)) {
158-
count <- sum(badp)
159-
prop <- count / length(badp)
160-
msg <- paste0("\n", count, " (", .fr(100 * prop, digits),
161-
"%) p_waic estimates greater than 0.4. ",
162-
"We recommend trying loo instead.")
163-
if (warn) .warn(msg) else cat(msg, "\n")
164-
}
165-
invisible(NULL)
166-
}
167-
1+
#' Widely applicable information criterion (WAIC)
2+
#'
3+
#' The `waic()` methods can be used to compute WAIC from the pointwise
4+
#' log-likelihood. However, we recommend LOO-CV using PSIS (as implemented by
5+
#' the [loo()] function) because PSIS provides useful diagnostics as well as
6+
#' effective sample size and Monte Carlo estimates.
7+
#'
8+
#' @export waic waic.array waic.matrix waic.function
9+
#' @inheritParams loo
10+
#'
11+
#' @return A named list (of class `c("waic", "loo")`) with components:
12+
#'
13+
#' \describe{
14+
#' \item{`estimates`}{
15+
#' A matrix with two columns (`"Estimate"`, `"SE"`) and three
16+
#' rows (`"elpd_waic"`, `"p_waic"`, `"waic"`). This contains
17+
#' point estimates and standard errors of the expected log pointwise predictive
18+
#' density (`elpd_waic`), the effective number of parameters
19+
#' (`p_waic`) and the information criterion `waic` (which is just
20+
#' `-2 * elpd_waic`, i.e., converted to deviance scale).
21+
#' }
22+
#' \item{`pointwise`}{
23+
#' A matrix with three columns (and number of rows equal to the number of
24+
#' observations) containing the pointwise contributions of each of the above
25+
#' measures (`elpd_waic`, `p_waic`, `waic`).
26+
#' }
27+
#' }
28+
#'
29+
#' @seealso
30+
#' * The __loo__ package [vignettes](https://mc-stan.org/loo/articles/) and
31+
#' Vehtari, Gelman, and Gabry (2017) and Vehtari, Simpson, Gelman, Yao,
32+
#' and Gabry (2024) for more details on why we prefer `loo()` to `waic()`.
33+
#' * [loo_compare()] for comparing models on approximate LOO-CV or WAIC.
34+
#'
35+
#' @references
36+
#' Watanabe, S. (2010). Asymptotic equivalence of Bayes cross validation and
37+
#' widely applicable information criterion in singular learning theory.
38+
#' *Journal of Machine Learning Research* **11**, 3571-3594.
39+
#'
40+
#' @template loo-and-psis-references
41+
#'
42+
#' @examples
43+
#' ### Array and matrix methods
44+
#' LLarr <- example_loglik_array()
45+
#' dim(LLarr)
46+
#'
47+
#' LLmat <- example_loglik_matrix()
48+
#' dim(LLmat)
49+
#'
50+
#' waic_arr <- waic(LLarr)
51+
#' waic_mat <- waic(LLmat)
52+
#' identical(waic_arr, waic_mat)
53+
#'
54+
#'
55+
#' \dontrun{
56+
#' log_lik1 <- extract_log_lik(stanfit1)
57+
#' log_lik2 <- extract_log_lik(stanfit2)
58+
#' (waic1 <- waic(log_lik1))
59+
#' (waic2 <- waic(log_lik2))
60+
#' print(compare(waic1, waic2), digits = 2)
61+
#' }
62+
#'
63+
waic <- function(x, ...) {
64+
UseMethod("waic")
65+
}
66+
67+
#' @export
68+
#' @templateVar fn waic
69+
#' @template array
70+
#'
71+
waic.array <- function(x, ...) {
72+
waic.matrix(llarray_to_matrix(x), ...)
73+
}
74+
75+
#' @export
76+
#' @templateVar fn waic
77+
#' @template matrix
78+
#'
79+
waic.matrix <- function(x, ...) {
80+
ll <- validate_ll(x)
81+
lldim <- dim(ll)
82+
lpd <- matrixStats::colLogSumExps(ll) - log(nrow(ll)) # colLogMeanExps
83+
p_waic <- matrixStats::colVars(ll)
84+
elpd_waic <- lpd - p_waic
85+
waic <- -2 * elpd_waic
86+
pointwise <- cbind(elpd_waic, p_waic, waic)
87+
88+
throw_pwaic_warnings(pointwise[, "p_waic"], digits = 1)
89+
return(waic_object(pointwise, dims = lldim))
90+
}
91+
92+
93+
#' @export
94+
#' @templateVar fn waic
95+
#' @template function
96+
#' @param draws,data,... For the function method only. See the
97+
#' **Methods (by class)** section below for details on these arguments.
98+
#'
99+
waic.function <-
100+
function(x,
101+
...,
102+
data = NULL,
103+
draws = NULL) {
104+
stopifnot(is.data.frame(data) || is.matrix(data), !is.null(draws))
105+
106+
.llfun <- validate_llfun(x)
107+
N <- dim(data)[1]
108+
S <- length(as.vector(.llfun(data_i = data[1,, drop=FALSE], draws = draws, ...)))
109+
waic_list <- lapply(seq_len(N), FUN = function(i) {
110+
ll_i <- .llfun(data_i = data[i,, drop=FALSE], draws = draws, ...)
111+
ll_i <- as.vector(ll_i)
112+
lpd_i <- logMeanExp(ll_i)
113+
p_waic_i <- var(ll_i)
114+
elpd_waic_i <- lpd_i - p_waic_i
115+
c(elpd_waic = elpd_waic_i, p_waic = p_waic_i)
116+
})
117+
pointwise <- do.call(rbind, waic_list)
118+
pointwise <- cbind(pointwise, waic = -2 * pointwise[, "elpd_waic"])
119+
120+
throw_pwaic_warnings(pointwise[, "p_waic"], digits = 1)
121+
waic_object(pointwise, dims = c(S, N))
122+
}
123+
124+
125+
#' @export
126+
dim.waic <- function(x) {
127+
attr(x, "dims")
128+
}
129+
130+
#' @rdname waic
131+
#' @export
132+
is.waic <- function(x) {
133+
inherits(x, "waic") && is.loo(x)
134+
}
135+
136+
137+
# internal ----------------------------------------------------------------
138+
139+
# structure the object returned by the waic methods
140+
waic_object <- function(pointwise, dims) {
141+
estimates <- table_of_estimates(pointwise)
142+
out <- nlist(estimates, pointwise)
143+
# maintain backwards compatibility
144+
old_nms <- c("elpd_waic", "p_waic", "waic", "se_elpd_waic", "se_p_waic", "se_waic")
145+
out <- c(out, setNames(as.list(estimates), old_nms))
146+
structure(
147+
out,
148+
dims = dims,
149+
class = c("waic", "loo")
150+
)
151+
}
152+
153+
# waic warnings
154+
# @param p 'p_waic' estimates
155+
throw_pwaic_warnings <- function(p, digits = 1, warn = TRUE) {
156+
badp <- p > 0.4
157+
if (any(badp)) {
158+
count <- sum(badp)
159+
prop <- count / length(badp)
160+
msg <- paste0("\n", count, " (", .fr(100 * prop, digits),
161+
"%) p_waic estimates greater than 0.4. ",
162+
"We recommend trying loo instead.")
163+
if (warn) .warn(msg) else cat(msg, "\n")
164+
}
165+
invisible(NULL)
166+
}
167+

0 commit comments

Comments
 (0)