|
12 | 12 | #' @template args-transformations |
13 | 13 | #' @template args-facet_args |
14 | 14 | #' @template args-density-controls |
15 | | -#' @param ... Currently ignored. |
| 15 | +#' @param ... For dot plots, optional additional arguments to pass to [ggdist::stat_dots()]. |
16 | 16 | #' @param alpha Passed to the geom to control the transparency. |
17 | 17 | #' |
18 | 18 | #' @template return-ggplot |
|
25 | 25 | #' \item{`mcmc_dens()`}{ |
26 | 26 | #' Kernel density plots of posterior draws with all chains merged. |
27 | 27 | #' } |
| 28 | +#' \item{`mcmc_dots()`}{ |
| 29 | +#' Dot plots of posterior draws with all chains merged. |
| 30 | +#' } |
28 | 31 | #' \item{`mcmc_hist_by_chain()`}{ |
29 | 32 | #' Histograms of posterior draws with chains separated via faceting. |
30 | 33 | #' } |
31 | 34 | #' \item{`mcmc_dens_overlay()`}{ |
32 | 35 | #' Kernel density plots of posterior draws with chains separated but |
33 | 36 | #' overlaid on a single plot. |
34 | 37 | #' } |
| 38 | +#' \item{`mcmc_dots_by_chain()`}{ |
| 39 | +#' Dot plots of posterior draws with chains separated via faceting. |
| 40 | +#' } |
35 | 41 | #' \item{`mcmc_violin()`}{ |
36 | 42 | #' The density estimate of each chain is plotted as a violin with |
37 | 43 | #' horizontal lines at notable quantiles. |
|
77 | 83 | #' mcmc_hist(x, transformations = list(sigma = log)) |
78 | 84 | #' |
79 | 85 | #' # separate histograms by chain |
80 | | -#' color_scheme_set("pink") |
| 86 | +#' color_scheme_set("orange") |
81 | 87 | #' mcmc_hist_by_chain(x, regex_pars = "beta") |
82 | 88 | #' } |
83 | 89 | #' |
84 | 90 | #' ################# |
85 | 91 | #' ### Densities ### |
86 | 92 | #' ################# |
87 | | -#' |
| 93 | +#' color_scheme_set("purple") |
88 | 94 | #' mcmc_dens(x, pars = c("sigma", "beta[2]"), |
89 | 95 | #' facet_args = list(nrow = 2)) |
90 | 96 | #' \donttest{ |
|
98 | 104 | #' } |
99 | 105 | #' # separate chains as violin plots |
100 | 106 | #' color_scheme_set("green") |
101 | | -#' mcmc_violin(x) + panel_bg(color = "gray20", size = 2, fill = "gray30") |
| 107 | +#' mcmc_violin(x) + panel_bg(color = "gray20", linewidth = 2, fill = "gray30") |
| 108 | +#' |
| 109 | +#' |
| 110 | +#' ################# |
| 111 | +#' ### Dot Plots ### |
| 112 | +#' ################# |
| 113 | +#' |
| 114 | +#' # dot plots of some parameters |
| 115 | +#' color_scheme_set("pink") |
| 116 | +#' mcmc_dots(x, pars = c("alpha", "beta[2]")) |
| 117 | +#' |
| 118 | +#' \donttest{ |
| 119 | +#' color_scheme_set("teal") |
| 120 | +#' # separate dot plots by chain |
| 121 | +#' mcmc_dots_by_chain(x, regex_pars = "beta") |
| 122 | +#' |
| 123 | +#' # custom facet labels (will change row labels to e.g. "Chain: 1" instead of just "1") |
| 124 | +#' chain_labeller <- ggplot2::labeller(.rows = ggplot2::label_both) |
| 125 | +#' mcmc_dots_by_chain(x, regex_pars = "beta", facet_args = list(labeller = chain_labeller)) |
| 126 | +#' } |
102 | 127 | #' |
103 | 128 | NULL |
104 | 129 |
|
@@ -371,7 +396,67 @@ mcmc_violin <- function( |
371 | 396 | ) |
372 | 397 | } |
373 | 398 |
|
| 399 | +#' @rdname MCMC-distributions |
| 400 | +#' @export |
| 401 | +#' @template args-dots |
| 402 | +mcmc_dots <- function( |
| 403 | + x, |
| 404 | + pars = character(), |
| 405 | + regex_pars = character(), |
| 406 | + transformations = list(), |
| 407 | + ..., |
| 408 | + facet_args = list(), |
| 409 | + binwidth = NA, |
| 410 | + alpha = 1, |
| 411 | + quantiles = 100 |
| 412 | +) { |
| 413 | + check_ignored_arguments(..., ok_args = c("dotsize", "layout", "stackratio", "overflow")) |
| 414 | + |
| 415 | + suggested_package("ggdist") |
374 | 416 |
|
| 417 | + .mcmc_dots( |
| 418 | + x, |
| 419 | + pars = pars, |
| 420 | + regex_pars = regex_pars, |
| 421 | + transformations = transformations, |
| 422 | + binwidth = binwidth, |
| 423 | + facet_args = facet_args, |
| 424 | + alpha = alpha, |
| 425 | + quantiles = quantiles, |
| 426 | + ... |
| 427 | + ) |
| 428 | +} |
| 429 | + |
| 430 | +#' @rdname MCMC-distributions |
| 431 | +#' @export |
| 432 | +mcmc_dots_by_chain <- function( |
| 433 | + x, |
| 434 | + pars = character(), |
| 435 | + regex_pars = character(), |
| 436 | + transformations = list(), |
| 437 | + ..., |
| 438 | + facet_args = list(), |
| 439 | + binwidth = NA, |
| 440 | + alpha = 1, |
| 441 | + quantiles = 100 |
| 442 | +) { |
| 443 | + check_ignored_arguments(..., ok_args = c("dotsize", "layout", "stackratio", "overflow")) |
| 444 | + |
| 445 | + suggested_package("ggdist") |
| 446 | + |
| 447 | + .mcmc_dots( |
| 448 | + x, |
| 449 | + pars = pars, |
| 450 | + regex_pars = regex_pars, |
| 451 | + transformations = transformations, |
| 452 | + binwidth = binwidth, |
| 453 | + facet_args = facet_args, |
| 454 | + by_chain = TRUE, |
| 455 | + alpha = alpha, |
| 456 | + quantiles = quantiles, |
| 457 | + ... |
| 458 | + ) |
| 459 | +} |
375 | 460 |
|
376 | 461 |
|
377 | 462 | # internal ----------------------------------------------------------------- |
@@ -558,3 +643,65 @@ mcmc_violin <- function( |
558 | 643 | yaxis_title(on = n_param == 1 && violin) + |
559 | 644 | xaxis_title(on = n_param == 1) |
560 | 645 | } |
| 646 | + |
| 647 | + |
| 648 | +.mcmc_dots <- function( |
| 649 | + x, |
| 650 | + pars = character(), |
| 651 | + regex_pars = character(), |
| 652 | + transformations = list(), |
| 653 | + facet_args = list(), |
| 654 | + binwidth = NA, |
| 655 | + by_chain = FALSE, |
| 656 | + alpha = 1, |
| 657 | + quantiles = NA, |
| 658 | + ... |
| 659 | +) { |
| 660 | + x <- prepare_mcmc_array(x, pars, regex_pars, transformations) |
| 661 | + |
| 662 | + if (by_chain && !has_multiple_chains(x)) { |
| 663 | + STOP_need_multiple_chains() |
| 664 | + } |
| 665 | + |
| 666 | + data <- melt_mcmc(x, value.name = "value") |
| 667 | + n_param <- num_params(data) |
| 668 | + |
| 669 | + graph <- ggplot(data, aes(x = .data$value)) + |
| 670 | + ggdist::stat_dots( |
| 671 | + binwidth = binwidth, |
| 672 | + quantiles = quantiles, |
| 673 | + fill = get_color("mid"), |
| 674 | + color = get_color("mid_highlight"), |
| 675 | + alpha = alpha, |
| 676 | + ... |
| 677 | + ) |
| 678 | + |
| 679 | + facet_args[["scales"]] <- facet_args[["scales"]] %||% "free" |
| 680 | + if (!by_chain) { |
| 681 | + if (n_param > 1) { |
| 682 | + facet_args[["facets"]] <- vars(.data$Parameter) |
| 683 | + graph <- graph + do.call("facet_wrap", facet_args) |
| 684 | + } |
| 685 | + } else { |
| 686 | + facet_args[["rows"]] <- vars(.data$Chain) |
| 687 | + if (n_param > 1) { |
| 688 | + facet_args[["cols"]] <- vars(.data$Parameter) |
| 689 | + } |
| 690 | + graph <- graph + |
| 691 | + do.call("facet_grid", facet_args) + |
| 692 | + force_x_axis_in_facets() |
| 693 | + } |
| 694 | + |
| 695 | + if (n_param == 1) { |
| 696 | + graph <- graph + xlab(levels(data$Parameter)) |
| 697 | + } |
| 698 | + |
| 699 | + graph + |
| 700 | + dont_expand_y_axis(c(0.005, 0)) + |
| 701 | + bayesplot_theme_get() + |
| 702 | + yaxis_text(FALSE) + |
| 703 | + yaxis_title(FALSE) + |
| 704 | + yaxis_ticks(FALSE) + |
| 705 | + theme(axis.line.y = element_blank()) + |
| 706 | + xaxis_title(on = n_param == 1) |
| 707 | +} |
0 commit comments