diff --git a/NEWS.md b/NEWS.md index c8f7d131..a58443f8 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # bayesplot (development version) +* Fixed `validate_chain_list()` colnames check to compare all chains, not just the first two. * Replace `apply()` with `storage.mode()` for integer-to-numeric matrix conversion in `validate_predictions()`. * Fixed `is_chain_list()` to correctly reject empty lists instead of silently returning `TRUE`. * Added unit tests for `mcmc_areas_ridges_data()`, `mcmc_parcoord_data()`, and `mcmc_trace_data()`. diff --git a/R/helpers-mcmc.R b/R/helpers-mcmc.R index 0f3970d9..179ec020 100644 --- a/R/helpers-mcmc.R +++ b/R/helpers-mcmc.R @@ -277,12 +277,8 @@ validate_chain_list <- function(x) { abort("Each chain should have the same number of iterations.") } - cnames <- sapply(x, colnames) - if (is.array(cnames)) { - same_params <- identical(cnames[, 1], cnames[, 2]) - } else { - same_params <- length(unique(cnames)) == 1 - } + cnames <- lapply(x, colnames) + same_params <- all(vapply(cnames[-1], identical, logical(1), cnames[[1]])) if (!same_params) { abort(paste( "The parameters for each chain should be in the same order", diff --git a/tests/testthat/test-helpers-mcmc.R b/tests/testthat/test-helpers-mcmc.R index 409b0b99..84347f23 100644 --- a/tests/testthat/test-helpers-mcmc.R +++ b/tests/testthat/test-helpers-mcmc.R @@ -178,6 +178,18 @@ test_that("validate_chain_list works", { "Each chain should have the same number of iterations") }) +test_that("validate_chain_list detects colnames mismatch in chain 3+", { + ch <- matrix(rnorm(20), nrow = 2, dimnames = list(NULL, c("a", "b", "c", "d", "e", + "f", "g", "h", "i", "j"))) + chain3_bad <- ch + colnames(chain3_bad)[1] <- "z" + chains_ok <- list(ch, ch, ch) + chains_bad <- list(ch, ch, chain3_bad) + + expect_identical(validate_chain_list(chains_ok), chains_ok) + expect_error(validate_chain_list(chains_bad), "parameters for each chain") +}) + test_that("chain_list2array works", { expect_mcmc_array(chain_list2array(chainlist)) expect_mcmc_array(chain_list2array(chainlist1))