Skip to content
Closed
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Imports:
checkmate,
matrixStats (>= 0.52),
parallel,
posterior (>= 1.5.0),
posterior (>= 1.7.0),
stats
Suggests:
bayesplot (>= 1.7.0),
Expand Down
5 changes: 2 additions & 3 deletions R/loo-package.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#' Efficient LOO-CV and WAIC for Bayesian models
#'
#' @docType package
#' @name loo-package
#' @keywords internal
#'
#' @importFrom stats sd var quantile setNames weights rnorm qnorm
#' @importFrom matrixStats logSumExp colLogSumExps colSums2 colVars colMaxs
Expand Down Expand Up @@ -89,4 +88,4 @@
#' for the generalized Pareto distribution. *Technometrics* **51**,
#' 316-325.
#'
NULL
"_PACKAGE"
63 changes: 17 additions & 46 deletions R/psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -212,62 +212,33 @@ do_psis_i <- function(log_ratios_i, tail_len_i, ...) {
S <- length(log_ratios_i)
# shift log ratios for safer exponentation
lw_i <- log_ratios_i - max(log_ratios_i)
khat <- Inf

if (enough_tail_samples(tail_len_i)) {
ord <- sort.int(lw_i, index.return = TRUE)
tail_ids <- seq(S - tail_len_i + 1, S)
lw_tail <- ord$x[tail_ids]
if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps / 100) {
warning(
"Can't fit generalized Pareto distribution ",
"because all tail values are the same.",
call. = FALSE
)
} else {
cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values
smoothed <- psis_smooth_tail(lw_tail, cutoff)
khat <- smoothed$k
lw_i[ord$ix[tail_ids]] <- smoothed$tail
}
if (length(unique(utils::tail(sort(log_ratios_i), tail_len_i))) == 1) {
warning(
"Can't fit generalized Pareto distribution ",
"because all tail values are the same.",
call. = FALSE
)
}

smoothed <- suppressWarnings(posterior::ps_tail(
x = lw_i,
ndraws_tail = tail_len_i,
tail = "right",
are_log_weights = TRUE
))

lw_i <- smoothed$x
khat <- smoothed$k

# truncate at max of raw wts (i.e., 0 since max has been subtracted)
lw_i[lw_i > 0] <- 0
# shift log weights back so that the smallest log weights remain unchanged
lw_i <- lw_i + max(log_ratios_i)

list(log_weights = lw_i, pareto_k = khat)
list(log_weights = lw_i, pareto_k = if (is.na(khat)) Inf else khat)
}

#' PSIS tail smoothing for a single vector
#'
#' @noRd
#' @param x Vector of tail elements already sorted in ascending order.
#' @return A named list containing:
#' * `tail`: vector same size as `x` containing the logs of the
#' order statistics of the generalized pareto distribution.
#' * `k`: scalar shape parameter estimate.
#'
psis_smooth_tail <- function(x, cutoff) {
len <- length(x)
exp_cutoff <- exp(cutoff)

# save time not sorting since x already sorted
fit <- gpdfit(exp(x) - exp_cutoff, sort_x = FALSE)
k <- fit$k
sigma <- fit$sigma
if (is.finite(k)) {
p <- (seq_len(len) - 0.5) / len
qq <- qgpd(p, k, sigma) + exp_cutoff
tail <- log(qq)
} else {
tail <- x
}
list(tail = tail, k = k)
}


#' Calculate tail lengths to use for fitting the GPD
#'
#' The number of weights (i.e., tail length) used to fit the generalized Pareto
Expand Down
1 change: 1 addition & 0 deletions man/loo-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 6 additions & 9 deletions tests/testthat/test_psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,11 @@ test_that("do_psis_i throws warning if all tail values the same", {
"all tail values are the same"
)
expect_equal(val$pareto_k, Inf)
})

test_that("psis_smooth_tail returns original tail values if k is infinite", {
# skip on M1 Mac until we figure out why this test fails only on M1 Mac
skip_if(Sys.info()[["sysname"]] == "Darwin" && R.version$arch == "aarch64")

xx <- c(1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4)
val <- suppressWarnings(psis_smooth_tail(xx, 3))
expect_equal(val$tail, xx)
expect_equal(val$k, Inf)
xx <- c(4, 1, 4, 4, 4, 4, 4, 2, 3, 4, 4)
expect_warning(
val <- do_psis_i(xx, tail_len_i = 6),
"all tail values are the same"
)
expect_equal(val$pareto_k, Inf)
})
Loading