Skip to content

Commit cff3c2c

Browse files
committed
Revert "remove old loo::compare()"
This reverts commit a84154e.
1 parent 4e16d2f commit cff3c2c

4 files changed

Lines changed: 299 additions & 0 deletions

File tree

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ export(.compute_point_estimate)
9999
export(.ndraws)
100100
export(.thin_draws)
101101
export(E_loo)
102+
export(compare)
102103
export(crps)
103104
export(elpd)
104105
export(example_loglik_array)

R/compare.R

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#' Model comparison (deprecated, old version)
2+
#'
3+
#' **This function is deprecated**. Please use the new [loo_compare()] function
4+
#' instead.
5+
#'
6+
#' @export
7+
#' @param ... At least two objects returned by [loo()] (or [waic()]).
8+
#' @param x A list of at least two objects returned by [loo()] (or
9+
#' [waic()]). This argument can be used as an alternative to
10+
#' specifying the objects in `...`.
11+
#'
12+
#' @return A vector or matrix with class `'compare.loo'` that has its own
13+
#' print method. If exactly two objects are provided in `...` or
14+
#' `x`, then the difference in expected predictive accuracy and the
15+
#' standard error of the difference are returned. If more than two objects are
16+
#' provided then a matrix of summary information is returned (see **Details**).
17+
#'
18+
#' @details
19+
#' When comparing two fitted models, we can estimate the difference in their
20+
#' expected predictive accuracy by the difference in `elpd_loo` or
21+
#' `elpd_waic` (or multiplied by -2, if desired, to be on the
22+
#' deviance scale).
23+
#'
24+
#' *When that difference, `elpd_diff`, is positive then the expected
25+
#' predictive accuracy for the second model is higher. A negative
26+
#' `elpd_diff` favors the first model.*
27+
#'
28+
#' When using `compare()` with more than two models, the values in the
29+
#' `elpd_diff` and `se_diff` columns of the returned matrix are
30+
#' computed by making pairwise comparisons between each model and the model
31+
#' with the best ELPD (i.e., the model in the first row).
32+
#' Although the `elpd_diff` column is equal to the difference in
33+
#' `elpd_loo`, do not expect the `se_diff` column to be equal to the
34+
#' the difference in `se_elpd_loo`.
35+
#'
36+
#' To compute the standard error of the difference in ELPD we use a
37+
#' paired estimate to take advantage of the fact that the same set of _N_
38+
#' data points was used to fit both models. These calculations should be most
39+
#' useful when _N_ is large, because then non-normality of the
40+
#' distribution is not such an issue when estimating the uncertainty in these
41+
#' sums. These standard errors, for all their flaws, should give a better
42+
#' sense of uncertainty than what is obtained using the current standard
43+
#' approach of comparing differences of deviances to a Chi-squared
44+
#' distribution, a practice derived for Gaussian linear models or
45+
#' asymptotically, and which only applies to nested models in any case.
46+
#'
47+
#' @template loo-and-psis-references
48+
#'
49+
#' @examples
50+
#' \dontrun{
51+
#' loo1 <- loo(log_lik1)
52+
#' loo2 <- loo(log_lik2)
53+
#' print(compare(loo1, loo2), digits = 3)
54+
#' print(compare(x = list(loo1, loo2)))
55+
#'
56+
#' waic1 <- waic(log_lik1)
57+
#' waic2 <- waic(log_lik2)
58+
#' compare(waic1, waic2)
59+
#' }
60+
#'
61+
compare <- function(..., x = list()) {
62+
.Deprecated("loo_compare")
63+
dots <- list(...)
64+
if (length(dots)) {
65+
if (length(x)) {
66+
stop("If 'x' is specified then '...' should not be specified.",
67+
call. = FALSE)
68+
}
69+
nms <- as.character(match.call(expand.dots = TRUE))[-1L]
70+
} else {
71+
if (!is.list(x) || !length(x)) {
72+
stop("'x' must be a list.", call. = FALSE)
73+
}
74+
dots <- x
75+
nms <- names(dots)
76+
if (!length(nms)) {
77+
nms <- paste0("model", seq_along(dots))
78+
}
79+
}
80+
81+
if (!all(sapply(dots, is.loo))) {
82+
stop("All inputs should have class 'loo'.")
83+
}
84+
if (length(dots) <= 1L) {
85+
stop("'compare' requires at least two models.")
86+
} else if (length(dots) == 2L) {
87+
loo1 <- dots[[1]]
88+
loo2 <- dots[[2]]
89+
comp <- compare_two_models(loo1, loo2)
90+
class(comp) <- c(class(comp), "old_compare.loo")
91+
return(comp)
92+
} else {
93+
Ns <- sapply(dots, function(x) nrow(x$pointwise))
94+
if (!all(Ns == Ns[1L])) {
95+
stop("Not all models have the same number of data points.", call. = FALSE)
96+
}
97+
98+
x <- sapply(dots, function(x) {
99+
est <- x$estimates
100+
setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))) )
101+
})
102+
colnames(x) <- nms
103+
rnms <- rownames(x)
104+
comp <- x
105+
ord <- order(x[grep("^elpd", rnms), ], decreasing = TRUE)
106+
comp <- t(comp)[ord, ]
107+
patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$")
108+
col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))),
109+
use.names = FALSE)
110+
comp <- comp[, col_ord]
111+
112+
# compute elpd_diff and se_elpd_diff relative to best model
113+
rnms <- rownames(comp)
114+
diffs <- mapply(elpd_diffs, dots[ord[1]], dots[ord])
115+
elpd_diff <- apply(diffs, 2, sum)
116+
se_diff <- apply(diffs, 2, se_elpd_diff)
117+
comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp)
118+
rownames(comp) <- rnms
119+
class(comp) <- c("compare.loo", class(comp), "old_compare.loo")
120+
comp
121+
}
122+
}
123+
124+
125+
126+
# internal ----------------------------------------------------------------
127+
compare_two_models <- function(loo_a, loo_b, return = c("elpd_diff", "se"), check_dims = TRUE) {
128+
if (check_dims) {
129+
if (dim(loo_a$pointwise)[1] != dim(loo_b$pointwise)[1]) {
130+
stop(paste("Models don't have the same number of data points.",
131+
"\nFound N_1 =", dim(loo_a$pointwise)[1], "and N_2 =", dim(loo_b$pointwise)[1]), call. = FALSE)
132+
}
133+
}
134+
135+
diffs <- elpd_diffs(loo_a, loo_b)
136+
comp <- c(elpd_diff = sum(diffs), se = se_elpd_diff(diffs))
137+
structure(comp, class = "compare.loo")
138+
}

man/compare.Rd

Lines changed: 80 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_compare.R

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,83 @@ test_that("loo_compare returns expected result (3 models)", {
109109
# except rownames) to using 'x' argument
110110
expect_equal(comp1, loo_compare(x = list(w1, w2, w3)), ignore_attr = TRUE)
111111
})
112+
113+
# Tests for deprecated compare() ------------------------------------------
114+
115+
test_that("compare throws deprecation warnings", {
116+
expect_warning(loo::compare(w1, w2), "Deprecated")
117+
expect_warning(loo::compare(w1, w1, w2), "Deprecated")
118+
})
119+
120+
test_that("compare returns expected result (2 models)", {
121+
expect_warning(comp1 <- loo::compare(w1, w1), "Deprecated")
122+
expect_snapshot(comp1)
123+
expect_equal(comp1[1:2], c(elpd_diff = 0, se = 0))
124+
125+
expect_warning(comp2 <- loo::compare(w1, w2), "Deprecated")
126+
expect_snapshot(comp2)
127+
expect_named(comp2, c("elpd_diff", "se"))
128+
expect_s3_class(comp2, "compare.loo")
129+
130+
# specifying objects via ... and via arg x gives equal results
131+
expect_warning(comp_via_list <- loo::compare(x = list(w1, w2)), "Deprecated")
132+
expect_equal(comp2, comp_via_list)
133+
})
134+
135+
test_that("compare returns expected result (3 models)", {
136+
w3 <- suppressWarnings(waic(LLarr3))
137+
expect_warning(comp1 <- loo::compare(w1, w2, w3), "Deprecated")
138+
139+
expect_equal(
140+
colnames(comp1),
141+
c(
142+
"elpd_diff",
143+
"se_diff",
144+
"elpd_waic",
145+
"se_elpd_waic",
146+
"p_waic",
147+
"se_p_waic",
148+
"waic",
149+
"se_waic"
150+
)
151+
)
152+
expect_equal(rownames(comp1), c("w1", "w2", "w3"))
153+
expect_equal(comp1[1, 1], 0)
154+
expect_s3_class(comp1, "compare.loo")
155+
expect_s3_class(comp1, "matrix")
156+
expect_snapshot_value(comp1, style = "serialize")
157+
158+
# specifying objects via '...' gives equivalent results (equal
159+
# except rownames) to using 'x' argument
160+
expect_warning(
161+
comp_via_list <- loo::compare(x = list(w1, w2, w3)),
162+
"Deprecated"
163+
)
164+
expect_equal(comp1, comp_via_list, ignore_attr = TRUE)
165+
})
166+
167+
test_that("compare throws appropriate errors", {
168+
expect_error(
169+
suppressWarnings(loo::compare(w1, w2, x = list(w1, w2))),
170+
"should not be specified"
171+
)
172+
expect_error(suppressWarnings(loo::compare(x = 2)), "must be a list")
173+
expect_error(
174+
suppressWarnings(loo::compare(x = list(2))),
175+
"should have class 'loo'"
176+
)
177+
expect_error(
178+
suppressWarnings(loo::compare(x = list(w1))),
179+
"requires at least two models"
180+
)
181+
182+
w3 <- suppressWarnings(waic(LLarr2[,, -1]))
183+
expect_error(
184+
suppressWarnings(loo::compare(x = list(w1, w3))),
185+
"same number of data points"
186+
)
187+
expect_error(
188+
suppressWarnings(loo::compare(x = list(w1, w2, w3))),
189+
"same number of data points"
190+
)
191+
})

0 commit comments

Comments
 (0)