Skip to content

Commit a4c7ea7

Browse files
authored
Merge pull request #501 from utkarshpawade/perf/mcmc-areas-data-redundant-processing
Eliminate redundant data processing in mcmc_areas_data()
2 parents 3e80030 + 7656c7e commit a4c7ea7

File tree

2 files changed

+67
-59
lines changed

2 files changed

+67
-59
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# bayesplot (development version)
22

3+
* Eliminate redundant data processing in `mcmc_areas_data()` by reusing the prepared MCMC array for both interval and density computation.
34
* Validate equal chain lengths in `validate_df_with_chain()`, reject missing
45
chain labels, and renumber data-frame chain labels internally when converting
56
to arrays.

R/mcmc-intervals.R

Lines changed: 66 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -599,64 +599,13 @@ mcmc_intervals_data <- function(x,
599599
prob <- probs[1]
600600
prob_outer <- probs[2]
601601

602-
x <- prepare_mcmc_array(x, pars, regex_pars, transformations)
603-
x <- merge_chains(x)
604-
605-
data_long <- melt_mcmc(x) %>%
602+
data_long <- melt_mcmc(
603+
merge_chains(prepare_mcmc_array(x, pars, regex_pars, transformations))
604+
) %>%
606605
dplyr::as_tibble() %>%
607606
rlang::set_names(tolower)
608607

609-
probs <- c(0.5 - prob_outer / 2,
610-
0.5 - prob / 2,
611-
0.5 + prob / 2,
612-
0.5 + prob_outer / 2)
613-
614-
point_est <- match.arg(point_est)
615-
m_func <- if (point_est == "mean") mean else median
616-
617-
data <- data_long %>%
618-
group_by(.data$parameter) %>%
619-
summarise(
620-
outer_width = prob_outer,
621-
inner_width = prob,
622-
point_est = point_est,
623-
ll = unname(quantile(.data$value, probs[1])),
624-
l = unname(quantile(.data$value, probs[2])),
625-
m = m_func(.data$value),
626-
h = unname(quantile(.data$value, probs[3])),
627-
hh = unname(quantile(.data$value, probs[4]))
628-
)
629-
630-
if (point_est == "none") {
631-
data$m <- NULL
632-
}
633-
634-
color_by_rhat <- isTRUE(length(rhat) > 0)
635-
636-
if (color_by_rhat) {
637-
rhat <- drop_NAs_and_warn(new_rhat(rhat))
638-
639-
if (length(rhat) != nrow(data)) {
640-
abort(paste(
641-
"'rhat' has length", length(rhat),
642-
"but 'x' has", nrow(data), "parameters."
643-
))
644-
}
645-
646-
rhat <- set_names(rhat, data$parameter)
647-
648-
rhat_tbl <- rhat %>%
649-
mcmc_rhat_data() %>%
650-
select(all_of("parameter"),
651-
rhat_value = "value",
652-
rhat_rating = "rating",
653-
rhat_description = "description") %>%
654-
mutate(parameter = factor(.data$parameter, levels(data$parameter)))
655-
656-
data <- dplyr::inner_join(data, rhat_tbl, by = "parameter")
657-
}
658-
659-
data
608+
compute_intervals(data_long, prob, prob_outer, point_est, rhat)
660609
}
661610

662611

@@ -691,17 +640,17 @@ mcmc_areas_data <- function(x,
691640
point_est <- match.arg(point_est)
692641
temp_point_est <- if (point_est == "none") "median" else point_est
693642

694-
intervals <- mcmc_intervals_data(x, pars, regex_pars, transformations,
695-
prob = probs[1], prob_outer = probs[2],
696-
point_est = temp_point_est, rhat = rhat)
697-
698643
x <- prepare_mcmc_array(x, pars, regex_pars, transformations)
699644
x <- merge_chains(x)
700645

701646
data_long <- melt_mcmc(x) %>%
702647
dplyr::as_tibble() %>%
703648
rlang::set_names(tolower)
704649

650+
intervals <- compute_intervals(data_long, prob = probs[1],
651+
prob_outer = probs[2],
652+
point_est = temp_point_est, rhat = rhat)
653+
705654
# Compute the density intervals
706655
data_inner <- data_long %>%
707656
compute_column_density(
@@ -901,3 +850,61 @@ check_interval_widths <- function(prob, prob_outer) {
901850
}
902851
sort(c(prob, prob_outer))
903852
}
853+
854+
# Internal helper shared by mcmc_intervals_data() and mcmc_areas_data()
855+
compute_intervals <- function(data_long, prob, prob_outer,
856+
point_est = c("median", "mean", "none"),
857+
rhat = numeric()) {
858+
859+
probs <- c(0.5 - prob_outer / 2,
860+
0.5 - prob / 2,
861+
0.5 + prob / 2,
862+
0.5 + prob_outer / 2)
863+
864+
point_est <- match.arg(point_est)
865+
m_func <- if (point_est == "mean") mean else median
866+
867+
data <- data_long %>%
868+
group_by(.data$parameter) %>%
869+
summarise(
870+
outer_width = prob_outer,
871+
inner_width = prob,
872+
point_est = point_est,
873+
ll = unname(quantile(.data$value, probs[1])),
874+
l = unname(quantile(.data$value, probs[2])),
875+
m = m_func(.data$value),
876+
h = unname(quantile(.data$value, probs[3])),
877+
hh = unname(quantile(.data$value, probs[4]))
878+
)
879+
880+
if (point_est == "none") {
881+
data$m <- NULL
882+
}
883+
884+
color_by_rhat <- isTRUE(length(rhat) > 0)
885+
886+
if (color_by_rhat) {
887+
rhat <- drop_NAs_and_warn(new_rhat(rhat))
888+
889+
if (length(rhat) != nrow(data)) {
890+
abort(paste(
891+
"'rhat' has length", length(rhat),
892+
"but 'x' has", nrow(data), "parameters."
893+
))
894+
}
895+
896+
rhat <- set_names(rhat, data$parameter)
897+
898+
rhat_tbl <- rhat %>%
899+
mcmc_rhat_data() %>%
900+
select(all_of("parameter"),
901+
rhat_value = "value",
902+
rhat_rating = "rating",
903+
rhat_description = "description") %>%
904+
mutate(parameter = factor(.data$parameter, levels(data$parameter)))
905+
906+
data <- dplyr::inner_join(data, rhat_tbl, by = "parameter")
907+
}
908+
909+
data
910+
}

0 commit comments

Comments
 (0)