Skip to content
Merged
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

* Parameters were added for the `tab_pfn` model: `num_estimators()`, `softmax_temperature()`, `balance_probabilities()`, `average_before_softmax()`, and `training_set_limit()`.

* `parameters()` and the `grid_*()` functions give more information in the error message when non-parameter objects are passed in (#437).


# dials 1.4.2

* `prop_terms()` is a new parameter object used for recipes that do supervised feature selection (#395).
Expand Down
98 changes: 90 additions & 8 deletions R/grids.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,21 @@ grid_regular.parameters <- function(
original = TRUE,
filter = NULL
) {
# test for NA and finalized
# test for empty ...
check_dots_empty()

if (nrow(x) == 0) {
cli::cli_abort("At least one parameter object is required.")
}
# check for unknowns
for (i in seq_along(x$object)) {
check_param(
x$object[[i]],
allow_na = FALSE,
allow_unknown = FALSE,
arg = x$id[i]
)
}

params <- x$object
names(params) <- x$id
grd <- make_regular_grid(
Expand All @@ -91,6 +104,22 @@ grid_regular.list <- function(
original = TRUE,
filter = NULL
) {
check_dots_empty()

if (length(x) == 0) {
cli::cli_abort("At least one parameter object is required.")
}
# check for unknowns
param_names <- names(x)
for (i in seq_along(x)) {
check_param(
x[[i]],
allow_na = FALSE,
allow_unknown = FALSE,
arg = param_arg_name(param_names[i], x[[i]], i)
)
}

y <- parameters(x)
params <- y$object
names(params) <- y$id
Expand All @@ -114,7 +143,19 @@ grid_regular.param <- function(
original = TRUE,
filter = NULL
) {
y <- parameters(list(x, ...))
# check for unknowns
param_list <- list(x, ...)
param_names <- names(param_list)
for (i in seq_along(param_list)) {
check_param(
param_list[[i]],
allow_na = FALSE,
allow_unknown = FALSE,
arg = param_arg_name(param_names[i], param_list[[i]], i)
)
}

y <- parameters(param_list)
params <- y$object
names(params) <- y$id
grd <- make_regular_grid(
Expand All @@ -136,7 +177,7 @@ make_regular_grid <- function(
) {
check_levels(levels, call = call)
check_bool(original, call = call)
validate_params(..., call = call)

filter_quo <- enquo(filter)
param_quos <- quos(...)
params <- map(param_quos, eval_tidy)
Expand Down Expand Up @@ -207,8 +248,21 @@ grid_random.parameters <- function(
original = TRUE,
filter = NULL
) {
# test for NA and finalized
# test for empty ...
check_dots_empty()

if (nrow(x) == 0) {
cli::cli_abort("At least one parameter object is required.")
}
# check for unknowns
for (i in seq_along(x$object)) {
check_param(
x$object[[i]],
allow_na = FALSE,
allow_unknown = FALSE,
arg = x$id[i]
)
}

params <- x$object
names(params) <- x$id
grd <- make_random_grid(
Expand All @@ -224,6 +278,22 @@ grid_random.parameters <- function(
#' @export
#' @rdname grid_regular
grid_random.list <- function(x, ..., size = 5, original = TRUE, filter = NULL) {
check_dots_empty()

if (length(x) == 0) {
cli::cli_abort("At least one parameter object is required.")
}
# check for unknowns
param_names <- names(x)
for (i in seq_along(x)) {
check_param(
x[[i]],
allow_na = FALSE,
allow_unknown = FALSE,
arg = param_arg_name(param_names[i], x[[i]], i)
)
}

y <- parameters(x)
params <- y$object
names(params) <- y$id
Expand All @@ -247,7 +317,19 @@ grid_random.param <- function(
original = TRUE,
filter = NULL
) {
y <- parameters(list(x, ...))
param_list <- list(x, ...)
# check for unknowns
param_names <- names(param_list)
for (i in seq_along(param_list)) {
check_param(
param_list[[i]],
allow_na = FALSE,
allow_unknown = FALSE,
arg = param_arg_name(param_names[i], param_list[[i]], i)
)
}

y <- parameters(param_list)
params <- y$object
names(params) <- y$id
grd <- make_random_grid(
Expand All @@ -269,7 +351,7 @@ make_random_grid <- function(
) {
check_number_whole(size, min = 1, call = call)
check_bool(original, call = call)
validate_params(..., call = call)

filter_quo <- enquo(filter)
param_quos <- quos(...)
params <- map(param_quos, eval_tidy)
Expand Down
54 changes: 54 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,57 @@ check_unique <- function(x, ..., arg = caller_arg(x), call = caller_env()) {
call = call
)
}

check_param <- function(
x,
...,
allow_na = FALSE,
allow_unknown = FALSE,
arg = caller_arg(x),
call = caller_env()
) {
check_dots_empty()

if (allow_na && all(is.na(x))) {
return(invisible(NULL))
}

if (inherits(x, "param")) {
if (allow_unknown || !has_unknowns(x)) {
return(invisible(NULL))
}

cli::cli_abort(
c(
x = "{.arg {arg}} must be a {.cls param} object without unknowns.",
i = "See the {.fn dials::finalize} function."
),
call = call
)
}

what <- if (allow_unknown) {
"a <param> object"
} else {
"a <param> object without unknowns"
}

stop_input_type(
x,
what,
...,
allow_na = allow_na,
arg = arg,
call = call
)
}

param_arg_name <- function(name, x, position) {
if (!is.null(name) && nzchar(name)) {
return(name)
}
if (inherits(x, "param")) {
return(names(x$label))
}
paste("Argument", position)
}
53 changes: 25 additions & 28 deletions R/parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ parameters <- function(x, ...) {
#' @export
#' @rdname parameters
parameters.default <- function(x, ...) {
if (missing(x)) {
cli::cli_abort(
"No input provided. Please supply at least one parameter object."
)
}
cli::cli_abort(
"{.cls parameters} objects cannot be created from {.obj_type_friendly {x}}."
)
Expand All @@ -44,9 +49,14 @@ parameters.param <- function(x, ...) {
parameters.list <- function(x, ...) {
check_dots_empty()

elem_param <- purrr::map_lgl(x, inherits, "param")
if (!all(elem_param)) {
cli::cli_abort("The objects should all be {.cls param} objects.")
param_names <- names(x)
for (i in seq_along(x)) {
check_param(
x[[i]],
allow_na = FALSE,
allow_unknown = TRUE,
arg = param_arg_name(param_names[i], x[[i]], i)
)
}
elem_name <- purrr::map_chr(x, \(.x) names(.x$label))
elem_id <- names(x)
Expand All @@ -66,30 +76,6 @@ parameters.list <- function(x, ...) {
)
}

param_or_na <- function(x) {
inherits(x, "param") || all(is.na(x))
}

check_list_of_param <- function(x, ..., call = caller_env()) {
check_dots_empty()
if (!is.list(x)) {
cli::cli_abort(
"{.arg object} must be a list of {.cls param} objects.",
call = call
)
}
is_good_boi <- map_lgl(x, param_or_na)
if (!all(is_good_boi)) {
offenders <- which(!is_good_boi)

cli::cli_abort(
"{.arg object} elements in the following positions must be {.code NA} or a
{.cls param} object: {offenders}.",
call = call
)
}
}

#' Construct a new parameter set object
#'
#' @param name,id,source,component,component_id Character strings with the same
Expand Down Expand Up @@ -120,7 +106,18 @@ parameters_constr <- function(
check_character(source, call = call)
check_character(component, call = call)
check_character(component_id, call = call)
check_list_of_param(object, call = call)
if (!is.list(object)) {
cli::cli_abort("{.arg object} must be a list.", call = call)
}
for (i in seq_along(object)) {
check_param(
object[[i]],
allow_na = TRUE,
allow_unknown = TRUE,
arg = paste0("object[[", i, "]]"),
call = call
)
}

n_elements <- lengths(list(name, id, source, component, component_id, object))
n_elements_unique <- unique(n_elements)
Expand Down
46 changes: 42 additions & 4 deletions R/space_filling.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,20 @@ grid_space_filling.parameters <- function(
iter = 1000,
original = TRUE
) {
# test for NA and finalized
# test for empty ...
check_dots_empty()

if (nrow(x) == 0) {
cli::cli_abort("At least one parameter object is required.")
}
for (i in seq_along(x$object)) {
check_param(
x$object[[i]],
allow_na = FALSE,
allow_unknown = FALSE,
arg = x$id[i]
)
}

params <- x$object
names(params) <- x$id
grd <- make_sfd(
Expand All @@ -145,6 +157,21 @@ grid_space_filling.list <- function(
iter = 1000,
original = TRUE
) {
check_dots_empty()

if (length(x) == 0) {
cli::cli_abort("At least one parameter object is required.")
}
param_names <- names(x)
for (i in seq_along(x)) {
check_param(
x[[i]],
allow_na = FALSE,
allow_unknown = FALSE,
arg = param_arg_name(param_names[i], x[[i]], i)
)
}

y <- parameters(x)
params <- y$object
names(params) <- y$id
Expand Down Expand Up @@ -172,7 +199,18 @@ grid_space_filling.param <- function(
type = "any",
original = TRUE
) {
y <- parameters(list(x, ...))
param_list <- list(x, ...)
param_names <- names(param_list)
for (i in seq_along(param_list)) {
check_param(
param_list[[i]],
allow_na = FALSE,
allow_unknown = FALSE,
arg = param_arg_name(param_names[i], param_list[[i]], i)
)
}

y <- parameters(param_list)
params <- y$object
names(params) <- y$id
grd <- make_sfd(
Expand Down Expand Up @@ -215,7 +253,7 @@ make_sfd <- function(
check_number_whole(iter, min = 1, call = call)
check_bool(original, call = call)
type <- rlang::arg_match(type, sfd_types)
validate_params(..., call = call)

param_quos <- quos(...)
params <- map(param_quos, eval_tidy)
p <- length(params)
Expand Down
Loading
Loading