Skip to content

Commit f8ecb8f

Browse files
committed
Better errors for .default vectorised
1 parent a61b958 commit f8ecb8f

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

r/R/dplyr-funcs-conditional.R

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ register_bindings_conditional <- function() {
282282
value <- parsed$value
283283
if (!is.null(.default)) {
284284
if (length(.default) != 1) {
285-
validation_error(paste0("`.default` must have size 1, not size ", length(.default), "."))
285+
arrow_not_supported("`case_when()` with vectorized `.default`")
286286
}
287287
n <- length(query)
288288
query[[n + 1]] <- TRUE
@@ -326,6 +326,9 @@ register_bindings_conditional <- function() {
326326
if (!is.null(ptype)) {
327327
arrow_not_supported("`recode_values()` with `ptype` specified")
328328
}
329+
if (!unmatched %in% c("default", "error")) {
330+
validation_error('`unmatched` must be either "default" or "error"')
331+
}
329332
if (unmatched == "error") {
330333
arrow_not_supported("`recode_values()` with `unmatched = \"error\"`")
331334
}
@@ -338,6 +341,9 @@ register_bindings_conditional <- function() {
338341
value <- parsed$value
339342

340343
if (!is.null(default)) {
344+
if (length(default) != 1) {
345+
arrow_not_supported("`recode_values()` with vectorized `default`")
346+
}
341347
n <- length(query)
342348
query[[n + 1]] <- TRUE
343349
value[[n + 1]] <- Expression$scalar(default)

r/tests/testthat/test-dplyr-funcs-conditional.R

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ test_that("case_when()", {
296296
)
297297
expect_arrow_eval_error(
298298
case_when(int > 5 ~ 1, .default = c(0, 1)),
299-
"`.default` must have size 1, not size 2",
300-
class = "validation_error"
299+
"`case_when\\(\\)` with vectorized `.default` not supported in Arrow",
300+
class = "arrow_not_supported"
301301
)
302302

303303
expect_arrow_eval_error(
@@ -783,4 +783,14 @@ test_that("recode_values()", {
783783
"`recode_values\\(\\)` with `unmatched = \"error\"` not supported in Arrow",
784784
class = "arrow_not_supported"
785785
)
786+
expect_arrow_eval_error(
787+
recode_values(chr, "a" ~ "A", unmatched = "wat"),
788+
'`unmatched` must be either "default" or "error"',
789+
class = "validation_error"
790+
)
791+
expect_arrow_eval_error(
792+
recode_values(chr, "a" ~ "A", default = c("x", "y")),
793+
"`recode_values\\(\\)` with vectorized `default` not supported in Arrow",
794+
class = "arrow_not_supported"
795+
)
786796
})

0 commit comments

Comments
 (0)