Skip to content

Commit edd71f7

Browse files
authored
Merge pull request #437 from tidymodels/check-param-inputs
Improve checking of parameter inputs to `parameters()` and `grid_*()` functions
2 parents b2dbed8 + 5b4f2ba commit edd71f7

14 files changed

Lines changed: 887 additions & 61 deletions

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
* Parameters were added for the `tab_pfn` model: `num_estimators()`, `softmax_temperature()`, `balance_probabilities()`, `average_before_softmax()`, and `training_set_limit()`.
88

9+
* `parameters()` and the `grid_*()` functions give more information in the error message when non-parameter objects are passed in (#437).
10+
11+
912
# dials 1.4.2
1013

1114
* `prop_terms()` is a new parameter object used for recipes that do supervised feature selection (#395).

R/grids.R

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,21 @@ grid_regular.parameters <- function(
6868
original = TRUE,
6969
filter = NULL
7070
) {
71-
# test for NA and finalized
72-
# test for empty ...
71+
check_dots_empty()
72+
73+
if (nrow(x) == 0) {
74+
cli::cli_abort("At least one parameter object is required.")
75+
}
76+
# check for unknowns
77+
for (i in seq_along(x$object)) {
78+
check_param(
79+
x$object[[i]],
80+
allow_na = FALSE,
81+
allow_unknown = FALSE,
82+
arg = x$id[i]
83+
)
84+
}
85+
7386
params <- x$object
7487
names(params) <- x$id
7588
grd <- make_regular_grid(
@@ -91,6 +104,22 @@ grid_regular.list <- function(
91104
original = TRUE,
92105
filter = NULL
93106
) {
107+
check_dots_empty()
108+
109+
if (length(x) == 0) {
110+
cli::cli_abort("At least one parameter object is required.")
111+
}
112+
# check for unknowns
113+
param_names <- names(x)
114+
for (i in seq_along(x)) {
115+
check_param(
116+
x[[i]],
117+
allow_na = FALSE,
118+
allow_unknown = FALSE,
119+
arg = param_arg_name(param_names[i], x[[i]], i)
120+
)
121+
}
122+
94123
y <- parameters(x)
95124
params <- y$object
96125
names(params) <- y$id
@@ -114,7 +143,19 @@ grid_regular.param <- function(
114143
original = TRUE,
115144
filter = NULL
116145
) {
117-
y <- parameters(list(x, ...))
146+
# check for unknowns
147+
param_list <- list(x, ...)
148+
param_names <- names(param_list)
149+
for (i in seq_along(param_list)) {
150+
check_param(
151+
param_list[[i]],
152+
allow_na = FALSE,
153+
allow_unknown = FALSE,
154+
arg = param_arg_name(param_names[i], param_list[[i]], i)
155+
)
156+
}
157+
158+
y <- parameters(param_list)
118159
params <- y$object
119160
names(params) <- y$id
120161
grd <- make_regular_grid(
@@ -136,7 +177,7 @@ make_regular_grid <- function(
136177
) {
137178
check_levels(levels, call = call)
138179
check_bool(original, call = call)
139-
validate_params(..., call = call)
180+
140181
filter_quo <- enquo(filter)
141182
param_quos <- quos(...)
142183
params <- map(param_quos, eval_tidy)
@@ -207,8 +248,21 @@ grid_random.parameters <- function(
207248
original = TRUE,
208249
filter = NULL
209250
) {
210-
# test for NA and finalized
211-
# test for empty ...
251+
check_dots_empty()
252+
253+
if (nrow(x) == 0) {
254+
cli::cli_abort("At least one parameter object is required.")
255+
}
256+
# check for unknowns
257+
for (i in seq_along(x$object)) {
258+
check_param(
259+
x$object[[i]],
260+
allow_na = FALSE,
261+
allow_unknown = FALSE,
262+
arg = x$id[i]
263+
)
264+
}
265+
212266
params <- x$object
213267
names(params) <- x$id
214268
grd <- make_random_grid(
@@ -224,6 +278,22 @@ grid_random.parameters <- function(
224278
#' @export
225279
#' @rdname grid_regular
226280
grid_random.list <- function(x, ..., size = 5, original = TRUE, filter = NULL) {
281+
check_dots_empty()
282+
283+
if (length(x) == 0) {
284+
cli::cli_abort("At least one parameter object is required.")
285+
}
286+
# check for unknowns
287+
param_names <- names(x)
288+
for (i in seq_along(x)) {
289+
check_param(
290+
x[[i]],
291+
allow_na = FALSE,
292+
allow_unknown = FALSE,
293+
arg = param_arg_name(param_names[i], x[[i]], i)
294+
)
295+
}
296+
227297
y <- parameters(x)
228298
params <- y$object
229299
names(params) <- y$id
@@ -247,7 +317,19 @@ grid_random.param <- function(
247317
original = TRUE,
248318
filter = NULL
249319
) {
250-
y <- parameters(list(x, ...))
320+
param_list <- list(x, ...)
321+
# check for unknowns
322+
param_names <- names(param_list)
323+
for (i in seq_along(param_list)) {
324+
check_param(
325+
param_list[[i]],
326+
allow_na = FALSE,
327+
allow_unknown = FALSE,
328+
arg = param_arg_name(param_names[i], param_list[[i]], i)
329+
)
330+
}
331+
332+
y <- parameters(param_list)
251333
params <- y$object
252334
names(params) <- y$id
253335
grd <- make_random_grid(
@@ -269,7 +351,7 @@ make_random_grid <- function(
269351
) {
270352
check_number_whole(size, min = 1, call = call)
271353
check_bool(original, call = call)
272-
validate_params(..., call = call)
354+
273355
filter_quo <- enquo(filter)
274356
param_quos <- quos(...)
275357
params <- map(param_quos, eval_tidy)

R/misc.R

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,57 @@ check_unique <- function(x, ..., arg = caller_arg(x), call = caller_env()) {
292292
call = call
293293
)
294294
}
295+
296+
check_param <- function(
297+
x,
298+
...,
299+
allow_na = FALSE,
300+
allow_unknown = FALSE,
301+
arg = caller_arg(x),
302+
call = caller_env()
303+
) {
304+
check_dots_empty()
305+
306+
if (allow_na && all(is.na(x))) {
307+
return(invisible(NULL))
308+
}
309+
310+
if (inherits(x, "param")) {
311+
if (allow_unknown || !has_unknowns(x)) {
312+
return(invisible(NULL))
313+
}
314+
315+
cli::cli_abort(
316+
c(
317+
x = "{.arg {arg}} must be a {.cls param} object without unknowns.",
318+
i = "See the {.fn dials::finalize} function."
319+
),
320+
call = call
321+
)
322+
}
323+
324+
what <- if (allow_unknown) {
325+
"a <param> object"
326+
} else {
327+
"a <param> object without unknowns"
328+
}
329+
330+
stop_input_type(
331+
x,
332+
what,
333+
...,
334+
allow_na = allow_na,
335+
arg = arg,
336+
call = call
337+
)
338+
}
339+
340+
param_arg_name <- function(name, x, position) {
341+
if (!is.null(name) && nzchar(name)) {
342+
return(name)
343+
}
344+
if (inherits(x, "param")) {
345+
return(names(x$label))
346+
}
347+
paste("Argument", position)
348+
}

R/parameters.R

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ parameters <- function(x, ...) {
2525
#' @export
2626
#' @rdname parameters
2727
parameters.default <- function(x, ...) {
28+
if (missing(x)) {
29+
cli::cli_abort(
30+
"No input provided. Please supply at least one parameter object."
31+
)
32+
}
2833
cli::cli_abort(
2934
"{.cls parameters} objects cannot be created from {.obj_type_friendly {x}}."
3035
)
@@ -44,9 +49,14 @@ parameters.param <- function(x, ...) {
4449
parameters.list <- function(x, ...) {
4550
check_dots_empty()
4651

47-
elem_param <- purrr::map_lgl(x, inherits, "param")
48-
if (!all(elem_param)) {
49-
cli::cli_abort("The objects should all be {.cls param} objects.")
52+
param_names <- names(x)
53+
for (i in seq_along(x)) {
54+
check_param(
55+
x[[i]],
56+
allow_na = FALSE,
57+
allow_unknown = TRUE,
58+
arg = param_arg_name(param_names[i], x[[i]], i)
59+
)
5060
}
5161
elem_name <- purrr::map_chr(x, \(.x) names(.x$label))
5262
elem_id <- names(x)
@@ -66,30 +76,6 @@ parameters.list <- function(x, ...) {
6676
)
6777
}
6878

69-
param_or_na <- function(x) {
70-
inherits(x, "param") || all(is.na(x))
71-
}
72-
73-
check_list_of_param <- function(x, ..., call = caller_env()) {
74-
check_dots_empty()
75-
if (!is.list(x)) {
76-
cli::cli_abort(
77-
"{.arg object} must be a list of {.cls param} objects.",
78-
call = call
79-
)
80-
}
81-
is_good_boi <- map_lgl(x, param_or_na)
82-
if (!all(is_good_boi)) {
83-
offenders <- which(!is_good_boi)
84-
85-
cli::cli_abort(
86-
"{.arg object} elements in the following positions must be {.code NA} or a
87-
{.cls param} object: {offenders}.",
88-
call = call
89-
)
90-
}
91-
}
92-
9379
#' Construct a new parameter set object
9480
#'
9581
#' @param name,id,source,component,component_id Character strings with the same
@@ -120,7 +106,18 @@ parameters_constr <- function(
120106
check_character(source, call = call)
121107
check_character(component, call = call)
122108
check_character(component_id, call = call)
123-
check_list_of_param(object, call = call)
109+
if (!is.list(object)) {
110+
cli::cli_abort("{.arg object} must be a list.", call = call)
111+
}
112+
for (i in seq_along(object)) {
113+
check_param(
114+
object[[i]],
115+
allow_na = TRUE,
116+
allow_unknown = TRUE,
117+
arg = paste0("object[[", i, "]]"),
118+
call = call
119+
)
120+
}
124121

125122
n_elements <- lengths(list(name, id, source, component, component_id, object))
126123
n_elements_unique <- unique(n_elements)

R/space_filling.R

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,20 @@ grid_space_filling.parameters <- function(
118118
iter = 1000,
119119
original = TRUE
120120
) {
121-
# test for NA and finalized
122-
# test for empty ...
121+
check_dots_empty()
122+
123+
if (nrow(x) == 0) {
124+
cli::cli_abort("At least one parameter object is required.")
125+
}
126+
for (i in seq_along(x$object)) {
127+
check_param(
128+
x$object[[i]],
129+
allow_na = FALSE,
130+
allow_unknown = FALSE,
131+
arg = x$id[i]
132+
)
133+
}
134+
123135
params <- x$object
124136
names(params) <- x$id
125137
grd <- make_sfd(
@@ -145,6 +157,21 @@ grid_space_filling.list <- function(
145157
iter = 1000,
146158
original = TRUE
147159
) {
160+
check_dots_empty()
161+
162+
if (length(x) == 0) {
163+
cli::cli_abort("At least one parameter object is required.")
164+
}
165+
param_names <- names(x)
166+
for (i in seq_along(x)) {
167+
check_param(
168+
x[[i]],
169+
allow_na = FALSE,
170+
allow_unknown = FALSE,
171+
arg = param_arg_name(param_names[i], x[[i]], i)
172+
)
173+
}
174+
148175
y <- parameters(x)
149176
params <- y$object
150177
names(params) <- y$id
@@ -172,7 +199,18 @@ grid_space_filling.param <- function(
172199
type = "any",
173200
original = TRUE
174201
) {
175-
y <- parameters(list(x, ...))
202+
param_list <- list(x, ...)
203+
param_names <- names(param_list)
204+
for (i in seq_along(param_list)) {
205+
check_param(
206+
param_list[[i]],
207+
allow_na = FALSE,
208+
allow_unknown = FALSE,
209+
arg = param_arg_name(param_names[i], param_list[[i]], i)
210+
)
211+
}
212+
213+
y <- parameters(param_list)
176214
params <- y$object
177215
names(params) <- y$id
178216
grd <- make_sfd(
@@ -215,7 +253,7 @@ make_sfd <- function(
215253
check_number_whole(iter, min = 1, call = call)
216254
check_bool(original, call = call)
217255
type <- rlang::arg_match(type, sfd_types)
218-
validate_params(..., call = call)
256+
219257
param_quos <- quos(...)
220258
params <- map(param_quos, eval_tidy)
221259
p <- length(params)

0 commit comments

Comments
 (0)