Skip to content

Commit 834284a

Browse files
committed
Move unequal chain length check to validate_df_with_chain()
1 parent e9b5d6c commit 834284a

3 files changed

Lines changed: 9 additions & 7 deletions

File tree

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# bayesplot (development version)
22

3-
* Fix `df_with_chain2array()` silently recycling data when chains have unequal iterations.
3+
* Validate equal chain lengths in `validate_df_with_chain()`.
44
* Added unit tests for previously untested edge cases in `param_range()`, `param_glue()`, and `tidyselect_parameters()` (no-match, partial-match, and negation behavior).
55
* Bumped minimum version for `rstantools` from `>= 1.5.0` to `>= 2.0.0` .
66
* Use `rlang::warn()` and `rlang::inform()` for selected PPC user messages instead of base `warning()` and `message()`.

R/helpers-mcmc.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ validate_df_with_chain <- function(x) {
210210
x$chain <- NULL
211211
}
212212
x$Chain <- as.integer(x$Chain)
213+
rows_per_chain <- table(x$Chain)
214+
if (length(unique(rows_per_chain)) != 1) {
215+
abort("All chains must have the same number of iterations.")
216+
}
213217
x
214218
}
215219

@@ -222,11 +226,7 @@ df_with_chain2array <- function(x) {
222226
a <- x[, !colnames(x) %in% "Chain", drop = FALSE]
223227
parnames <- colnames(a)
224228
a <- as.matrix(a)
225-
rows_per_chain <- table(chain)
226-
if (length(unique(rows_per_chain)) != 1) {
227-
abort("All chains must have the same number of iterations.")
228-
}
229-
n_iter <- as.integer(rows_per_chain[[1]])
229+
n_iter <- nrow(a) %/% n_chain
230230
x <- array(NA, dim = c(n_iter, n_chain, ncol(a)))
231231
for (j in seq_len(n_chain)) {
232232
x[, j, ] <- a[chain == j,, drop=FALSE]

tests/testthat/test-helpers-mcmc.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,14 @@ test_that("df_with_chain2array works", {
114114

115115
expect_error(df_with_chain2array(dframe), "is_df_with_chain")
116116

117-
# Unequal chain lengths should error, not silently recycle
117+
# Unequal chain lengths should error via validate_df_with_chain
118118
unequal_df <- data.frame(
119119
Chain = c(1L, 1L, 1L, 1L, 2L, 2L, 2L),
120120
V1 = rnorm(7),
121121
V2 = rnorm(7)
122122
)
123+
expect_error(validate_df_with_chain(unequal_df),
124+
"All chains must have the same number of iterations")
123125
expect_error(df_with_chain2array(unequal_df),
124126
"All chains must have the same number of iterations")
125127
})

0 commit comments

Comments
 (0)