Skip to content

Commit e9b5d6c

Browse files
committed
Fix df_with_chain2array() silently recycling data with unequal chain lengths
1 parent d15d028 commit e9b5d6c

3 files changed

Lines changed: 16 additions & 1 deletion

File tree

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# bayesplot (development version)
22

3+
* Fix `df_with_chain2array()` silently recycling data when chains have unequal iterations.
34
* Added unit tests for previously untested edge cases in `param_range()`, `param_glue()`, and `tidyselect_parameters()` (no-match, partial-match, and negation behavior).
45
* Bumped minimum version for `rstantools` from `>= 1.5.0` to `>= 2.0.0` .
56
* Use `rlang::warn()` and `rlang::inform()` for selected PPC user messages instead of base `warning()` and `message()`.

R/helpers-mcmc.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,12 @@ df_with_chain2array <- function(x) {
222222
a <- x[, !colnames(x) %in% "Chain", drop = FALSE]
223223
parnames <- colnames(a)
224224
a <- as.matrix(a)
225-
x <- array(NA, dim = c(ceiling(nrow(a) / n_chain), n_chain, ncol(a)))
225+
rows_per_chain <- table(chain)
226+
if (length(unique(rows_per_chain)) != 1) {
227+
abort("All chains must have the same number of iterations.")
228+
}
229+
n_iter <- as.integer(rows_per_chain[[1]])
230+
x <- array(NA, dim = c(n_iter, n_chain, ncol(a)))
226231
for (j in seq_len(n_chain)) {
227232
x[, j, ] <- a[chain == j,, drop=FALSE]
228233
}

tests/testthat/test-helpers-mcmc.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ test_that("df_with_chain2array works", {
113113
expect_mcmc_array(a)
114114

115115
expect_error(df_with_chain2array(dframe), "is_df_with_chain")
116+
117+
# Unequal chain lengths should error, not silently recycle
118+
unequal_df <- data.frame(
119+
Chain = c(1L, 1L, 1L, 1L, 2L, 2L, 2L),
120+
V1 = rnorm(7),
121+
V2 = rnorm(7)
122+
)
123+
expect_error(df_with_chain2array(unequal_df),
124+
"All chains must have the same number of iterations")
116125
})
117126

118127

0 commit comments

Comments
 (0)