|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 |
|
| 18 | +#' Parse logical condition formulas |
| 19 | +#' |
| 20 | +#' Converts condition ~ value formulas into Arrow expressions. Unlike |
| 21 | +#' [parse_value_mapping()], the LHS must be a logical expression (not a value |
| 22 | +#' to match against). |
| 23 | +#' |
| 24 | +#' @param formulas A list of two-sided formulas where LHS is a logical condition |
| 25 | +#' and RHS is the value to use when TRUE (e.g., `x > 5 ~ "high"`). |
| 26 | +#' @param mask The data mask for evaluating formula expressions. |
| 27 | +#' |
| 28 | +#' @return A list with `query` (list of logical expressions) and `value` |
| 29 | +#' (list of replacement expressions). |
| 30 | +#' |
| 31 | +#' @keywords internal |
| 32 | +#' @noRd |
| 33 | +parse_condition_formulas <- function(formulas, mask) { |
| 34 | + fn <- call_name(rlang::caller_call()) |
| 35 | + # Compact NULL entries (allows conditional formulas like: if (cond) x ~ y) |
| 36 | + formulas <- compact(formulas) |
| 37 | + n <- length(formulas) |
| 38 | + query <- vector("list", n) |
| 39 | + value <- vector("list", n) |
| 40 | + # Process each formula: condition ~ value |
| 41 | + for (i in seq_len(n)) { |
| 42 | + f <- formulas[[i]] |
| 43 | + if (!is_formula(f, lhs = TRUE)) { |
| 44 | + validation_error(paste0("Each argument to ", fn, "() must be a two-sided formula")) |
| 45 | + } |
| 46 | + # f[[2]] is LHS (logical condition), f[[3]] is RHS (value when TRUE) |
| 47 | + query[[i]] <- arrow_eval(f[[2]], mask) |
| 48 | + value[[i]] <- arrow_eval(f[[3]], mask) |
| 49 | + # Validate LHS is logical (unlike parse_value_mapping which does equality matching) |
| 50 | + if (!call_binding("is.logical", query[[i]])) { |
| 51 | + validation_error(paste0("Left side of each formula in ", fn, "() must be a logical expression")) |
| 52 | + } |
| 53 | + } |
| 54 | + list(query = query, value = value) |
| 55 | +} |
| 56 | + |
| 57 | +#' Create case_when Expression from query/value lists |
| 58 | +#' @param query List of logical Arrow Expressions. |
| 59 | +#' @param value List of value Arrow Expressions. |
| 60 | +#' @return An Arrow Expression representing the case_when. |
| 61 | +#' @keywords internal |
| 62 | +#' @noRd |
| 63 | +build_case_when_expr <- function(query, value) { |
| 64 | + Expression$create( |
| 65 | + "case_when", |
| 66 | + args = c( |
| 67 | + Expression$create( |
| 68 | + "make_struct", |
| 69 | + args = query, |
| 70 | + options = list(field_names = as.character(seq_along(query))) |
| 71 | + ), |
| 72 | + value |
| 73 | + ) |
| 74 | + ) |
| 75 | +} |
| 76 | + |
| 77 | +#' Build a match expression for x against a value (scalar, NA, or vector). |
| 78 | +#' @param x Arrow Expression for the column to match against. |
| 79 | +#' @param match_value Value to match - R scalar, vector, or NA. Expressions |
| 80 | +#' are compared with equality. |
| 81 | +#' @return Arrow Expression that is TRUE when x matches match_value. |
| 82 | +#' @keywords internal |
| 83 | +#' @noRd |
| 84 | +build_match_expr <- function(x, match_value) { |
| 85 | + # Expressions or length-1 non-NA: use equality directly |
| 86 | + if (inherits(match_value, "Expression") || length(match_value) == 1 && !is.na(match_value)) { |
| 87 | + return(x == match_value) |
| 88 | + } |
| 89 | + |
| 90 | + # R scalar NA requires is.na() since x == NA returns NA in Arrow |
| 91 | + if (length(match_value) == 1) { |
| 92 | + return(call_binding("is.na", x)) |
| 93 | + } |
| 94 | + |
| 95 | + # R vector: use %in%, handling NA separately if present |
| 96 | + has_na <- any(is.na(match_value)) |
| 97 | + non_na_values <- match_value[!is.na(match_value)] |
| 98 | + |
| 99 | + if (length(non_na_values) == 0) { |
| 100 | + call_binding("is.na", x) |
| 101 | + } else if (has_na) { |
| 102 | + call_binding("%in%", x, non_na_values) | call_binding("is.na", x) |
| 103 | + } else { |
| 104 | + call_binding("%in%", x, match_value) |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +#' Build query/value lists from parallel from/to vectors. |
| 109 | +#' NA values in `from` use is.na() for matching. |
| 110 | +#' @param x Arrow Expression for the column to match against. |
| 111 | +#' @param from Vector of values to match. |
| 112 | +#' @param to Vector of replacement values (recycled to length of `from`). |
| 113 | +#' @return list(query, value) for use with build_case_when_expr(). |
| 114 | +#' @keywords internal |
| 115 | +#' @noRd |
| 116 | +parse_from_to_mapping <- function(x, from, to) { |
| 117 | + n <- length(from) |
| 118 | + to <- vctrs::vec_recycle(to, n) |
| 119 | + query <- map(from, ~ build_match_expr(x, .x)) |
| 120 | + value <- map(to, Expression$scalar) |
| 121 | + list(query = query, value = value) |
| 122 | +} |
| 123 | + |
| 124 | +#' Build query/value lists from value ~ replacement formulas. |
| 125 | +#' NA values on LHS use is.na() for matching. |
| 126 | +#' @param x Arrow Expression for the column to match against. |
| 127 | +#' @param formulas List of two-sided formulas (value ~ replacement). |
| 128 | +#' @param mask Data mask for evaluating formula expressions. |
| 129 | +#' @param fn Calling function name (for error messages). |
| 130 | +#' @return list(query, value) for use with build_case_when_expr(). |
| 131 | +#' @keywords internal |
| 132 | +#' @noRd |
| 133 | +parse_formula_mapping <- function(x, formulas, mask, fn) { |
| 134 | + # Compact NULL entries (allows conditional formulas like: if (cond) x ~ y) |
| 135 | + formulas <- compact(formulas) |
| 136 | + n <- length(formulas) |
| 137 | + query <- vector("list", n) |
| 138 | + value <- vector("list", n) |
| 139 | + for (i in seq_len(n)) { |
| 140 | + f <- formulas[[i]] |
| 141 | + if (!is_formula(f, lhs = TRUE)) { |
| 142 | + validation_error(paste0("Each argument to ", fn, "() must be a two-sided formula")) |
| 143 | + } |
| 144 | + # f[[2]] is LHS (value to match), f[[3]] is RHS (replacement) |
| 145 | + lhs <- arrow_eval(f[[2]], mask) |
| 146 | + query[[i]] <- build_match_expr(x, lhs) |
| 147 | + value[[i]] <- arrow_eval(f[[3]], mask) |
| 148 | + } |
| 149 | + list(query = query, value = value) |
| 150 | +} |
| 151 | + |
| 152 | +#' Dispatch to formula or from/to parser based on which args are provided. |
| 153 | +#' Returns list(query, value) or NULL if no mappings. |
| 154 | +#' @param x Arrow Expression for the column to match against. |
| 155 | +#' @param formulas List of two-sided formulas (value ~ replacement). |
| 156 | +#' @param from Vector of values to match (alternative to formulas). |
| 157 | +#' @param to Vector of replacement values (used with `from`). |
| 158 | +#' @param mask The data mask for evaluating formula expressions. |
| 159 | +#' @keywords internal |
| 160 | +#' @noRd |
| 161 | +parse_value_mapping <- function(x, formulas = list(), from = NULL, to = NULL, mask) { |
| 162 | + fn <- call_name(rlang::caller_call()) |
| 163 | + # Mutually exclusive interfaces |
| 164 | + if (length(formulas) > 0 && !is.null(from)) { |
| 165 | + validation_error(paste0("Can't use both `...` and `from`/`to` in ", fn, "()")) |
| 166 | + } |
| 167 | + |
| 168 | + if (length(formulas) > 0) { |
| 169 | + parse_formula_mapping(x, formulas, mask, fn) |
| 170 | + } else if (!is.null(from)) { |
| 171 | + if (is.null(to)) { |
| 172 | + validation_error("`to` must be provided when using `from`") |
| 173 | + } |
| 174 | + parse_from_to_mapping(x, from, to) |
| 175 | + } else { |
| 176 | + # No mappings provided |
| 177 | + NULL |
| 178 | + } |
| 179 | +} |
| 180 | + |
18 | 181 | register_bindings_conditional <- function() { |
19 | 182 | register_binding("%in%", function(x, table) { |
20 | 183 | # We use `is_in` here, unlike with Arrays, which use `is_in_meta_binary` |
@@ -133,44 +296,79 @@ register_bindings_conditional <- function() { |
133 | 296 | } |
134 | 297 |
|
135 | 298 | formulas <- list2(...) |
136 | | - n <- length(formulas) |
137 | | - if (n == 0) { |
| 299 | + if (length(formulas) == 0) { |
138 | 300 | validation_error("No cases provided") |
139 | 301 | } |
140 | | - query <- vector("list", n) |
141 | | - value <- vector("list", n) |
142 | | - mask <- caller_env() |
143 | | - for (i in seq_len(n)) { |
144 | | - f <- formulas[[i]] |
145 | | - if (!inherits(f, "formula")) { |
146 | | - validation_error("Each argument to case_when() must be a two-sided formula") |
147 | | - } |
148 | | - query[[i]] <- arrow_eval(f[[2]], mask) |
149 | | - value[[i]] <- arrow_eval(f[[3]], mask) |
150 | | - if (!call_binding("is.logical", query[[i]])) { |
151 | | - validation_error("Left side of each formula in case_when() must be a logical expression") |
152 | | - } |
153 | | - } |
| 302 | + parsed <- parse_condition_formulas(formulas, caller_env()) |
| 303 | + query <- parsed$query |
| 304 | + value <- parsed$value |
154 | 305 | if (!is.null(.default)) { |
155 | 306 | if (length(.default) != 1) { |
156 | | - validation_error(paste0("`.default` must have size 1, not size ", length(.default), ".")) |
| 307 | + arrow_not_supported("`.default` must be size 1; vectors of length > 1") |
157 | 308 | } |
158 | | - |
159 | | - query[n + 1] <- TRUE |
160 | | - value[n + 1] <- .default |
| 309 | + n <- length(query) |
| 310 | + query[[n + 1]] <- TRUE |
| 311 | + value[[n + 1]] <- .default |
161 | 312 | } |
162 | | - Expression$create( |
163 | | - "case_when", |
164 | | - args = c( |
165 | | - Expression$create( |
166 | | - "make_struct", |
167 | | - args = query, |
168 | | - options = list(field_names = as.character(seq_along(query))) |
169 | | - ), |
170 | | - value |
171 | | - ) |
172 | | - ) |
| 313 | + build_case_when_expr(query, value) |
173 | 314 | }, |
174 | 315 | notes = "`.ptype` and `.size` arguments not supported" |
175 | 316 | ) |
| 317 | + |
| 318 | + register_binding("dplyr::replace_when", function(x, ...) { |
| 319 | + formulas <- list2(...) |
| 320 | + if (length(formulas) == 0) { |
| 321 | + return(x) |
| 322 | + } |
| 323 | + parsed <- parse_condition_formulas(formulas, caller_env()) |
| 324 | + query <- parsed$query |
| 325 | + value <- parsed$value |
| 326 | + n <- length(query) |
| 327 | + query[[n + 1]] <- TRUE |
| 328 | + value[[n + 1]] <- x |
| 329 | + build_case_when_expr(query, value) |
| 330 | + }) |
| 331 | + |
| 332 | + register_binding("dplyr::replace_values", function(x, ..., from = NULL, to = NULL) { |
| 333 | + parsed <- parse_value_mapping(x, list2(...), from, to, caller_env()) |
| 334 | + if (is.null(parsed)) { |
| 335 | + return(x) |
| 336 | + } |
| 337 | + query <- parsed$query |
| 338 | + value <- parsed$value |
| 339 | + n <- length(query) |
| 340 | + query[[n + 1]] <- TRUE |
| 341 | + value[[n + 1]] <- x |
| 342 | + build_case_when_expr(query, value) |
| 343 | + }) |
| 344 | + |
| 345 | + register_binding( |
| 346 | + "dplyr::recode_values", |
| 347 | + function(x, ..., from = NULL, to = NULL, default = NULL, unmatched = "default", ptype = NULL) { |
| 348 | + if (!is.null(ptype)) { |
| 349 | + arrow_not_supported("`recode_values()` with `ptype` specified") |
| 350 | + } |
| 351 | + if (unmatched != "default") { |
| 352 | + arrow_not_supported('`recode_values()` with `unmatched` other than "default"') |
| 353 | + } |
| 354 | + |
| 355 | + parsed <- parse_value_mapping(x, list2(...), from, to, caller_env()) |
| 356 | + if (is.null(parsed)) { |
| 357 | + validation_error("`...` can't be empty") |
| 358 | + } |
| 359 | + query <- parsed$query |
| 360 | + value <- parsed$value |
| 361 | + |
| 362 | + if (!is.null(default)) { |
| 363 | + if (length(default) != 1) { |
| 364 | + arrow_not_supported("`default` must be size 1; vectors of length > 1") |
| 365 | + } |
| 366 | + n <- length(query) |
| 367 | + query[[n + 1]] <- TRUE |
| 368 | + value[[n + 1]] <- Expression$scalar(default) |
| 369 | + } |
| 370 | + build_case_when_expr(query, value) |
| 371 | + }, |
| 372 | + notes = "`ptype` argument and `unmatched = \"error\"` not supported" |
| 373 | + ) |
176 | 374 | } |
0 commit comments