Skip to content

Commit bb4e492

Browse files
thisisnicjonkeane
andauthored
GH-49534: [R] Implement dplyr recode_values(), replace_values(), and replace_when() (#49536)
### Rationale for this change Implement new dplyr functions ### What changes are included in this PR? Implement them ### Are these changes tested? Yeah ### Are there any user-facing changes? Moar functions ### AI Use Code generated using Claude, with plenty of input from me. I've gone through it in detail and refactored lots, but it needs a last pass before it's ready for review. * GitHub Issue: #49534 Lead-authored-by: Nic Crane <thisisnic@gmail.com> Co-authored-by: Jonathan Keane <jkeane@gmail.com> Signed-off-by: Nic Crane <thisisnic@gmail.com>
1 parent fe298b4 commit bb4e492

File tree

8 files changed

+521
-39
lines changed

8 files changed

+521
-39
lines changed

r/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ importFrom(bit64,str.integer64)
432432
importFrom(glue,glue)
433433
importFrom(methods,as)
434434
importFrom(purrr,as_mapper)
435+
importFrom(purrr,compact)
435436
importFrom(purrr,flatten)
436437
importFrom(purrr,imap)
437438
importFrom(purrr,imap_chr)

r/R/arrow-package.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#' @importFrom stats quantile median na.omit na.exclude na.pass na.fail
1919
#' @importFrom R6 R6Class
2020
#' @importFrom purrr as_mapper map map2 map_chr map2_chr map_dbl map_dfr map_int map_lgl keep imap imap_chr
21-
#' @importFrom purrr flatten reduce walk
21+
#' @importFrom purrr compact flatten reduce walk
2222
#' @importFrom assertthat assert_that is.string
2323
#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos quo
2424
#' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind set_names exec

r/R/dplyr-funcs-conditional.R

Lines changed: 229 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,169 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
#' Parse logical condition formulas
19+
#'
20+
#' Converts condition ~ value formulas into Arrow expressions. Unlike
21+
#' [parse_value_mapping()], the LHS must be a logical expression (not a value
22+
#' to match against).
23+
#'
24+
#' @param formulas A list of two-sided formulas where LHS is a logical condition
25+
#' and RHS is the value to use when TRUE (e.g., `x > 5 ~ "high"`).
26+
#' @param mask The data mask for evaluating formula expressions.
27+
#'
28+
#' @return A list with `query` (list of logical expressions) and `value`
29+
#' (list of replacement expressions).
30+
#'
31+
#' @keywords internal
32+
#' @noRd
33+
parse_condition_formulas <- function(formulas, mask) {
34+
fn <- call_name(rlang::caller_call())
35+
# Compact NULL entries (allows conditional formulas like: if (cond) x ~ y)
36+
formulas <- compact(formulas)
37+
n <- length(formulas)
38+
query <- vector("list", n)
39+
value <- vector("list", n)
40+
# Process each formula: condition ~ value
41+
for (i in seq_len(n)) {
42+
f <- formulas[[i]]
43+
if (!is_formula(f, lhs = TRUE)) {
44+
validation_error(paste0("Each argument to ", fn, "() must be a two-sided formula"))
45+
}
46+
# f[[2]] is LHS (logical condition), f[[3]] is RHS (value when TRUE)
47+
query[[i]] <- arrow_eval(f[[2]], mask)
48+
value[[i]] <- arrow_eval(f[[3]], mask)
49+
# Validate LHS is logical (unlike parse_value_mapping which does equality matching)
50+
if (!call_binding("is.logical", query[[i]])) {
51+
validation_error(paste0("Left side of each formula in ", fn, "() must be a logical expression"))
52+
}
53+
}
54+
list(query = query, value = value)
55+
}
56+
57+
#' Create case_when Expression from query/value lists
58+
#' @param query List of logical Arrow Expressions.
59+
#' @param value List of value Arrow Expressions.
60+
#' @return An Arrow Expression representing the case_when.
61+
#' @keywords internal
62+
#' @noRd
63+
build_case_when_expr <- function(query, value) {
64+
Expression$create(
65+
"case_when",
66+
args = c(
67+
Expression$create(
68+
"make_struct",
69+
args = query,
70+
options = list(field_names = as.character(seq_along(query)))
71+
),
72+
value
73+
)
74+
)
75+
}
76+
77+
#' Build a match expression for x against a value (scalar, NA, or vector).
78+
#' @param x Arrow Expression for the column to match against.
79+
#' @param match_value Value to match - R scalar, vector, or NA. Expressions
80+
#' are compared with equality.
81+
#' @return Arrow Expression that is TRUE when x matches match_value.
82+
#' @keywords internal
83+
#' @noRd
84+
build_match_expr <- function(x, match_value) {
85+
# Expressions or length-1 non-NA: use equality directly
86+
if (inherits(match_value, "Expression") || length(match_value) == 1 && !is.na(match_value)) {
87+
return(x == match_value)
88+
}
89+
90+
# R scalar NA requires is.na() since x == NA returns NA in Arrow
91+
if (length(match_value) == 1) {
92+
return(call_binding("is.na", x))
93+
}
94+
95+
# R vector: use %in%, handling NA separately if present
96+
has_na <- any(is.na(match_value))
97+
non_na_values <- match_value[!is.na(match_value)]
98+
99+
if (length(non_na_values) == 0) {
100+
call_binding("is.na", x)
101+
} else if (has_na) {
102+
call_binding("%in%", x, non_na_values) | call_binding("is.na", x)
103+
} else {
104+
call_binding("%in%", x, match_value)
105+
}
106+
}
107+
108+
#' Build query/value lists from parallel from/to vectors.
109+
#' NA values in `from` use is.na() for matching.
110+
#' @param x Arrow Expression for the column to match against.
111+
#' @param from Vector of values to match.
112+
#' @param to Vector of replacement values (recycled to length of `from`).
113+
#' @return list(query, value) for use with build_case_when_expr().
114+
#' @keywords internal
115+
#' @noRd
116+
parse_from_to_mapping <- function(x, from, to) {
117+
n <- length(from)
118+
to <- vctrs::vec_recycle(to, n)
119+
query <- map(from, ~ build_match_expr(x, .x))
120+
value <- map(to, Expression$scalar)
121+
list(query = query, value = value)
122+
}
123+
124+
#' Build query/value lists from value ~ replacement formulas.
125+
#' NA values on LHS use is.na() for matching.
126+
#' @param x Arrow Expression for the column to match against.
127+
#' @param formulas List of two-sided formulas (value ~ replacement).
128+
#' @param mask Data mask for evaluating formula expressions.
129+
#' @param fn Calling function name (for error messages).
130+
#' @return list(query, value) for use with build_case_when_expr().
131+
#' @keywords internal
132+
#' @noRd
133+
parse_formula_mapping <- function(x, formulas, mask, fn) {
134+
# Compact NULL entries (allows conditional formulas like: if (cond) x ~ y)
135+
formulas <- compact(formulas)
136+
n <- length(formulas)
137+
query <- vector("list", n)
138+
value <- vector("list", n)
139+
for (i in seq_len(n)) {
140+
f <- formulas[[i]]
141+
if (!is_formula(f, lhs = TRUE)) {
142+
validation_error(paste0("Each argument to ", fn, "() must be a two-sided formula"))
143+
}
144+
# f[[2]] is LHS (value to match), f[[3]] is RHS (replacement)
145+
lhs <- arrow_eval(f[[2]], mask)
146+
query[[i]] <- build_match_expr(x, lhs)
147+
value[[i]] <- arrow_eval(f[[3]], mask)
148+
}
149+
list(query = query, value = value)
150+
}
151+
152+
#' Dispatch to formula or from/to parser based on which args are provided.
153+
#' Returns list(query, value) or NULL if no mappings.
154+
#' @param x Arrow Expression for the column to match against.
155+
#' @param formulas List of two-sided formulas (value ~ replacement).
156+
#' @param from Vector of values to match (alternative to formulas).
157+
#' @param to Vector of replacement values (used with `from`).
158+
#' @param mask The data mask for evaluating formula expressions.
159+
#' @keywords internal
160+
#' @noRd
161+
parse_value_mapping <- function(x, formulas = list(), from = NULL, to = NULL, mask) {
162+
fn <- call_name(rlang::caller_call())
163+
# Mutually exclusive interfaces
164+
if (length(formulas) > 0 && !is.null(from)) {
165+
validation_error(paste0("Can't use both `...` and `from`/`to` in ", fn, "()"))
166+
}
167+
168+
if (length(formulas) > 0) {
169+
parse_formula_mapping(x, formulas, mask, fn)
170+
} else if (!is.null(from)) {
171+
if (is.null(to)) {
172+
validation_error("`to` must be provided when using `from`")
173+
}
174+
parse_from_to_mapping(x, from, to)
175+
} else {
176+
# No mappings provided
177+
NULL
178+
}
179+
}
180+
18181
register_bindings_conditional <- function() {
19182
register_binding("%in%", function(x, table) {
20183
# We use `is_in` here, unlike with Arrays, which use `is_in_meta_binary`
@@ -133,44 +296,79 @@ register_bindings_conditional <- function() {
133296
}
134297

135298
formulas <- list2(...)
136-
n <- length(formulas)
137-
if (n == 0) {
299+
if (length(formulas) == 0) {
138300
validation_error("No cases provided")
139301
}
140-
query <- vector("list", n)
141-
value <- vector("list", n)
142-
mask <- caller_env()
143-
for (i in seq_len(n)) {
144-
f <- formulas[[i]]
145-
if (!inherits(f, "formula")) {
146-
validation_error("Each argument to case_when() must be a two-sided formula")
147-
}
148-
query[[i]] <- arrow_eval(f[[2]], mask)
149-
value[[i]] <- arrow_eval(f[[3]], mask)
150-
if (!call_binding("is.logical", query[[i]])) {
151-
validation_error("Left side of each formula in case_when() must be a logical expression")
152-
}
153-
}
302+
parsed <- parse_condition_formulas(formulas, caller_env())
303+
query <- parsed$query
304+
value <- parsed$value
154305
if (!is.null(.default)) {
155306
if (length(.default) != 1) {
156-
validation_error(paste0("`.default` must have size 1, not size ", length(.default), "."))
307+
arrow_not_supported("`.default` must be size 1; vectors of length > 1")
157308
}
158-
159-
query[n + 1] <- TRUE
160-
value[n + 1] <- .default
309+
n <- length(query)
310+
query[[n + 1]] <- TRUE
311+
value[[n + 1]] <- .default
161312
}
162-
Expression$create(
163-
"case_when",
164-
args = c(
165-
Expression$create(
166-
"make_struct",
167-
args = query,
168-
options = list(field_names = as.character(seq_along(query)))
169-
),
170-
value
171-
)
172-
)
313+
build_case_when_expr(query, value)
173314
},
174315
notes = "`.ptype` and `.size` arguments not supported"
175316
)
317+
318+
register_binding("dplyr::replace_when", function(x, ...) {
319+
formulas <- list2(...)
320+
if (length(formulas) == 0) {
321+
return(x)
322+
}
323+
parsed <- parse_condition_formulas(formulas, caller_env())
324+
query <- parsed$query
325+
value <- parsed$value
326+
n <- length(query)
327+
query[[n + 1]] <- TRUE
328+
value[[n + 1]] <- x
329+
build_case_when_expr(query, value)
330+
})
331+
332+
register_binding("dplyr::replace_values", function(x, ..., from = NULL, to = NULL) {
333+
parsed <- parse_value_mapping(x, list2(...), from, to, caller_env())
334+
if (is.null(parsed)) {
335+
return(x)
336+
}
337+
query <- parsed$query
338+
value <- parsed$value
339+
n <- length(query)
340+
query[[n + 1]] <- TRUE
341+
value[[n + 1]] <- x
342+
build_case_when_expr(query, value)
343+
})
344+
345+
register_binding(
346+
"dplyr::recode_values",
347+
function(x, ..., from = NULL, to = NULL, default = NULL, unmatched = "default", ptype = NULL) {
348+
if (!is.null(ptype)) {
349+
arrow_not_supported("`recode_values()` with `ptype` specified")
350+
}
351+
if (unmatched != "default") {
352+
arrow_not_supported('`recode_values()` with `unmatched` other than "default"')
353+
}
354+
355+
parsed <- parse_value_mapping(x, list2(...), from, to, caller_env())
356+
if (is.null(parsed)) {
357+
validation_error("`...` can't be empty")
358+
}
359+
query <- parsed$query
360+
value <- parsed$value
361+
362+
if (!is.null(default)) {
363+
if (length(default) != 1) {
364+
arrow_not_supported("`default` must be size 1; vectors of length > 1")
365+
}
366+
n <- length(query)
367+
query[[n + 1]] <- TRUE
368+
value[[n + 1]] <- Expression$scalar(default)
369+
}
370+
build_case_when_expr(query, value)
371+
},
372+
notes = "`ptype` argument and `unmatched = \"error\"` not supported"
373+
)
176374
}

r/R/dplyr-funcs-doc.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#'
2222
#' The `arrow` package contains methods for 38 `dplyr` table functions, many of
2323
#' which are "verbs" that do transformations to one or more tables.
24-
#' The package also has mappings of 226 R functions to the corresponding
24+
#' The package also has mappings of 229 R functions to the corresponding
2525
#' functions in the Arrow compute library. These allow you to write code inside
2626
#' of `dplyr` methods that call R functions, including many in packages like
2727
#' `stringr` and `lubridate`, and they will get translated to Arrow and run
@@ -214,6 +214,9 @@
214214
#' * [`if_else()`][dplyr::if_else()]
215215
#' * [`n()`][dplyr::n()]
216216
#' * [`n_distinct()`][dplyr::n_distinct()]
217+
#' * [`recode_values()`][dplyr::recode_values()]: `ptype` argument and `unmatched = "error"` not supported
218+
#' * [`replace_values()`][dplyr::replace_values()]
219+
#' * [`replace_when()`][dplyr::replace_when()]
217220
#' * [`when_all()`][dplyr::when_all()]
218221
#' * [`when_any()`][dplyr::when_any()]
219222
#'

r/man/acero.Rd

Lines changed: 5 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

r/man/read_json_arrow.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

r/man/schema.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)