Skip to content

Commit 9d6a95b

Browse files
authored
Merge pull request #529 from utkarshpawade/fix/validate-chain-list-colnames-check
Fix validate_chain_list() colnames check to compare all chains
2 parents 14c1c74 + 8637e82 commit 9d6a95b

3 files changed

Lines changed: 15 additions & 6 deletions

File tree

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# bayesplot (development version)
22

3+
* Fixed `validate_chain_list()` colnames check to compare all chains, not just the first two.
34
* Added test verifying `legend_move("none")` behaves equivalently to `legend_none()`.
45
* Added singleton-dimension edge-case tests for exported `_data()` functions.
56
* Validate empty list and zero-row matrix inputs in `nuts_params.list()`.

R/helpers-mcmc.R

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,8 @@ validate_chain_list <- function(x) {
277277
abort("Each chain should have the same number of iterations.")
278278
}
279279

280-
cnames <- sapply(x, colnames)
281-
if (is.array(cnames)) {
282-
same_params <- identical(cnames[, 1], cnames[, 2])
283-
} else {
284-
same_params <- length(unique(cnames)) == 1
285-
}
280+
cnames <- lapply(x, colnames)
281+
same_params <- all(vapply(cnames[-1], identical, logical(1), cnames[[1]]))
286282
if (!same_params) {
287283
abort(paste(
288284
"The parameters for each chain should be in the same order",

tests/testthat/test-helpers-mcmc.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,18 @@ test_that("validate_chain_list works", {
178178
"Each chain should have the same number of iterations")
179179
})
180180

181+
test_that("validate_chain_list detects colnames mismatch in chain 3+", {
182+
ch <- matrix(rnorm(20), nrow = 2, dimnames = list(NULL, c("a", "b", "c", "d", "e",
183+
"f", "g", "h", "i", "j")))
184+
chain3_bad <- ch
185+
colnames(chain3_bad)[1] <- "z"
186+
chains_ok <- list(ch, ch, ch)
187+
chains_bad <- list(ch, ch, chain3_bad)
188+
189+
expect_identical(validate_chain_list(chains_ok), chains_ok)
190+
expect_error(validate_chain_list(chains_bad), "parameters for each chain")
191+
})
192+
181193
test_that("chain_list2array works", {
182194
expect_mcmc_array(chain_list2array(chainlist))
183195
expect_mcmc_array(chain_list2array(chainlist1))

0 commit comments

Comments
 (0)