Skip to content

Commit 5ee52a1

Browse files
Merge branch 'master' into fix/allow-nas-prepare-mcmc-array-250
2 parents 19ceac5 + 9d6a95b commit 5ee52a1

119 files changed

Lines changed: 4695 additions & 961 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ Authors@R: c(person("Jonah", "Gabry", role = c("aut", "cre"), email = "jgabry@gm
1313
person("Teemu", "Sailynoja", role = "ctb"),
1414
person("Aki", "Vehtari", role = "ctb"),
1515
person("Behram", "Ulukır", role = "ctb"),
16-
person("Visruth", "Srimath Kandali", role = "ctb"))
16+
person("Visruth", "Srimath Kandali", role = "ctb"),
17+
person("Mattan S.", "Ben-Shachar", role = "ctb"))
1718
Maintainer: Jonah Gabry <jgabry@gmail.com>
1819
Description: Plotting functions for posterior analysis, MCMC diagnostics,
1920
prior and posterior predictive checks, and other visualizations

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# bayesplot (development version)
22

33
* `prepare_mcmc_array()` now warns instead of erroring on `NA`s in the input.
4+
* Fixed `validate_chain_list()` colnames check to compare all chains, not just the first two.
5+
* Added test verifying `legend_move("none")` behaves equivalently to `legend_none()`.
6+
* Added singleton-dimension edge-case tests for exported `_data()` functions.
7+
* Validate empty list and zero-row matrix inputs in `nuts_params.list()`.
8+
* Validate user-provided `pit` values in `ppc_loo_pit_data()` and `ppc_loo_pit_qq()`, rejecting non-numeric inputs, missing values, and values outside `[0, 1]`.
9+
* New `show_marginal` argument to `ppd_*()` functions to show the PPD - the marginal predictive distribution by @mattansb (#425)
410
* `ppc_ecdf_overlay()`, `ppc_ecdf_overlay_grouped()`, and `ppd_ecdf_overlay()` now always use `geom_step()`. The `discrete` argument is deprecated.
511
* Fixed missing `drop = FALSE` in `nuts_params.CmdStanMCMC()`.
612
* Replace `apply()` with `storage.mode()` for integer-to-numeric matrix conversion in `validate_predictions()`.

R/bayesplot-extractors.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,18 @@ nuts_params.stanreg <-
145145
#' @export
146146
#' @method nuts_params list
147147
nuts_params.list <- function(object, pars = NULL, ...) {
148+
if (length(object) == 0) {
149+
abort("'object' must be a non-empty list.")
150+
}
151+
148152
if (!all(sapply(object, is.matrix))) {
149153
abort("All list elements should be matrices.")
150154
}
151155

156+
if (any(vapply(object, nrow, integer(1)) == 0)) {
157+
abort("All matrices in the list must have at least one row.")
158+
}
159+
152160
dd <- lapply(object, dim)
153161
if (length(unique(dd)) != 1) {
154162
abort("All matrices in the list must have the same dimensions.")

R/helpers-gg.R

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ scale_color_ppc <-
9696
labels = NULL,
9797
...) {
9898
scale_color_manual(
99-
name = name %||% "",
99+
name = name,
100100
values = values %||% get_color(c("dh", "lh")),
101101
labels = labels %||% c(y_label(), yrep_label()),
102102
...
@@ -109,7 +109,7 @@ scale_fill_ppc <-
109109
labels = NULL,
110110
...) {
111111
scale_fill_manual(
112-
name = name %||% "",
112+
name = name,
113113
values = values %||% get_color(c("d", "l")),
114114
labels = labels %||% c(y_label(), yrep_label()),
115115
...
@@ -118,22 +118,77 @@ scale_fill_ppc <-
118118

119119
scale_color_ppd <-
120120
function(name = NULL,
121-
values = get_color("mh"),
122-
labels = ypred_label(),
121+
values = NULL,
122+
labels = NULL,
123+
highlight = TRUE,
124+
show_marginal = FALSE,
123125
...) {
124-
scale_color_ppc(name = name,
125-
values = values,
126-
labels = labels,
127-
...)
126+
if (isTRUE(show_marginal)) {
127+
if (isTRUE(highlight)) {
128+
cl <- c("dh", "lh")
129+
} else {
130+
cl <- c("d", "l")
131+
}
132+
default_values <- setNames(get_color(cl), nm = c("PPD", "ypred"))
133+
} else {
134+
if (isTRUE(highlight)) {
135+
default_values <- get_color("mh")
136+
} else {
137+
default_values <- get_color("m")
138+
}
139+
}
140+
141+
scale_color_ppc(
142+
name = name,
143+
values = values %||% default_values,
144+
labels = labels %||% ypred_label(),
145+
...
146+
)
128147
}
129148

130149
scale_fill_ppd <-
131150
function(name = NULL,
132-
values = get_color("m"),
133-
labels = ypred_label(),
151+
values = NULL,
152+
labels = NULL,
153+
show_marginal = FALSE,
134154
...) {
135-
scale_fill_ppc(name = name,
136-
values = values,
137-
labels = labels,
138-
...)
155+
if (isTRUE(show_marginal)) {
156+
default_values <- c(PPD = "white", ypred = get_color("l"))
157+
} else {
158+
default_values <- get_color("m")
159+
}
160+
161+
scale_fill_ppc(
162+
name = name,
163+
values = values %||% default_values,
164+
labels = labels %||% ypred_label(),
165+
...
166+
)
167+
}
168+
169+
170+
scale_linetype_ppd <-
171+
function(name = NULL,
172+
values = NULL,
173+
labels = NULL,
174+
...) {
175+
scale_linetype_manual(
176+
name = name,
177+
values = values %||% c(PPD = "5111", ypred = "solid"),
178+
labels = labels %||% ypred_label(),
179+
...
180+
)
181+
}
182+
183+
scale_shape_ppd <-
184+
function(name = NULL,
185+
values = NULL,
186+
labels = NULL,
187+
...) {
188+
scale_shape_manual(
189+
name = name,
190+
values = values %||% c(ypred = 21, PPD = 23),
191+
labels = labels %||% ypred_label(),
192+
...
193+
)
139194
}

R/helpers-mcmc.R

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,8 @@ validate_chain_list <- function(x) {
279279
abort("Each chain should have the same number of iterations.")
280280
}
281281

282-
cnames <- sapply(x, colnames)
283-
if (is.array(cnames)) {
284-
same_params <- identical(cnames[, 1], cnames[, 2])
285-
} else {
286-
same_params <- length(unique(cnames)) == 1
287-
}
282+
cnames <- lapply(x, colnames)
283+
same_params <- all(vapply(cnames[-1], identical, logical(1), cnames[[1]]))
288284
if (!same_params) {
289285
abort(paste(
290286
"The parameters for each chain should be in the same order",

R/helpers-ppc.R

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,20 @@ validate_y <- function(y) {
5858
#' Validate predictions (`yrep` or `ypred`)
5959
#'
6060
#' Checks that `predictions` is a numeric matrix, doesn't have any NAs, and has
61-
#' the correct number of columns.
61+
#' the correct number of columns. If `predictions` is a `posterior::draws`
62+
#' object it is first coerced to a matrix.
6263
#'
63-
#' @param predictions The user's `yrep` or `ypred` object (SxN matrix).
64+
#' @param predictions The user's `yrep` or `ypred` object (SxN matrix or a
65+
#' `posterior::draws` object).
6466
#' @param `n_obs` The number of observations (columns) that `predictions` should
6567
#' have, if applicable.
6668
#' @return Either throws an error or returns a numeric matrix.
6769
#' @noRd
6870
validate_predictions <- function(predictions, n_obs = NULL) {
69-
# sanity checks
71+
if (posterior::is_draws(predictions)) {
72+
predictions <- posterior::as_draws_matrix(predictions)
73+
predictions <- unclass(predictions)
74+
}
7075
stopifnot(is.matrix(predictions), is.numeric(predictions))
7176
if (!is.null(n_obs)) {
7277
stopifnot(length(n_obs) == 1, n_obs == as.integer(n_obs))
@@ -589,4 +594,9 @@ u_scale <- function(x) {
589594
create_rep_ids <- function(ids) paste('italic(y)[rep] (', ids, ")")
590595
y_label <- function() expression(italic(y))
591596
yrep_label <- function() expression(italic(y)[rep])
592-
ypred_label <- function() expression(italic(y)[pred])
597+
ypred_label <- function() {
598+
c(
599+
PPD = "PPD",
600+
ypred = expression(italic(y)[pred])
601+
)
602+
}

R/ppc-loo.R

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,10 @@ ppc_loo_pit_data <-
302302
boundary_correction = TRUE,
303303
grid_len = 512) {
304304
if (!is.null(pit)) {
305-
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
305+
pit <- validate_pit(pit)
306+
if (boundary_correction && length(pit) < 2L) {
307+
abort("At least 2 PIT values are required when 'boundary_correction' is TRUE.")
308+
}
306309
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
307310
} else {
308311
suggested_package("rstantools")
@@ -348,7 +351,7 @@ ppc_loo_pit_qq <- function(y,
348351

349352
compare <- match.arg(compare)
350353
if (!is.null(pit)) {
351-
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
354+
pit <- validate_pit(pit)
352355
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
353356
} else {
354357
suggested_package("rstantools")
@@ -795,14 +798,6 @@ ppc_loo_ribbon <-
795798
# Generate boundary corrected values via a linear convolution using a
796799
# 1-D Gaussian window filter. This method uses the "reflection method"
797800
# to estimate these pvalues and helps speed up the code
798-
if (any(is.infinite(x))) {
799-
warn(paste(
800-
"Ignored", sum(is.infinite(x)),
801-
"Non-finite PIT values are invalid for KDE boundary correction method"
802-
))
803-
x <- x[is.finite(x)]
804-
}
805-
806801
if (grid_len < 100) {
807802
grid_len <- 100
808803
}
@@ -819,6 +814,10 @@ ppc_loo_ribbon <-
819814
# 1-D Convolution
820815
bc_pvals <- .linear_convolution(x, bw, grid_counts, grid_breaks, grid_len)
821816

817+
if (all(is.na(bc_pvals))) {
818+
abort("KDE boundary correction produced all NA values.")
819+
}
820+
822821
# Generate vector of x-axis values for plotting based on binned relative freqs
823822
n_breaks <- length(grid_breaks)
824823

0 commit comments

Comments
 (0)