Skip to content

Commit 471afa8

Browse files
authored
Merge pull request #508 from utkarshpawade/test/mcmc-data-functions-coverage
Add test coverage for mcmc_areas_ridges_data, mcmc_parcoord_data, mcmc_trace_data
2 parents 015be80 + 72ec17e commit 471afa8

File tree

4 files changed

+129
-41
lines changed

4 files changed

+129
-41
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# bayesplot (development version)
22

3+
* Added unit tests for `mcmc_areas_ridges_data()`, `mcmc_parcoord_data()`, and `mcmc_trace_data()`.
34
* Added unit tests for `ppc_error_data()` and `ppc_loo_pit_data()` covering output structure, argument handling, and edge cases.
45
* Added vignette sections demonstrating `*_data()` companion functions for building custom ggplot2 visualizations (#435)
56
* Extract `drop_singleton_values()` helper in `mcmc_nuts_treedepth()` to remove duplicated filtering logic.

tests/testthat/test-mcmc-intervals.R

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,5 @@
11
source(test_path("data-for-mcmc-tests.R"))
22

3-
test_that("mcmc_intervals_data computes quantiles", {
4-
xs <- melt_mcmc(merge_chains(prepare_mcmc_array(arr, pars = "beta[1]")))
5-
d <- mcmc_intervals_data(arr, pars = "beta[1]",
6-
prob = .3, prob_outer = .5)
7-
8-
qs <- unlist(d[, c("ll", "l", "m", "h", "hh")])
9-
by_hand <- quantile(xs$Value, c(.25, .35, .5, .65, .75))
10-
expect_equal(qs, by_hand, ignore_attr = TRUE)
11-
12-
expect_equal(d$parameter, factor("beta[1]"))
13-
expect_equal(d$outer_width, .5)
14-
expect_equal(d$inner_width, .3)
15-
expect_equal(d$point_est, "median")
16-
17-
d2 <- mcmc_areas_data(arr, pars = "beta[1]", prob = .3, prob_outer = .5)
18-
sets <- split(d2, d2$interval)
19-
20-
expect_equal(range(sets$inner$x), c(d$l, d$h))
21-
expect_equal(range(sets$outer$x), c(d$ll, d$hh))
22-
})
23-
24-
test_that("mcmc_intervals_data computes point estimates", {
25-
xs <- melt_mcmc(merge_chains(prepare_mcmc_array(arr, pars = "beta[2]")))
26-
d <- mcmc_intervals_data(arr, pars = "beta[2]",
27-
prob = .3, prob_outer = .5, point_est = "mean")
28-
29-
expect_equal(d$m, mean(xs$Value), ignore_attr = TRUE)
30-
expect_equal(d$parameter, factor("beta[2]"))
31-
expect_equal(d$point_est, "mean")
32-
33-
d <- mcmc_intervals_data(arr, pars = "(Intercept)",
34-
prob = .3, prob_outer = .5,
35-
point_est = "none")
36-
expect_true(!("m" %in% names(d)))
37-
expect_equal(d$point_est, "none")
38-
})
39-
403
test_that("mcmc_intervals returns a ggplot object", {
414
expect_gg(mcmc_intervals(arr, pars = "beta[1]", regex_pars = "x\\:"))
425
expect_gg(mcmc_intervals(arr1chain, pars = "beta[1]", regex_pars = "Intercept"))
@@ -115,6 +78,45 @@ test_that("mcmc_intervals/areas with rhat", {
11578
}
11679
})
11780

81+
# _data() tests ----------------------------------------------------------------
82+
83+
test_that("mcmc_intervals_data computes quantiles", {
84+
xs <- melt_mcmc(merge_chains(prepare_mcmc_array(arr, pars = "beta[1]")))
85+
d <- mcmc_intervals_data(arr, pars = "beta[1]",
86+
prob = .3, prob_outer = .5)
87+
88+
qs <- unlist(d[, c("ll", "l", "m", "h", "hh")])
89+
by_hand <- quantile(xs$Value, c(.25, .35, .5, .65, .75))
90+
expect_equal(qs, by_hand, ignore_attr = TRUE)
91+
92+
expect_equal(d$parameter, factor("beta[1]"))
93+
expect_equal(d$outer_width, .5)
94+
expect_equal(d$inner_width, .3)
95+
expect_equal(d$point_est, "median")
96+
97+
d2 <- mcmc_areas_data(arr, pars = "beta[1]", prob = .3, prob_outer = .5)
98+
sets <- split(d2, d2$interval)
99+
100+
expect_equal(range(sets$inner$x), c(d$l, d$h))
101+
expect_equal(range(sets$outer$x), c(d$ll, d$hh))
102+
})
103+
104+
test_that("mcmc_intervals_data computes point estimates", {
105+
xs <- melt_mcmc(merge_chains(prepare_mcmc_array(arr, pars = "beta[2]")))
106+
d <- mcmc_intervals_data(arr, pars = "beta[2]",
107+
prob = .3, prob_outer = .5, point_est = "mean")
108+
109+
expect_equal(d$m, mean(xs$Value), ignore_attr = TRUE)
110+
expect_equal(d$parameter, factor("beta[2]"))
111+
expect_equal(d$point_est, "mean")
112+
113+
d <- mcmc_intervals_data(arr, pars = "(Intercept)",
114+
prob = .3, prob_outer = .5,
115+
point_est = "none")
116+
expect_true(!("m" %in% names(d)))
117+
expect_equal(d$point_est, "none")
118+
})
119+
118120
test_that("mcmc_areas_data computes density", {
119121
areas_data <- mcmc_areas_data(arr, point_est = "none")
120122
areas_data <- areas_data[areas_data$interval_width == 1, ]
@@ -153,7 +155,7 @@ test_that("compute_column_density can use density options (#118)", {
153155
expect_error(mcmc_areas_data(arr, kernel = stop()))
154156
})
155157

156-
test_that("probabilities outside of [0,1] cause an error", {
158+
test_that("mcmc_intervals_data errors for probabilities outside of [0,1]", {
157159
expect_error(mcmc_intervals_data(arr, prob = -0.1),
158160
"must be in \\[0,1\\]")
159161
expect_error(mcmc_intervals_data(arr, prob = 1.1),
@@ -164,14 +166,28 @@ test_that("probabilities outside of [0,1] cause an error", {
164166
"must be in \\[0,1\\]")
165167
})
166168

167-
test_that("inconsistent probabilities raise warning (#138)", {
169+
test_that("mcmc_intervals_data warns for inconsistent probabilities (#138)", {
168170
expect_warning(
169171
mcmc_intervals_data(arr, prob = .9, prob_outer = .8),
170172
"`prob_outer` .* is less than `prob`"
171173
)
172174
})
173175

174176

177+
test_that("mcmc_areas_ridges_data returns correct structure", {
178+
d <- mcmc_areas_ridges_data(arr, pars = c("beta[1]", "sigma"), prob = 0.5, prob_outer = 0.9)
179+
expect_s3_class(d, "data.frame")
180+
expect_named(
181+
d,
182+
c(
183+
"parameter", "interval", "interval_width", "x", "density",
184+
"scaled_density", "plotting_density"
185+
)
186+
)
187+
expect_setequal(unique(d$interval), c("inner", "outer"))
188+
expect_false("point" %in% d$interval)
189+
expect_equal(unique(as.character(d$parameter)), c("beta[1]", "sigma"))
190+
})
175191

176192

177193
# Visual tests -----------------------------------------------------------------

tests/testthat/test-mcmc-scatter-and-parcoord.R

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,6 @@ test_that("pairs_condition message if multiple args specified", {
313313
})
314314

315315

316-
317316
# mcmc_parcoord -----------------------------------------------------------
318317
test_that("mcmc_parcoord returns a ggplot object", {
319318
expect_gg(mcmc_parcoord(arr, pars = c("(Intercept)", "sigma")))
@@ -351,7 +350,6 @@ test_that("mcmc_parcoord throws correct warnings and errors", {
351350
)
352351
})
353352

354-
355353
# parcoord_style_np -------------------------------------------------------
356354
test_that("parcoord_style_np returns correct structure", {
357355
style <- parcoord_style_np()
@@ -375,6 +373,42 @@ test_that("parcoord_style_np throws correct errors", {
375373
)
376374
})
377375

376+
# mcmc_parcoord_data -------------------------------------------------
377+
378+
test_that("mcmc_parcoord_data returns expected structure", {
379+
d <- mcmc_parcoord_data(arr, pars = c("(Intercept)", "sigma"))
380+
expect_s3_class(d, "data.frame")
381+
expect_named(d, c("Draw", "Parameter", "Value", "Divergent"))
382+
383+
draws_by_parameter <- split(d$Draw, d$Parameter)
384+
expected_draws <- seq_len(dim(arr)[1] * dim(arr)[2])
385+
expect_equal(draws_by_parameter[[1]], expected_draws)
386+
expect_equal(draws_by_parameter[[2]], expected_draws)
387+
})
388+
389+
test_that("mcmc_parcoord_data sets Divergent to 0 when np is NULL", {
390+
d <- mcmc_parcoord_data(arr, pars = c("(Intercept)", "sigma"))
391+
expect_true(all(d$Divergent == 0))
392+
})
393+
394+
test_that("mcmc_parcoord_data joins divergence information from np", {
395+
fake_np <- data.frame(
396+
Iteration = rep(seq_len(dim(arr)[1]), each = dim(arr)[2]),
397+
Chain = rep(seq_len(dim(arr)[2]), times = dim(arr)[1]),
398+
Parameter = factor("divergent__"),
399+
Value = as.integer(rep(c(0, 1, 0, 1), times = dim(arr)[1]))
400+
)
401+
d <- mcmc_parcoord_data(arr, pars = c("(Intercept)", "sigma"), np = fake_np)
402+
403+
expect_false(anyNA(d$Divergent))
404+
expect_equal(sum(d$Divergent == 1), 400)
405+
expect_equal(sum(d$Divergent == 0), 400)
406+
})
407+
408+
test_that("mcmc_parcoord_data errors with fewer than 2 parameters", {
409+
expect_error(mcmc_parcoord_data(arr, pars = "sigma"), "at least two")
410+
})
411+
378412

379413
# Visual tests -----------------------------------------------------------------
380414

tests/testthat/test-mcmc-traces.R

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,44 @@ test_that("mcmc_trace 'np' argument works", {
100100
"No divergences to plot.")
101101
})
102102

103+
# mcmc_trace_data ----------------------------------------------------
104+
105+
test_that("mcmc_trace_data returns plotting data with expected columns", {
106+
d <- mcmc_trace_data(arr, pars = "beta[1]")
107+
expect_s3_class(d, "tbl_df")
108+
expect_named(
109+
d,
110+
c(
111+
"parameter", "value", "value_rank", "iteration", "chain",
112+
"n_chains", "n_iterations", "n_parameters", "highlight", "warmup"
113+
)
114+
)
115+
expect_equal(nrow(d), dim(arr)[1] * dim(arr)[2])
116+
})
103117

118+
test_that("mcmc_trace_data highlight argument works", {
119+
d <- mcmc_trace_data(arr, pars = "beta[1]", highlight = 2)
120+
expect_true(all(d$highlight[d$chain == 2]))
121+
expect_true(all(!d$highlight[d$chain != 2]))
122+
})
123+
124+
test_that("mcmc_trace_data warmup labeling works", {
125+
d <- mcmc_trace_data(arr, pars = "beta[1]", n_warmup = 20)
126+
expect_true(all(d$warmup[d$iteration <= 20]))
127+
expect_true(all(!d$warmup[d$iteration > 20]))
128+
})
129+
130+
test_that("mcmc_trace_data iter1 shifts iterations", {
131+
d <- mcmc_trace_data(arr, pars = "beta[1]", iter1 = 100)
132+
expect_true(min(d$iteration) == 101)
133+
})
134+
135+
test_that("mcmc_trace_data computes value_rank within each parameter", {
136+
d <- mcmc_trace_data(arr, pars = c("beta[1]", "beta[2]"))
137+
observed_ranks <- split(d$value_rank, d$parameter)
138+
expected_ranks <- lapply(split(d$value, d$parameter), rank, ties.method = "average")
139+
expect_equal(observed_ranks, expected_ranks)
140+
})
104141

105142

106143
# Visual tests -----------------------------------------------------------------

0 commit comments

Comments
 (0)