Skip to content

Commit 98a7245

Browse files
authored
Merge pull request #448 from utkarshpawade/issue-431-dplyr-deprecations
Fix #431: Replace deprecated dplyr/tidyselect functions
2 parents 40fd9eb + 93afa91 commit 98a7245

File tree

4 files changed

+45
-15
lines changed

4 files changed

+45
-15
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ SystemRequirements: pandoc (>= 1.12.3), pandoc-citeproc
3030
Depends:
3131
R (>= 4.1.0)
3232
Imports:
33-
dplyr (>= 0.8.0),
33+
dplyr (>= 1.0.0),
3434
ggplot2 (>= 3.4.0),
3535
ggridges (>= 0.5.5),
3636
glue,

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# bayesplot (development version)
22

3+
* Replaced deprecated `dplyr` and `tidyselect` functions (`top_n`, `one_of`, `group_indices`) with their modern equivalents to ensure future compatibility. (#431)
34
* Documentation added for all exported `*_data()` functions (#209)
45
* Improved documentation for `binwidth`, `bins`, and `breaks` arguments to clarify they are passed to `ggplot2::geom_area()` and `ggdist::stat_dots()` in addition to `ggplot2::geom_histogram()`
56
* Improved documentation for `freq` argument to clarify it applies to frequency polygons in addition to histograms

R/mcmc-intervals.R

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ mcmc_intervals_data <- function(x,
647647

648648
rhat_tbl <- rhat %>%
649649
mcmc_rhat_data() %>%
650-
select(one_of("parameter"),
650+
select(all_of("parameter"),
651651
rhat_value = "value",
652652
rhat_rating = "rating",
653653
rhat_description = "description") %>%
@@ -663,7 +663,7 @@ mcmc_intervals_data <- function(x,
663663
# Don't import `filter`: otherwise, you get a warning when using
664664
# `devtools::load_all(".")` because stats also has a `filter` function
665665

666-
#' @importFrom dplyr inner_join one_of top_n
666+
#' @importFrom dplyr inner_join all_of slice_min
667667
#' @rdname MCMC-intervals
668668
#' @export
669669
mcmc_areas_data <- function(x,
@@ -736,14 +736,14 @@ mcmc_areas_data <- function(x,
736736

737737
# Find the density values closest to the point estimate
738738
point_ests <- intervals %>%
739-
select(one_of("parameter", "m"))
739+
select(all_of(c("parameter", "m")))
740740

741741
point_centers <- data_inner %>%
742742
inner_join(point_ests, by = "parameter") %>%
743743
group_by(.data$parameter) %>%
744744
mutate(diff = abs(.data$m - .data$x)) %>%
745-
dplyr::top_n(1, -.data$diff) %>%
746-
select(one_of("parameter", "x", "m")) %>%
745+
dplyr::slice_min(order_by = .data$diff, n = 1) %>%
746+
select(all_of(c("parameter", "x", "m"))) %>%
747747
rename(center = "x") %>%
748748
ungroup()
749749

@@ -765,15 +765,15 @@ mcmc_areas_data <- function(x,
765765
}
766766

767767
data <- dplyr::bind_rows(data_inner, data_outer, points) %>%
768-
select(one_of("parameter", "interval", "interval_width",
769-
"x", "density", "scaled_density")) %>%
768+
select(all_of(c("parameter", "interval", "interval_width",
769+
"x", "density", "scaled_density"))) %>%
770770
# Density scaled so the highest in entire dataframe has height 1
771771
mutate(plotting_density = .data$density / max(.data$density))
772772

773773
if (rlang::has_name(intervals, "rhat_value")) {
774774
rhat_info <- intervals %>%
775-
select(one_of("parameter", "rhat_value",
776-
"rhat_rating", "rhat_description"))
775+
select(all_of(c("parameter", "rhat_value",
776+
"rhat_rating", "rhat_description")))
777777
data <- inner_join(data, rhat_info, by = "parameter")
778778
}
779779
data
@@ -824,18 +824,15 @@ compute_column_density <- function(df, group_vars, value_var, ...) {
824824
syms()
825825

826826
# Tuck away the subgroups to compute densities on into nested dataframes
827-
sub_df <- dplyr::select(df, !!! group_cols, !! value_var)
828-
829827
group_df <- df %>%
830828
dplyr::select(!!! group_cols, !! value_var) %>%
831829
group_by(!!! group_cols)
832830

833831
by_group <- group_df %>%
834-
split(dplyr::group_indices(group_df)) %>%
832+
dplyr::group_split() %>%
835833
lapply(pull, !! value_var)
836834

837-
nested <- df %>%
838-
dplyr::distinct(!!! group_cols) %>%
835+
nested <- dplyr::group_keys(group_df) %>%
839836
mutate(data = by_group)
840837

841838
nested$density <- lapply(nested$data, compute_interval_density, ...)

tests/testthat/test-mcmc-distributions.R

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,38 @@ test_that("mcmc_dens_chains returns a ggplot object", {
111111
expect_gg(p2)
112112
})
113113

114+
test_that("mcmc_dens_chains_data computes densities per parameter-chain group", {
115+
# Regression test for compute_column_density().
116+
# This path groups by both parameter and chain, so it exercises the
117+
# group_split() + group_keys() replacement introduced in PR #448.
118+
# The goal is to verify that densities are still computed for the
119+
# correct parameter-chain groups, in the correct grouping structure.
120+
dens_data <- mcmc_dens_chains_data(arr, n_dens = 100)
121+
by_group <- split(
122+
dens_data,
123+
interaction(dens_data$parameter, dens_data$chain, drop = TRUE, lex.order = TRUE)
124+
)
125+
126+
raw <- melt_mcmc(prepare_mcmc_array(arr))
127+
raw_by_group <- split(
128+
raw,
129+
interaction(raw$Parameter, raw$Chain, drop = TRUE, lex.order = TRUE)
130+
)
131+
132+
manual_density <- function(df) {
133+
dens <- density(df$Value, from = min(df$Value), to = max(df$Value), n = 100)
134+
data.frame(x = dens$x, density = dens$y)
135+
}
136+
137+
expected <- lapply(raw_by_group, manual_density)
138+
expect_setequal(names(by_group), names(expected))
139+
for (nm in names(expected)) {
140+
expect_equal(by_group[[nm]]$x, expected[[nm]]$x)
141+
expect_equal(by_group[[nm]]$density, expected[[nm]]$density, tolerance = 1e-10)
142+
}
143+
})
144+
145+
114146
test_that("mcmc_dens_chains/mcmc_dens_overlay color chains", {
115147
p1 <- mcmc_dens_chains(arr, pars = "beta[1]", regex_pars = "x\\:",
116148
color_chains = FALSE)

0 commit comments

Comments
 (0)