Skip to content

Commit 78fb3a7

Browse files
committed
Refactor diagnostic alias normalization and strengthen tests
1 parent 7ed626c commit 78fb3a7

2 files changed

Lines changed: 29 additions & 10 deletions

File tree

R/mcmc-diagnostics.R

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -425,21 +425,26 @@ diagnostic_points <- function(size = NULL) {
425425
# Functions wrapping around scale_color_manual() and scale_fill_manual(), used to
426426
# color the intervals by rhat value
427427
scale_color_diagnostic <- function(diagnostic = c("rhat", "neff")) {
428-
d <- match.arg(diagnostic)
428+
d <- match.arg(diagnostic, choices = c("rhat", "neff", "neff_ratio"))
429429
diagnostic_color_scale(d, aesthetic = "color")
430430
}
431431

432432
scale_fill_diagnostic <- function(diagnostic = c("rhat", "neff")) {
433-
d <- match.arg(diagnostic)
433+
d <- match.arg(diagnostic, choices = c("rhat", "neff", "neff_ratio"))
434434
diagnostic_color_scale(d, aesthetic = "fill")
435435
}
436436

437-
diagnostic_color_scale <- function(diagnostic = c("rhat", "neff_ratio"),
438-
aesthetic = c("color", "fill")) {
437+
normalize_diagnostic_name <- function(diagnostic) {
439438
diagnostic <- match.arg(diagnostic, choices = c("rhat", "neff", "neff_ratio"))
440439
if (diagnostic == "neff") {
441440
diagnostic <- "neff_ratio"
442441
}
442+
diagnostic
443+
}
444+
445+
diagnostic_color_scale <- function(diagnostic = c("rhat", "neff_ratio"),
446+
aesthetic = c("color", "fill")) {
447+
diagnostic <- normalize_diagnostic_name(diagnostic)
443448
aesthetic <- match.arg(aesthetic)
444449
dc <- diagnostic_colors(diagnostic, aesthetic)
445450
do.call(
@@ -455,10 +460,7 @@ diagnostic_color_scale <- function(diagnostic = c("rhat", "neff_ratio"),
455460

456461
diagnostic_colors <- function(diagnostic = c("rhat", "neff_ratio"),
457462
aesthetic = c("color", "fill")) {
458-
diagnostic <- match.arg(diagnostic, choices = c("rhat", "neff", "neff_ratio"))
459-
if (diagnostic == "neff") {
460-
diagnostic <- "neff_ratio"
461-
}
463+
diagnostic <- normalize_diagnostic_name(diagnostic)
462464
aesthetic <- match.arg(aesthetic)
463465
color_levels <- c("light", "mid", "dark")
464466
if (diagnostic == "neff_ratio") {

tests/testthat/test-mcmc-diagnostics.R

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,27 @@ test_that("'description' & 'rating' columns are correct (#176)", {
8080
})
8181

8282
test_that("diagnostic color helpers handle neff names explicitly", {
83+
# wrappers accept both aliases
8384
expect_no_error(scale_color_diagnostic("neff"))
8485
expect_no_error(scale_fill_diagnostic("neff"))
85-
expect_no_error(diagnostic_color_scale("neff", aesthetic = "color"))
86-
expect_no_error(diagnostic_color_scale("neff_ratio", aesthetic = "color"))
86+
expect_no_error(scale_color_diagnostic("neff_ratio"))
87+
expect_no_error(scale_fill_diagnostic("neff_ratio"))
88+
89+
# aliases map to equivalent scales
90+
color_neff <- scale_color_diagnostic("neff")
91+
color_neff_ratio <- scale_color_diagnostic("neff_ratio")
92+
expect_equal(color_neff$palette(3), color_neff_ratio$palette(3))
93+
expect_equal(color_neff$labels, color_neff_ratio$labels)
94+
95+
fill_neff <- scale_fill_diagnostic("neff")
96+
fill_neff_ratio <- scale_fill_diagnostic("neff_ratio")
97+
expect_equal(fill_neff$palette(3), fill_neff_ratio$palette(3))
98+
expect_equal(fill_neff$labels, fill_neff_ratio$labels)
99+
100+
base_neff <- diagnostic_color_scale("neff", aesthetic = "color")
101+
base_neff_ratio <- diagnostic_color_scale("neff_ratio", aesthetic = "color")
102+
expect_equal(base_neff$palette(3), base_neff_ratio$palette(3))
103+
expect_equal(base_neff$labels, base_neff_ratio$labels)
87104
})
88105

89106
test_that("mcmc_acf & mcmc_acf_bar return a ggplot object", {

0 commit comments

Comments
 (0)