Skip to content

Commit 042663b

Browse files
authored
Merge pull request #499 from utkarshpawade/fix/df-with-chain2array-unequal-iterations
Fix `df_with_chain2array()` silently recycling data with unequal chain lengths
2 parents d15d028 + d1b7c77 commit 042663b

File tree

5 files changed

+51
-3
lines changed

5 files changed

+51
-3
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# bayesplot (development version)
22

3+
* Validate equal chain lengths in `validate_df_with_chain()`, reject missing
4+
chain labels, and renumber data-frame chain labels internally when converting
5+
to arrays.
36
* Added unit tests for previously untested edge cases in `param_range()`, `param_glue()`, and `tidyselect_parameters()` (no-match, partial-match, and negation behavior).
47
* Bumped minimum version for `rstantools` from `>= 1.5.0` to `>= 2.0.0` .
58
* Use `rlang::warn()` and `rlang::inform()` for selected PPC user messages instead of base `warning()` and `message()`.

R/helpers-mcmc.R

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,13 @@ validate_df_with_chain <- function(x) {
210210
x$chain <- NULL
211211
}
212212
x$Chain <- as.integer(x$Chain)
213+
if (anyNA(x$Chain)) {
214+
abort("Chain values must not be NA.")
215+
}
216+
rows_per_chain <- table(x$Chain)
217+
if (length(unique(rows_per_chain)) != 1) {
218+
abort("All chains must have the same number of iterations.")
219+
}
213220
x
214221
}
215222

@@ -218,11 +225,14 @@ validate_df_with_chain <- function(x) {
218225
df_with_chain2array <- function(x) {
219226
x <- validate_df_with_chain(x)
220227
chain <- x$Chain
228+
# Renumber arbitrary chain labels to the contiguous 1:N indices used internally.
229+
chain <- match(chain, sort(unique(chain)))
221230
n_chain <- length(unique(chain))
222231
a <- x[, !colnames(x) %in% "Chain", drop = FALSE]
223232
parnames <- colnames(a)
224233
a <- as.matrix(a)
225-
x <- array(NA, dim = c(ceiling(nrow(a) / n_chain), n_chain, ncol(a)))
234+
n_iter <- nrow(a) %/% n_chain
235+
x <- array(NA, dim = c(n_iter, n_chain, ncol(a)))
226236
for (j in seq_len(n_chain)) {
227237
x[, j, ] <- a[chain == j,, drop=FALSE]
228238
}

R/mcmc-overview.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
#' frame with one column per parameter (if only a single chain or all chains
2525
#' have already been merged), or a data frame with one column per parameter plus
2626
#' an additional column `"Chain"` that contains the chain number (an integer)
27-
#' corresponding to each row in the data frame.
27+
#' corresponding to each row in the data frame. When a `"Chain"` column is
28+
#' supplied, each chain must have the same number of iterations. Chain labels
29+
#' are used to identify groups and are renumbered internally to `1:N`.
2830
#' * __draws__: Any of the `draws` formats supported by the
2931
#' \pkg{posterior} package.
3032
#'

man/MCMC-overview.Rd

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-helpers-mcmc.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,43 @@ test_that("validate_df_with_chain works", {
106106
tbl <- tibble::tibble(parameter=rnorm(n=40), Chain=rep(1:4, each=10))
107107
a <- validate_df_with_chain(tbl)
108108
expect_type(a$Chain, "integer")
109+
110+
missing_chain_df <- data.frame(
111+
Chain = c(1L, 1L, NA_integer_, NA_integer_),
112+
V1 = rnorm(4),
113+
V2 = rnorm(4)
114+
)
115+
expect_error(validate_df_with_chain(missing_chain_df),
116+
"Chain values must not be NA")
109117
})
110118

111119
test_that("df_with_chain2array works", {
112120
a <- df_with_chain2array(dframe_multiple_chains)
113121
expect_mcmc_array(a)
114122

115123
expect_error(df_with_chain2array(dframe), "is_df_with_chain")
124+
125+
# Unequal chain lengths should error via validate_df_with_chain
126+
unequal_df <- data.frame(
127+
Chain = c(1L, 1L, 1L, 1L, 2L, 2L, 2L),
128+
V1 = rnorm(7),
129+
V2 = rnorm(7)
130+
)
131+
expect_error(validate_df_with_chain(unequal_df),
132+
"All chains must have the same number of iterations")
133+
expect_error(df_with_chain2array(unequal_df),
134+
"All chains must have the same number of iterations")
135+
136+
renumbered_df <- data.frame(
137+
Chain = c(2L, 2L, 3L, 3L),
138+
V1 = 1:4,
139+
V2 = 5:8
140+
)
141+
a <- df_with_chain2array(renumbered_df)
142+
expect_equal(dim(a), c(2, 2, 2))
143+
expect_identical(unname(a[, 1, "V1"]), c(1L, 2L))
144+
expect_identical(unname(a[, 2, "V1"]), c(3L, 4L))
145+
expect_identical(as.character(dimnames(a)$Chain), c("1", "2"))
116146
})
117147

118148

@@ -305,6 +335,7 @@ test_that("diagnostic_factor.rhat works", {
305335
)
306336
expect_identical(levels(r), c("low", "ok", "high"))
307337
})
338+
308339
test_that("diagnostic_factor.neff_ratio works", {
309340
ratios <- new_neff_ratio(c(low = 0.05, low = 0.01,
310341
ok = 0.2, ok = 0.49,

0 commit comments

Comments
 (0)