Skip to content

Commit f0799d9

Browse files
authored
Merge pull request #442 from stan-dev/fix-pareto_smooth
fix pareto_smooth in case of constant tail + cleanup
2 parents e3d2b1d + 8ba5a9f commit f0799d9

2 files changed

Lines changed: 86 additions & 49 deletions

File tree

R/pareto_smooth.R

Lines changed: 39 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,10 @@ pareto_smooth.default <- function(x,
361361
x <- smoothed$x
362362
}
363363

364+
if (is.na(k)) {
365+
return(pareto_diags_na(x, return_k, extra_diags))
366+
}
367+
364368
diags_list <- list(khat = k)
365369

366370
if (extra_diags) {
@@ -470,38 +474,42 @@ pareto_convergence_rate.rvar <- function(x, ...) {
470474
#'
471475
#' @export
472476
ps_tail <- function(x,
473-
ndraws_tail,
474-
smooth_draws = TRUE,
475-
tail = c("right", "left"),
476-
are_log_weights = FALSE,
477-
...
478-
) {
477+
ndraws_tail,
478+
smooth_draws = TRUE,
479+
tail = c("right", "left"),
480+
are_log_weights = FALSE,
481+
...
482+
) {
483+
484+
if (ndraws_tail < 5) {
485+
warning_no_call(
486+
"Can't fit generalized Pareto distribution ",
487+
"because ndraws_tail is less than 5."
488+
)
489+
return(list(x = x, k = NA))
490+
}
479491

480492
if (are_log_weights) {
481493
# shift log values for safe exponentiation
482494
x <- x - max(x)
483495
}
484496

485497
tail <- match.arg(tail)
486-
487-
ndraws <- length(x)
488-
tail_ids <- seq(ndraws - ndraws_tail + 1, ndraws)
489-
490498
if (tail == "left") {
491499
x <- -x
492500
}
493501

502+
ndraws <- length(x)
503+
tail_ids <- seq(ndraws - ndraws_tail + 1, ndraws)
504+
494505
ord <- sort.int(x, index.return = TRUE)
495506
draws_tail <- ord$x[tail_ids]
496507

497508
if (is_constant(draws_tail)) {
498-
499509
if (tail == "left") {
500510
x <- -x
501511
}
502-
503-
out <- list(x = x, k = NA)
504-
return(out)
512+
return(list(x = x, k = NA))
505513
}
506514

507515
cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values
@@ -511,43 +519,22 @@ ps_tail <- function(x,
511519
}
512520

513521
max_tail <- max(draws_tail)
514-
min_tail <- min(draws_tail)
515-
516-
if (ndraws_tail >= 5) {
517-
ord <- sort.int(x, index.return = TRUE)
518-
if (abs(max_tail - min_tail) < .Machine$double.eps / 100) {
519-
warning_no_call(
520-
"Can't fit generalized Pareto distribution ",
521-
"because all tail values are the same."
522-
)
523-
smoothed <- NULL
524-
k <- NA
525-
} else {
526-
# save time not sorting since x already sorted
527-
if (are_log_weights) {
528-
draws_tail <- exp(draws_tail)
529-
cutoff <- exp(cutoff)
530-
}
531-
fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE, ...)
532-
k <- fit$k
533-
sigma <- fit$sigma
534-
if (is.finite(k) && smooth_draws) {
535-
p <- (seq_len(ndraws_tail) - 0.5) / ndraws_tail
536-
smoothed <- qgeneralized_pareto(p = p, mu = cutoff, k = k, sigma = sigma)
537-
if (are_log_weights) {
538-
smoothed <- log(smoothed)
539-
}
540-
} else {
541-
smoothed <- NULL
542-
}
522+
523+
if (are_log_weights) {
524+
draws_tail <- exp(draws_tail)
525+
cutoff <- exp(cutoff)
526+
}
527+
fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE, ...)
528+
k <- fit$k
529+
sigma <- fit$sigma
530+
if (is.finite(k) && smooth_draws) {
531+
p <- (seq_len(ndraws_tail) - 0.5) / ndraws_tail
532+
smoothed <- qgeneralized_pareto(p = p, mu = cutoff, k = k, sigma = sigma)
533+
if (are_log_weights) {
534+
smoothed <- log(smoothed)
543535
}
544536
} else {
545-
warning_no_call(
546-
"Can't fit generalized Pareto distribution ",
547-
"because ndraws_tail is less than 5."
548-
)
549537
smoothed <- NULL
550-
k <- NA
551538
}
552539

553540
# truncate at max of raw draws
@@ -598,7 +585,10 @@ ps_tail <- function(x,
598585
#' @return minimum sample size
599586
#' @export
600587
ps_min_ss <- function(k, ...) {
601-
if (k < 1) {
588+
if (is.na(k)) {
589+
return(NA)
590+
}
591+
if (isTRUE(k < 1)) {
602592
out <- 10^(1 / (1 - max(0, k)))
603593
} else {
604594
out <- Inf

tests/testthat/test-pareto_smooth.R

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,19 @@ test_that("pareto_khat handles constant tail correctly", {
1313
})
1414

1515

16+
test_that("pareto_smooth handles non-constant x with constant tail", {
17+
18+
# x is non-constant but the right tail is constant
19+
x <- c(seq(-3, 3, length.out = 90), rep(5, 10))
20+
21+
expect_no_error(
22+
ps <- pareto_smooth(x, tail = "right", ndraws_tail = 10, return_k = TRUE, verbose = FALSE)
23+
)
24+
expect_true(is.na(ps$diagnostics$khat))
25+
expect_equal(ps$x, x)
26+
27+
})
28+
1629
test_that("pareto_khat handles tail argument", {
1730

1831
# as tau is bounded (0, Inf) the left pareto k should be lower than
@@ -235,3 +248,37 @@ test_that("pareto_smooth works for log_weights", {
235248
expect_true(ps$diagnostics$khat > 0.7)
236249

237250
})
251+
252+
test_that("check ps_tail behavior for ndraws_tail less than 5", {
253+
w <- c(1:25, 1e3, 1e3, 1e3)
254+
lw <- log(w)
255+
256+
# prints correct warning
257+
expect_warning(
258+
tail <- ps_tail(lw, ndraws_tail = 4, tail = "right"),
259+
"Can't fit generalized Pareto distribution because ndraws_tail is less than 5."
260+
)
261+
# output has expected shape and k = NA
262+
expect_equal(names(tail), c("x", "k"))
263+
expect_true(is.na(tail$k))
264+
})
265+
266+
test_that("check ps_tail behavior for constant draws_tail", {
267+
x <- log(replicate(10, 0.3))
268+
269+
tail <- ps_tail(x, ndraws_tail = 10, tail = "left")
270+
271+
# output has expected return values
272+
expect_equal(x, tail$x)
273+
expect_true(is.na(tail$k))
274+
})
275+
276+
277+
test_that("check ps_min_ss behavior special cases", {
278+
# k = NA
279+
expect_true(is.na(ps_min_ss(NA)))
280+
# k > 1
281+
expect_true(is.infinite(ps_min_ss(2)))
282+
# k < 1
283+
expect_equal(ps_min_ss(0.5), 10^(1 / (1 - max(0, 0.5))))
284+
})

0 commit comments

Comments
 (0)