Skip to content

Commit c2d8d1c

Browse files
author
Florence Bockting
committed
feat: update ppc-pit-ecdf-grouped to support correlated method
1 parent 5d969d9 commit c2d8d1c

7 files changed

Lines changed: 819 additions & 50 deletions

File tree

R/ppc-distributions.R

Lines changed: 267 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -657,32 +657,40 @@ ppc_violin_grouped <-
657657
#' @param pit An optional vector of probability integral transformed values for
658658
#' which the ECDF is to be drawn. If NULL, PIT values are computed to `y` with
659659
#' respect to the corresponding values in `yrep`.
660-
#' @param interpolate_adj For `ppc_pit_ecdf()` when `method = "independent"`,
660+
#' @param interpolate_adj For `ppc_pit_ecdf()` and `ppc_pit_ecdf_grouped()`
661+
#' when `method = "independent"`,
661662
#' a boolean defining if the simultaneous confidence bands should be
662663
#' interpolated based on precomputed values rather than computed exactly.
663664
#' Computing the bands may be computationally intensive and the approximation
664665
#' gives a fast method for assessing the ECDF trajectory. The default is to use
665666
#' interpolation if `K` is greater than 200.
666-
#' @param method For `ppc_pit_ecdf()`, the method used to calculate the
667+
#' @param method For `ppc_pit_ecdf()` and `ppc_pit_ecdf_grouped()`, the method
668+
#' used to calculate the
667669
#' uniformity test:
668670
#' * `"independent"`: (default) Assumes independence (Säilynoja et al., 2022).
669671
#' * `"correlated"`: Accounts for correlation (Tesso & Vehtari, 2026).
670-
#' @param test For `ppc_pit_ecdf()` when `method = "correlated"`, which
672+
#' @param test For `ppc_pit_ecdf()` and `ppc_pit_ecdf_grouped()` when
673+
#' `method = "correlated"`, which
671674
#' dependence-aware test to use: `"POT"`, `"PRIT"`, or `"PIET"`.
672675
#' Defaults to `"POT"`.
673-
#' @param gamma For `ppc_pit_ecdf()` when `method = "correlated"`, tolerance
676+
#' @param gamma For `ppc_pit_ecdf()` and `ppc_pit_ecdf_grouped()` when
677+
#' `method = "correlated"`, tolerance
674678
#' threshold controlling how strongly suspicious points are flagged. Larger
675679
#' values (gamma > 0) emphasizes points with larger deviations. If `NULL`, automatically
676680
#' determined based on p-value.
677-
#' @param linewidth For `ppc_pit_ecdf()` when `method = "correlated"`, the line width of the ECDF
678-
#' and highlighting points. Defaults to 0.3.
679-
#' @param color For `ppc_pit_ecdf()` when `method = "correlated"`, a vector
681+
#' @param linewidth For `ppc_pit_ecdf()` and `ppc_pit_ecdf_grouped()` when
682+
#' `method = "correlated"`, the line width of the ECDF and highlighting
683+
#' points. Defaults to 0.3.
684+
#' @param color For `ppc_pit_ecdf()` and `ppc_pit_ecdf_grouped()` when
685+
#' `method = "correlated"`, a vector
680686
#' with base color and highlight color for the ECDF plot. Defaults to
681687
#' `c(ecdf = "grey60", highlight = "red")`. The first element is used for
682688
#' the main ECDF line, the second for highlighted suspicious regions.
683-
#' @param help_text For `ppc_pit_ecdf()` when `method = "correlated"`, a boolean
684-
#' defining whether to add informative text to the plot. Defaults to `TRUE`.
685-
#' @param pareto_pit For `ppc_pit_ecdf()`, a boolean defining whether to compute
689+
#' @param help_text For `ppc_pit_ecdf()` and `ppc_pit_ecdf_grouped()` when
690+
#' `method = "correlated"`, a boolean defining whether to add informative
691+
#' text to the plot. Defaults to `TRUE`.
692+
#' @param pareto_pit For `ppc_pit_ecdf()` and `ppc_pit_ecdf_grouped()`, a
693+
#' boolean defining whether to compute
686694
#' the PIT values using Pareto-smoothed importance sampling (if `TRUE` and no pit values are provided).
687695
#' Defaults to `TRUE` when `method = "correlated"` and `test` is `"POT"` or `"PIET"`.
688696
#' Otherwise defaults to `FALSE`. If `TRUE` requires the specification of `lw` or `psis_object`.
@@ -787,7 +795,6 @@ ppc_pit_ecdf <- function(y,
787795
y <- validate_y(y)
788796
yrep <- validate_predictions(yrep, length(y))
789797

790-
# TODO: posterior::pareto_pit() from https://github.com/stan-dev/posterior/pull/435 - requires that PR merged
791798
pit <- posterior::pareto_pit(x = yrep, y = y, weights = NULL, log = TRUE)
792799
K <- K %||% length(pit)
793800

@@ -809,7 +816,9 @@ ppc_pit_ecdf <- function(y,
809816
}
810817

811818
} else {
812-
# --- Empirical PIT ---
819+
# --- Empirical PIT ---'
820+
y <- validate_y(y)
821+
yrep <- validate_predictions(yrep, length(y))
813822
pit <- ppc_data(y, yrep) %>%
814823
group_by(.data$y_id) %>%
815824
dplyr::group_map(
@@ -837,7 +846,6 @@ ppc_pit_ecdf <- function(y,
837846

838847
# Compute the per-observation test statistics (sorted for Shapley values)
839848
# and the combined Cauchy p-value.
840-
# TODO: posterior::uniformity_test() from https://github.com/stan-dev/posterior/pull/435 - requires that PR merged
841849
test_res <- posterior::uniformity_test(pit = pit, test = test)
842850
p_value_CCT <- test_res$pvalue
843851
pointwise_contrib <- test_res$pointwise
@@ -1017,56 +1025,282 @@ ppc_pit_ecdf_grouped <-
10171025
pit = NULL,
10181026
prob = .99,
10191027
plot_diff = FALSE,
1020-
interpolate_adj = NULL) {
1028+
interpolate_adj = NULL,
1029+
method = "independent",
1030+
test = NULL,
1031+
gamma = NULL,
1032+
linewidth = NULL,
1033+
color = NULL,
1034+
help_text = NULL,
1035+
pareto_pit = NULL) {
10211036
check_ignored_arguments(...,
1022-
ok_args = c("K", "pit", "prob", "plot_diff", "interpolate_adj")
1037+
ok_args = c("K", "pareto_pit", "pit", "prob", "plot_diff",
1038+
"interpolate_adj", "method", "test", "gamma",
1039+
"linewidth", "color", "help_text")
1040+
)
1041+
1042+
.warn_ignored <- function(method_name, args) {
1043+
inform(paste0(
1044+
"As method = ", method_name, " specified; ignoring: ",
1045+
paste(args, collapse = ", "), "."
1046+
))
1047+
}
1048+
1049+
method <- match.arg(method, choices = c("independent", "correlated"))
1050+
1051+
switch(method,
1052+
"correlated" = {
1053+
if (!is.null(interpolate_adj)) .warn_ignored("'correlated'", "interpolate_adj")
1054+
test <- match.arg(test %||% "POT", choices = c("POT", "PRIT", "PIET"))
1055+
alpha <- 1 - prob
1056+
gamma <- gamma %||% 0
1057+
linewidth <- linewidth %||% 0.3
1058+
color <- color %||% c(ecdf = "grey60", highlight = "red")
1059+
help_text <- help_text %||% TRUE
1060+
pareto_pit <- pareto_pit %||% is.null(pit) && test %in% c("POT", "PIET")
1061+
},
1062+
"independent" = {
1063+
ignored <- c(
1064+
if (!is.null(test)) "test",
1065+
if (!is.null(gamma)) "gamma",
1066+
if (!is.null(help_text)) "help_text"
1067+
)
1068+
if (length(ignored) > 0) .warn_ignored("'independent'", ignored)
1069+
pareto_pit <- pareto_pit %||% FALSE
1070+
}
10231071
)
10241072

1025-
if (is.null(pit)) {
1073+
if (isTRUE(pareto_pit) && is.null(pit)) {
1074+
suggested_package("rstantools")
1075+
y <- validate_y(y)
1076+
yrep <- validate_predictions(yrep, length(y))
1077+
group <- validate_group(group, length(y))
1078+
pit <- posterior::pareto_pit(x = yrep, y = y, weights = NULL, log = TRUE)
1079+
K <- K %||% length(pit)
1080+
} else if (!is.null(pit)) {
1081+
pit <- validate_pit(pit)
1082+
group <- validate_group(group, length(pit))
1083+
K <- K %||% length(pit)
1084+
ignored <- c(
1085+
if (!missing(y) && !is.null(y)) "y",
1086+
if (!missing(yrep) && !is.null(yrep)) "yrep"
1087+
)
1088+
if (length(ignored) > 0) {
1089+
inform(paste0(
1090+
"As 'pit' specified; ignoring: ",
1091+
paste(ignored, collapse = ", "), "."
1092+
))
1093+
}
1094+
} else {
1095+
y <- validate_y(y)
1096+
yrep <- validate_predictions(yrep, length(y))
1097+
group <- validate_group(group, length(y))
10261098
pit <- ppc_data(y, yrep, group) %>%
10271099
group_by(.data$y_id) %>%
10281100
dplyr::group_map(
10291101
~ mean(.x$value[.x$is_y] > .x$value[!.x$is_y]) +
10301102
runif(1, max = mean(.x$value[.x$is_y] == .x$value[!.x$is_y]))
1031-
) %>%
1103+
) %>%
10321104
unlist()
1033-
if (is.null(K)) {
1034-
K <- min(nrow(yrep) + 1, 1000)
1105+
K <- K %||% min(nrow(yrep) + 1, 1000)
1106+
}
1107+
1108+
data <- data.frame(pit = pit, group = group, stringsAsFactors = FALSE)
1109+
group_levels <- unique(data$group)
1110+
1111+
if (method == "correlated") {
1112+
data_cor <- dplyr::group_by(data, .data$group) %>%
1113+
dplyr::group_map(function(.x, .y) {
1114+
n_obs <- nrow(.x)
1115+
K_g <- K %||% n_obs
1116+
unit_interval <- seq(0, 1, length.out = K_g)
1117+
ecdf_pit_fn <- ecdf(.x$pit)
1118+
x_combined <- sort(unique(c(unit_interval, .x$pit)))
1119+
df_main <- data.frame(
1120+
x = x_combined,
1121+
ecdf_value = ecdf_pit_fn(x_combined) - plot_diff * x_combined,
1122+
group = .y[[1]],
1123+
stringsAsFactors = FALSE
1124+
)
1125+
1126+
test_res <- posterior::uniformity_test(pit = .x$pit, test = test)
1127+
p_value_CCT <- test_res$pvalue
1128+
pointwise_contrib <- test_res$pointwise
1129+
max_contrib <- max(pointwise_contrib)
1130+
if (gamma < 0 || gamma > max_contrib) {
1131+
stop(sprintf(
1132+
"gamma must be in [0, %.2f], but gamma = %s was provided.",
1133+
max_contrib, gamma
1134+
))
1135+
}
1136+
1137+
red <- NULL
1138+
red_points <- NULL
1139+
if (p_value_CCT < alpha) {
1140+
red_idx <- which(pointwise_contrib > gamma)
1141+
if (length(red_idx) > 0) {
1142+
pit_sorted <- sort(.x$pit)
1143+
df_pit <- data.frame(
1144+
pit = pit_sorted,
1145+
ecdf_value = ecdf_pit_fn(pit_sorted),
1146+
stringsAsFactors = FALSE
1147+
)
1148+
df_red <- df_pit[red_idx, , drop = FALSE]
1149+
df_red$segment <- cumsum(c(1, diff(red_idx) != 1))
1150+
seg_sizes <- stats::ave(df_red$pit, df_red$segment, FUN = length)
1151+
df_isolated <- df_red[seg_sizes == 1, , drop = FALSE]
1152+
df_grouped <- df_red[seg_sizes > 1, , drop = FALSE]
1153+
1154+
if (nrow(df_grouped) > 0) {
1155+
red <- do.call(rbind, lapply(
1156+
split(df_grouped, df_grouped$segment),
1157+
function(grp) {
1158+
pit_idx <- match(grp$pit, x_combined)
1159+
idx_range <- seq(min(pit_idx), max(pit_idx))
1160+
data.frame(
1161+
x = x_combined[idx_range],
1162+
ecdf_value = ecdf_pit_fn(x_combined[idx_range]) -
1163+
plot_diff * x_combined[idx_range],
1164+
segment = grp$segment[1],
1165+
group = .y[[1]],
1166+
stringsAsFactors = FALSE
1167+
)
1168+
}
1169+
))
1170+
}
1171+
1172+
if (nrow(df_isolated) > 0) {
1173+
red_points <- data.frame(
1174+
x = df_isolated$pit,
1175+
ecdf_value = df_isolated$ecdf_value - plot_diff * df_isolated$pit,
1176+
group = .y[[1]],
1177+
stringsAsFactors = FALSE
1178+
)
1179+
}
1180+
}
1181+
}
1182+
1183+
ann <- NULL
1184+
if (isTRUE(help_text)) {
1185+
ann <- data.frame(
1186+
group = .y[[1]],
1187+
x = -Inf,
1188+
y = Inf,
1189+
label = sprintf(
1190+
"p[unif]^{%s} == '%s' ~ (alpha == '%.2f')",
1191+
test, fmt_p(p_value_CCT), alpha
1192+
),
1193+
stringsAsFactors = FALSE
1194+
)
1195+
}
1196+
1197+
list(main = df_main, red = red, red_points = red_points, ann = ann)
1198+
})
1199+
1200+
main_df <- dplyr::bind_rows(lapply(data_cor, `[[`, "main"))
1201+
red_df <- dplyr::bind_rows(lapply(data_cor, `[[`, "red"))
1202+
red_points_df <- dplyr::bind_rows(lapply(data_cor, `[[`, "red_points"))
1203+
ann_df <- dplyr::bind_rows(lapply(data_cor, `[[`, "ann"))
1204+
ref_df <- data.frame(
1205+
group = group_levels,
1206+
x = 0,
1207+
y = 0,
1208+
xend = 1,
1209+
yend = if (plot_diff) 0 else 1,
1210+
stringsAsFactors = FALSE
1211+
)
1212+
1213+
p <- ggplot() +
1214+
geom_step(
1215+
data = main_df,
1216+
mapping = aes(x = .data$x, y = .data$ecdf_value, group = .data$group),
1217+
show.legend = FALSE,
1218+
linewidth = linewidth,
1219+
color = color["ecdf"]
1220+
) +
1221+
geom_segment(
1222+
data = ref_df,
1223+
mapping = aes(
1224+
x = .data$x,
1225+
y = .data$y,
1226+
xend = .data$xend,
1227+
yend = .data$yend
1228+
),
1229+
linetype = "dashed",
1230+
color = "darkgrey",
1231+
linewidth = 0.3
1232+
)
1233+
1234+
if (nrow(red_df) > 0) {
1235+
p <- p + geom_step(
1236+
data = red_df,
1237+
mapping = aes(x = .data$x, y = .data$ecdf_value, group = interaction(.data$group, .data$segment)),
1238+
color = color["highlight"],
1239+
linewidth = linewidth + 0.8
1240+
)
10351241
}
1036-
} else {
1037-
inform("'pit' specified so ignoring 'y' and 'yrep' if specified.")
1038-
pit <- validate_pit(pit)
1242+
1243+
if (nrow(red_points_df) > 0) {
1244+
p <- p + geom_point(
1245+
data = red_points_df,
1246+
mapping = aes(x = .data$x, y = .data$ecdf_value),
1247+
color = color["highlight"],
1248+
size = linewidth + 1
1249+
)
1250+
}
1251+
1252+
if (isTRUE(help_text) && nrow(ann_df) > 0) {
1253+
label_size <- 0.7 * bayesplot_theme_get()$text@size / ggplot2::.pt
1254+
p <- p + geom_text(
1255+
data = ann_df,
1256+
mapping = aes(x = .data$x, y = .data$y, label = .data$label),
1257+
hjust = -0.05,
1258+
vjust = 1.5,
1259+
color = "black",
1260+
parse = TRUE,
1261+
size = label_size
1262+
)
1263+
}
1264+
1265+
return(
1266+
p +
1267+
labs(y = if (plot_diff) "ECDF difference" else "ECDF", x = "PIT") +
1268+
yaxis_ticks(FALSE) +
1269+
bayesplot_theme_get() +
1270+
facet_wrap("group") +
1271+
scale_color_ppc() +
1272+
force_axes_in_facets()
1273+
)
10391274
}
1040-
N <- length(pit)
10411275

1042-
gammas <- lapply(unique(group), function(g) {
1043-
N_g <- sum(group == g)
1276+
gammas <- lapply(group_levels, function(g) {
1277+
N_g <- sum(data$group == g)
10441278
adjust_gamma(
10451279
N = N_g,
1046-
K = ifelse(is.null(K), N_g, K),
1280+
K = K %||% N_g,
10471281
prob = prob,
10481282
interpolate_adj = interpolate_adj
10491283
)
10501284
})
1051-
names(gammas) <- unique(group)
1285+
names(gammas) <- group_levels
10521286

1053-
data <- data.frame(pit = pit, group = group) %>%
1054-
group_by(group) %>%
1287+
data <- data %>%
1288+
dplyr::group_by(.data$group) %>%
10551289
dplyr::group_map(
10561290
~ data.frame(
1057-
ecdf_value = ecdf(.x$pit)(seq(0, 1, length.out = ifelse(is.null(K), nrow(.x), K))),
1291+
ecdf_value = ecdf(.x$pit)(seq(0, 1, length.out = K %||% nrow(.x))),
10581292
group = .y[1],
10591293
lims_upper = ecdf_intervals(
10601294
gamma = gammas[[unlist(.y[1])]],
10611295
N = nrow(.x),
1062-
K = ifelse(is.null(K), nrow(.x), K)
1296+
K = K %||% nrow(.x)
10631297
)$upper[-1] / nrow(.x),
10641298
lims_lower = ecdf_intervals(
10651299
gamma = gammas[[unlist(.y[1])]],
10661300
N = nrow(.x),
1067-
K = ifelse(is.null(K), nrow(.x), K)
1301+
K = K %||% nrow(.x)
10681302
)$lower[-1] / nrow(.x),
1069-
x = seq(0, 1, length.out = ifelse(is.null(K), nrow(.x), K))
1303+
x = seq(0, 1, length.out = K %||% nrow(.x))
10701304
)
10711305
) %>%
10721306
dplyr::bind_rows()

R/ppc-loo.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,6 @@ ppc_loo_pit_ecdf <- function(y,
550550
lw <- .get_lw(lw, psis_object)
551551
stopifnot(identical(dim(yrep), dim(lw)))
552552

553-
# TODO: posterior::pareto_pit() from https://github.com/stan-dev/posterior/pull/435 - requires that PR merged
554553
pit <- posterior::pareto_pit(x = yrep, y = y, weights = lw, log = TRUE)
555554
K <- K %||% length(pit)
556555

@@ -601,7 +600,6 @@ ppc_loo_pit_ecdf <- function(y,
601600

602601
# Compute the per-observation test statistics (sorted for Shapley values)
603602
# and the combined Cauchy p-value.
604-
# TODO: posterior::uniformity_test() from https://github.com/stan-dev/posterior/pull/435 - requires that PR merged
605603
test_res <- posterior::uniformity_test(pit = pit, test = test)
606604
p_value_CCT <- test_res$pvalue
607605
pointwise_contrib <- test_res$pointwise

0 commit comments

Comments
 (0)