Skip to content

Commit 60f012d

Browse files
authored
Merge pull request #342 from florence-bockting/update-loo-print
Create new kfold.print method
2 parents 1ad507e + 97c2f90 commit 60f012d

File tree

8 files changed

+129
-0
lines changed

8 files changed

+129
-0
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ S3method(print,compare.loo)
5656
S3method(print,compare.loo_ss)
5757
S3method(print,importance_sampling)
5858
S3method(print,importance_sampling_loo)
59+
S3method(print,kfold)
5960
S3method(print,loo)
6061
S3method(print,pareto_k_table)
6162
S3method(print,pseudobma_bb_weights)

R/print.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,28 @@ print.importance_sampling <- function(x, digits = 1, plot_k = FALSE, ...) {
105105
invisible(x)
106106
}
107107

108+
#' @export
109+
#' @rdname print.loo
110+
print.kfold <- function(x, digits = 1, plot_k = FALSE, ...) {
111+
print.loo(x, digits = digits, ...)
112+
113+
if ("diagnostics" %in% names(x)) {
114+
cat("------\n")
115+
S <- dim(x)[1]
116+
k_threshold <- ps_khat_threshold(S)
117+
if (length(pareto_k_ids(x, threshold = k_threshold))) {
118+
cat("\n")
119+
}
120+
print(pareto_k_table(x), digits = digits)
121+
cat(.k_help())
122+
123+
if (plot_k) {
124+
graphics::plot(x, ...)
125+
}
126+
}
127+
return(invisible(x))
128+
}
129+
108130
# internal ----------------------------------------------------------------
109131

110132
#' Print dimensions of log-likelihood or log-weights matrix

man/print.loo.Rd

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/print_plot.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,54 @@
6666

6767
WAoAAAACAAQFAAACAwAAAAAOAAAAAT+2J8YDcP5s
6868

69+
# print.loo supports kfold with pareto-k diagnostics - calibrated
70+
71+
Code
72+
print(kfold1)
73+
Output
74+
75+
Based on 10-fold cross-validation.
76+
77+
Estimate SE
78+
elpd_kfold -285.0 9.2
79+
p_kfold 2.5 0.6
80+
kfoldic 570.0 18.4
81+
------
82+
83+
All Pareto k estimates are good (k < 0.7).
84+
See help('pareto-k-diagnostic') for details.
85+
86+
# print.loo supports kfold with pareto-k diagnostics - miscalibrated
87+
88+
Code
89+
print(kfold1)
90+
Output
91+
92+
Based on 10-fold cross-validation.
93+
94+
Estimate SE
95+
elpd_kfold -5556.6 701.0
96+
p_kfold 358.2 108.5
97+
kfoldic 11113.1 1401.9
98+
------
99+
100+
Pareto k diagnostic values:
101+
Count Pct. Min. ESS
102+
(-Inf, 0.7] (good) 245 93.5% 24
103+
(0.7, 1] (bad) 8 3.1% <NA>
104+
(1, Inf) (very bad) 9 3.4% <NA>
105+
See help('pareto-k-diagnostic') for details.
106+
107+
# print.loo supports kfold without pareto-k diagnostics
108+
109+
Code
110+
print(kfold1)
111+
Output
112+
113+
Based on 10-fold cross-validation.
114+
115+
Estimate SE
116+
elpd_kfold -5556.6 701.0
117+
p_kfold 358.2 108.5
118+
kfoldic 11113.1 1401.9
119+
7.49 KB
Binary file not shown.
9.58 KB
Binary file not shown.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
## Test data for testing print method of `kfold` object
2+
3+
### Case 1: All pareto-k values are good
4+
5+
```{r}
6+
set.seed(123)
7+
dat <- dplyr::tibble(
8+
x = rnorm(200),
9+
y = 2 + 1.5 * x + rnorm(200, sd = 1)
10+
)
11+
12+
fit <- brm(y ~ x, data = dat, seed = 42)
13+
kfold1 <- kfold(fit)
14+
saveRDS(kfold, "kfold-calibrated.Rds")
15+
```
16+
17+
### Case 2: Some pareto-k values are problematic
18+
19+
```{r}
20+
data(roaches, package = "rstanarm")
21+
roaches$sqrt_roach1 <- sqrt(roaches$roach1)
22+
23+
fit_p <- brm(y ~ sqrt_roach1 + treatment + senior + offset(log(exposure2)),
24+
data = roaches,
25+
family = poisson,
26+
prior = prior(normal(0,1), class = b),
27+
refresh = 0)
28+
29+
kfold2 <- kfold(fit_p)
30+
saveRDS(kfold2, "kfold-miscalibrated.Rds")
31+
```

tests/testthat/test_print_plot.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,24 @@ test_that("mcse_loo returns NA when it should", {
163163
test_that("mcse_loo errors if not psis_loo object", {
164164
expect_error(mcse_loo(psis1), "psis_loo")
165165
})
166+
167+
# print.loo kfold objects --------------------------------------------------
168+
169+
test_that("print.loo supports kfold with pareto-k diagnostics - calibrated", {
170+
kfold1 <- readRDS("data-for-tests/kfold-calibrated.Rds")
171+
172+
expect_snapshot(print(kfold1))
173+
})
174+
175+
test_that("print.loo supports kfold with pareto-k diagnostics - miscalibrated", {
176+
kfold1 <- readRDS("data-for-tests/kfold-miscalibrated.Rds")
177+
178+
expect_snapshot(print(kfold1))
179+
})
180+
181+
test_that("print.loo supports kfold without pareto-k diagnostics", {
182+
kfold1 <- readRDS("data-for-tests/kfold-miscalibrated.Rds")
183+
kfold1$diagnostics <- NULL
184+
185+
expect_snapshot(print(kfold1))
186+
})

0 commit comments

Comments
 (0)