Skip to content

Commit d1b7c77

Browse files
committed
Handle invalid/misaligned Chain labels in df_with_chain2array()
1 parent 834284a commit d1b7c77

5 files changed

Lines changed: 34 additions & 3 deletions

File tree

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# bayesplot (development version)
22

3-
* Validate equal chain lengths in `validate_df_with_chain()`.
3+
* Validate equal chain lengths in `validate_df_with_chain()`, reject missing
4+
chain labels, and renumber data-frame chain labels internally when converting
5+
to arrays.
46
* Added unit tests for previously untested edge cases in `param_range()`, `param_glue()`, and `tidyselect_parameters()` (no-match, partial-match, and negation behavior).
57
* Bumped minimum version for `rstantools` from `>= 1.5.0` to `>= 2.0.0` .
68
* Use `rlang::warn()` and `rlang::inform()` for selected PPC user messages instead of base `warning()` and `message()`.

R/helpers-mcmc.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ validate_df_with_chain <- function(x) {
210210
x$chain <- NULL
211211
}
212212
x$Chain <- as.integer(x$Chain)
213+
if (anyNA(x$Chain)) {
214+
abort("Chain values must not be NA.")
215+
}
213216
rows_per_chain <- table(x$Chain)
214217
if (length(unique(rows_per_chain)) != 1) {
215218
abort("All chains must have the same number of iterations.")
@@ -222,6 +225,8 @@ validate_df_with_chain <- function(x) {
222225
df_with_chain2array <- function(x) {
223226
x <- validate_df_with_chain(x)
224227
chain <- x$Chain
228+
# Renumber arbitrary chain labels to the contiguous 1:N indices used internally.
229+
chain <- match(chain, sort(unique(chain)))
225230
n_chain <- length(unique(chain))
226231
a <- x[, !colnames(x) %in% "Chain", drop = FALSE]
227232
parnames <- colnames(a)

R/mcmc-overview.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
#' frame with one column per parameter (if only a single chain or all chains
2525
#' have already been merged), or a data frame with one column per parameter plus
2626
#' an additional column `"Chain"` that contains the chain number (an integer)
27-
#' corresponding to each row in the data frame.
27+
#' corresponding to each row in the data frame. When a `"Chain"` column is
28+
#' supplied, each chain must have the same number of iterations. Chain labels
29+
#' are used to identify groups and are renumbered internally to `1:N`.
2830
#' * __draws__: Any of the `draws` formats supported by the
2931
#' \pkg{posterior} package.
3032
#'

man/MCMC-overview.Rd

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-helpers-mcmc.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ test_that("validate_df_with_chain works", {
106106
tbl <- tibble::tibble(parameter=rnorm(n=40), Chain=rep(1:4, each=10))
107107
a <- validate_df_with_chain(tbl)
108108
expect_type(a$Chain, "integer")
109+
110+
missing_chain_df <- data.frame(
111+
Chain = c(1L, 1L, NA_integer_, NA_integer_),
112+
V1 = rnorm(4),
113+
V2 = rnorm(4)
114+
)
115+
expect_error(validate_df_with_chain(missing_chain_df),
116+
"Chain values must not be NA")
109117
})
110118

111119
test_that("df_with_chain2array works", {
@@ -124,6 +132,17 @@ test_that("df_with_chain2array works", {
124132
"All chains must have the same number of iterations")
125133
expect_error(df_with_chain2array(unequal_df),
126134
"All chains must have the same number of iterations")
135+
136+
renumbered_df <- data.frame(
137+
Chain = c(2L, 2L, 3L, 3L),
138+
V1 = 1:4,
139+
V2 = 5:8
140+
)
141+
a <- df_with_chain2array(renumbered_df)
142+
expect_equal(dim(a), c(2, 2, 2))
143+
expect_identical(unname(a[, 1, "V1"]), c(1L, 2L))
144+
expect_identical(unname(a[, 2, "V1"]), c(3L, 4L))
145+
expect_identical(as.character(dimnames(a)$Chain), c("1", "2"))
127146
})
128147

129148

@@ -316,6 +335,7 @@ test_that("diagnostic_factor.rhat works", {
316335
)
317336
expect_identical(levels(r), c("low", "ok", "high"))
318337
})
338+
319339
test_that("diagnostic_factor.neff_ratio works", {
320340
ratios <- new_neff_ratio(c(low = 0.05, low = 0.01,
321341
ok = 0.2, ok = 0.49,

0 commit comments

Comments
 (0)