|
1 | | -# TODO: |
2 | | -# - Add hyperlink to `measure_set` |
3 | | -# - Add tests |
4 | | - |
5 | 1 | # Modified after https://github.com/tidymodels/yardstick/blob/main/R/aaa-new.R |
6 | 2 |
|
7 | 3 | #' Construct a new measure function |
8 | 4 | #' @keywords summary_stats |
9 | 5 | #' |
10 | 6 | #' @description |
11 | 7 | #' These functions provide convenient wrappers to create the three types of |
12 | | -#' measure functions in `tidyhydro`: measures of central tendency, variability |
13 | | -#' and symmetry. They add a measure-specific class to `fn` and |
14 | | -#' mimic a behaviour of [metric_set][yardstick::metric_set]. These features |
15 | | -#' are used by measure_set. |
| 8 | +#' descriptive statistics functions in `tidyhydro`: measures of central |
| 9 | +#' tendency, variability and symmetry. They add a descriptive |
| 10 | +#' statistics-specific class to `fn` and mimic a behaviour of |
| 11 | +#' [metrics][yardstick::metrics] from `yardstick`, while are not |
| 12 | +#' directly compatible with [metric_set][yardstick::metric_set]. |
16 | 13 | #' |
17 | | -#' See [Custom performance |
18 | | -#' metrics](https://www.tidymodels.org/learn/develop/metrics/) for more |
19 | | -#' information about creating custom metrics. |
| 14 | +#' In order to create a measure set, one can use [measure_set]. |
20 | 15 | #' |
21 | 16 | #' @param fn A function. The measure function to attach a measure-specific class |
22 | 17 | #' |
| 18 | +#' @seealso [measure_set] |
| 19 | +#' |
23 | 20 | #' @name new-measure |
24 | 21 | NULL |
25 | 22 |
|
@@ -70,8 +67,261 @@ format.measure <- function(x, ...) { |
70 | 67 | "tendency_measure" = "Measure of Central Tendency", |
71 | 68 | "var_measure" = "Measure of Variability", |
72 | 69 | "sym_measure" = "Measure of Distribution Symmetry", |
73 | | - "measure" |
| 70 | + "measure" = "Measure" |
| 71 | + ) |
| 72 | + |
| 73 | + paste("A", measure_type) |
| 74 | +} |
| 75 | + |
| 76 | +# Measure set ------------------------------------------------------------ |
| 77 | + |
| 78 | +#' Combine multiple measures into a single function |
| 79 | +#' @keywords summary_stats |
| 80 | +#' @family descriptive statistics |
| 81 | +#' |
| 82 | +#' @description |
| 83 | +#' This function proposes a convenient wrapper to create a measure set, |
| 84 | +#' mimicking a behaviour of [metric_set][yardstick::metric_set]. |
| 85 | +#' |
| 86 | +#' @param ... The bare names of the functions to be included in the measure set. |
| 87 | +#' |
| 88 | +#' @details |
| 89 | +#' All functions must be valid measure functions, i.e. they must be of |
| 90 | +#' class `tendency_measure`, `var_measure` or `sym_measure`. Or created with |
| 91 | +#' [new_tendency_measure], [new_var_measure] or [new_sym_measure]. |
| 92 | +#' |
| 93 | +#' Alike with [metric_set], where it is not allowed to mix different metric |
| 94 | +#' classes, it is allowed to mix different measure classes in [measure_set]. |
| 95 | +#' For example, [gm()] can be used with [cv()] because they |
| 96 | +#' are valid measure functions even though first one is of class |
| 97 | +#' `tendency_measure` and the second one is of class `var_measure`. |
| 98 | +#' |
| 99 | +#' @examples |
| 100 | +#' \dontrun{ |
| 101 | +#' library(tidyhydro) |
| 102 | +#' |
| 103 | +#' # Multiple descriptive statistics |
| 104 | +#' multi_measure <- measure_set(gm, cv) |
| 105 | +#' |
| 106 | +#' # The returned function has arguments: |
| 107 | +#' # fn(data, truth, na_rm = TRUE, ...) |
| 108 | +#' multi_measure(avacha, obs) |
| 109 | +#' |
| 110 | +#' avacha |> |
| 111 | +#' group_by(month = format(date, "%b")) |> |
| 112 | +#' multi_measure(obs) |
| 113 | +#' } |
| 114 | +#' |
| 115 | +#' @export |
| 116 | +measure_set <- function(...) { |
| 117 | + quo_fns <- rlang::enquos(...) |
| 118 | + validate_not_empty(quo_fns) |
| 119 | + |
| 120 | + # Get values and check that they are fns |
| 121 | + fns <- lapply(quo_fns, rlang::eval_tidy) |
| 122 | + validate_inputs_are_functions(fns) |
| 123 | + |
| 124 | + # Add on names, and then check that |
| 125 | + # all fns are of the same function class |
| 126 | + names(fns) <- vapply(quo_fns, get_quo_label, character(1)) |
| 127 | + validate_function_class(fns) |
| 128 | + |
| 129 | + fn_cls <- class(fns[[1]])[[1]] |
| 130 | + |
| 131 | + # All measure functions have the same signature |
| 132 | + if ( |
| 133 | + fn_cls %in% c("tendency_measure", "var_measure", "sym_measure", "measure") |
| 134 | + ) { |
| 135 | + make_measure_function(fns) |
| 136 | + } else { |
| 137 | + cli::cli_abort( |
| 138 | + "{.fn validate_function_class} should have errored on unknown classes.", |
| 139 | + .internal = TRUE |
| 140 | + ) |
| 141 | + } |
| 142 | +} |
| 143 | + |
| 144 | +#' @export |
| 145 | +print.measure_set <- function(x, ...) { |
| 146 | + cat(format(x), sep = "\n") |
| 147 | + invisible(x) |
| 148 | +} |
| 149 | + |
| 150 | +#' @export |
| 151 | +format.measure_set <- function(x, ...) { |
| 152 | + measures <- attr(x, "measures") |
| 153 | + measure_names <- names(measures) |
| 154 | + |
| 155 | + cli::cli_format_method({ |
| 156 | + cli::cli_text("A measure set, consisting of:") |
| 157 | + |
| 158 | + for (i in seq_along(measures)) { |
| 159 | + measure_format <- format(measures[[i]]) |
| 160 | + cli::cli_text("- {.fun {measure_names[i]}}: {measure_format}") |
| 161 | + } |
| 162 | + }) |
| 163 | +} |
| 164 | + |
| 165 | +validate_not_empty <- function(x, call = rlang::caller_env()) { |
| 166 | + if (rlang::is_empty(x)) { |
| 167 | + cli::cli_abort( |
| 168 | + "At least 1 function must be supplied to {.code ...}.", |
| 169 | + call = call |
74 | 170 | ) |
| 171 | + } |
| 172 | +} |
| 173 | + |
| 174 | +validate_inputs_are_functions <- function(fns, call = rlang::caller_env()) { |
| 175 | + is_fun_vec <- vapply(fns, rlang::is_function, logical(1)) |
| 176 | + all_fns <- all(is_fun_vec) |
| 177 | + |
| 178 | + if (!all_fns) { |
| 179 | + not_fn <- which(!is_fun_vec) |
| 180 | + cli::cli_abort( |
| 181 | + "All inputs to {.fn measure_set} must be functions.", |
| 182 | + "These inputs are not: {not_fn}.", |
| 183 | + call = call |
| 184 | + ) |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | +# Validate that all metric functions inherit from valid function classes or |
| 189 | +# combinations of classes |
| 190 | +validate_function_class <- function(fns) { |
| 191 | + fn_cls <- vapply(fns, function(fn) class(fn)[1], character(1)) |
| 192 | + fn_cls_unique <- unique(fn_cls) |
| 193 | + n_unique <- length(fn_cls_unique) |
| 194 | + |
| 195 | + if (n_unique == 0L) { |
| 196 | + return(invisible(fns)) |
| 197 | + } |
| 198 | + |
| 199 | + valid_cls <- c( |
| 200 | + "tendency_measure", |
| 201 | + "var_measure", |
| 202 | + "sym_measure", |
| 203 | + "measure" |
| 204 | + ) |
| 205 | + |
| 206 | + # Allow mixing of different measure types |
| 207 | + if (all(fn_cls_unique %in% valid_cls)) { |
| 208 | + return(invisible(fns)) |
| 209 | + } |
| 210 | + |
| 211 | + # Error handling for invalid classes |
| 212 | + fn_bad_names <- lapply(fn_cls_unique, function(x) { |
| 213 | + names(fns)[fn_cls == x] |
| 214 | + }) |
| 215 | + |
| 216 | + # clean up for nicer printing |
| 217 | + fn_cls_unique <- gsub("_measure", "", fn_cls_unique) |
| 218 | + fn_cls_unique <- gsub("function", "other", fn_cls_unique) |
| 219 | + |
| 220 | + fn_cls_other <- fn_cls_unique == "other" |
| 221 | + |
| 222 | + if (any(fn_cls_other)) { |
| 223 | + fn_cls_other_loc <- which(fn_cls_other) |
| 224 | + fn_other_names <- fn_bad_names[[fn_cls_other_loc]] |
| 225 | + fns_other <- fns[fn_other_names] |
| 226 | + |
| 227 | + env_names_other <- vapply( |
| 228 | + fns_other, |
| 229 | + function(fn) rlang::env_name(rlang::fn_env(fn)), |
| 230 | + character(1) |
| 231 | + ) |
| 232 | + |
| 233 | + fn_bad_names[[fn_cls_other_loc]] <- paste0( |
| 234 | + fn_other_names, |
| 235 | + " <", |
| 236 | + env_names_other, |
| 237 | + ">" |
| 238 | + ) |
| 239 | + } |
| 240 | + |
| 241 | + fn_pastable <- mapply( |
| 242 | + FUN = function(fn_type, fn_names) { |
| 243 | + fn_names <- paste0(fn_names, collapse = ", ") |
| 244 | + paste0("- ", fn_type, " (", fn_names, ")") |
| 245 | + }, |
| 246 | + fn_type = fn_cls_unique, |
| 247 | + fn_names = fn_bad_names, |
| 248 | + USE.NAMES = FALSE |
| 249 | + ) |
| 250 | + |
| 251 | + cli::cli_abort( |
| 252 | + c( |
| 253 | + "x" = "The combination of measure functions must be valid measure types.", |
| 254 | + "i" = "The following measure function types are being mixed:", |
| 255 | + fn_pastable |
| 256 | + ), |
| 257 | + call = rlang::call2("measure_set") |
| 258 | + ) |
| 259 | +} |
| 260 | + |
| 261 | +make_measure_function <- function(fns) { |
| 262 | + measure_function <- function(data, truth, na_rm = TRUE, ...) { |
| 263 | + # Construct common argument set for each measure call |
| 264 | + call_args <- rlang::quos( |
| 265 | + data = data, |
| 266 | + truth = !!rlang::enquo(truth), |
| 267 | + na_rm = na_rm, |
| 268 | + ... = ... |
| 269 | + ) |
| 270 | + |
| 271 | + # Construct calls from the functions + arguments |
| 272 | + calls <- lapply(fns, rlang::call2, !!!call_args) |
| 273 | + |
| 274 | + # For measures, we don't need call_remove_static_arguments since |
| 275 | + # measures don't typically have tweaked/static arguments |
| 276 | + |
| 277 | + # Evaluate |
| 278 | + measure_list <- mapply( |
| 279 | + FUN = eval_safely, |
| 280 | + calls, |
| 281 | + names(calls), |
| 282 | + SIMPLIFY = FALSE, |
| 283 | + USE.NAMES = FALSE |
| 284 | + ) |
| 285 | + |
| 286 | + dplyr::bind_rows(measure_list) |
| 287 | + } |
| 288 | + |
| 289 | + class(measure_function) <- c("measure_set", class(measure_function)) |
| 290 | + attr(measure_function, "measures") <- fns |
| 291 | + measure_function |
| 292 | +} |
| 293 | + |
| 294 | +eval_safely <- function(expr, expr_nm, data = NULL, env = rlang::caller_env()) { |
| 295 | + tryCatch( |
| 296 | + expr = { |
| 297 | + rlang::eval_tidy(expr, data = data, env = env) |
| 298 | + }, |
| 299 | + error = function(cnd) { |
| 300 | + cli::cli_abort( |
| 301 | + "Failed to compute {.fn {expr_nm}}.", |
| 302 | + parent = cnd, |
| 303 | + call = rlang::call2("measure_set") |
| 304 | + ) |
| 305 | + } |
| 306 | + ) |
| 307 | +} |
| 308 | + |
| 309 | +get_quo_label <- function(quo) { |
| 310 | + out <- rlang::as_label(quo) |
| 311 | + |
| 312 | + if (length(out) != 1L) { |
| 313 | + cli::cli_abort( |
| 314 | + "{.code as_label(quo)} resulted in a character vector of length >1.", |
| 315 | + .internal = TRUE |
| 316 | + ) |
| 317 | + } |
| 318 | + |
| 319 | + is_namespaced <- grepl("::", out, fixed = TRUE) |
| 320 | + |
| 321 | + if (is_namespaced) { |
| 322 | + split <- strsplit(out, "::", fixed = TRUE)[[1]] |
| 323 | + out <- split[[2]] |
| 324 | + } |
75 | 325 |
|
76 | | - cat(paste("A", measure_type)) |
| 326 | + out |
77 | 327 | } |
0 commit comments