Skip to content

Commit b7e1fd1

Browse files
committed
feat(checkmate): install checkmate + refactor validation using it (#24)
1 parent d98263c commit b7e1fd1

5 files changed

Lines changed: 86 additions & 32 deletions

File tree

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Encoding: UTF-8
2323
LazyData: true
2424
RoxygenNote: 7.3.2
2525
Imports:
26+
checkmate,
2627
dplyr,
2728
future,
2829
future.apply,

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ export(transform_to_lnorm)
2222
export(valid_inputs)
2323
export(validate_single_config)
2424
importFrom(R6,R6Class)
25+
importFrom(checkmate,assert_character)
26+
importFrom(checkmate,assert_flag)
27+
importFrom(checkmate,assert_int)
28+
importFrom(checkmate,assert_integer)
29+
importFrom(checkmate,assert_list)
30+
importFrom(checkmate,assert_numeric)
2531
importFrom(dplyr,bind_rows)
2632
importFrom(dplyr,filter)
2733
importFrom(dplyr,group_by)

R/simulation-package.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"_PACKAGE"
33

44
## usethis namespace: start
5+
#' @importFrom checkmate assert_character assert_flag assert_int assert_integer
6+
#' @importFrom checkmate assert_list assert_numeric
57
#' @importFrom dplyr bind_rows filter group_by mutate rowwise ungroup
68
#' @importFrom future availableCores multisession plan sequential
79
#' @importFrom future.apply future_lapply

R/validate_model_inputs.R

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ valid_inputs <- function(run_number, param) {
2222
#' @export
2323

2424
check_run_number <- function(run_number) {
25-
if (run_number < 0L || run_number %% 1L != 0L) {
26-
stop("The run number must be a non-negative integer. Provided: ",
27-
run_number, call. = FALSE)
28-
}
25+
assert_int(run_number, lower = 0L, .var.name = "run_number")
2926
}
3027

3128

@@ -38,13 +35,19 @@ check_run_number <- function(run_number) {
3835
#' @export
3936

4037
check_log_file_path <- function(param) {
38+
assert_list(param, names = "unique", .var.name = "param")
39+
4140
log_to_file <- param[["log_to_file"]]
4241
file_path <- param[["file_path"]]
43-
if (isTRUE(log_to_file) && (is.null(file_path) || !nzchar(file_path))) {
44-
stop(
45-
"If 'log_to_file' is TRUE, you must provide a non-NULL, ",
46-
"non-empty 'file_path'.",
47-
call. = FALSE
42+
43+
assert_flag(log_to_file, null.ok = TRUE, .var.name = "log_to_file")
44+
45+
if (isTRUE(log_to_file)) {
46+
assert_character(
47+
file_path,
48+
min.chars = 1L,
49+
len = 1L,
50+
.var.name = "file_path"
4851
)
4952
}
5053
}
@@ -61,6 +64,7 @@ check_log_file_path <- function(param) {
6164
#' @export
6265

6366
check_param_names <- function(param) {
67+
assert_list(param, names = "unique", .var.name = "param")
6468

6569
# Check the distribution names....
6670
# Import JSON with the required names
@@ -100,15 +104,28 @@ check_param_names <- function(param) {
100104

101105
check_prob_vector <- function(vec, name) {
102106
if (!is.numeric(vec)) {
103-
stop('Routing vector "', name, '" must be numeric.', call. = FALSE)
107+
stop(sprintf('Routing vector "%s" must be numeric.', name), call. = FALSE)
104108
}
109+
105110
if (any(vec < 0L | vec > 1L)) {
106-
stop('All values in routing vector "', name, '" must be between 0 and 1.',
107-
call. = FALSE)
111+
stop(
112+
sprintf(
113+
'All values in routing vector "%s" must be between 0 and 1.',
114+
name
115+
),
116+
call. = FALSE
117+
)
108118
}
109-
if (sum(vec) < 0.99 || sum(vec) > 1.01) {
110-
stop('Values in routing vector "', name, '" must sum to 1 (+-0.01).',
111-
call. = FALSE)
119+
120+
sum_vec <- sum(vec)
121+
if (sum_vec < 0.99 || sum_vec > 1.01) {
122+
stop(
123+
sprintf(
124+
'Values in routing vector "%s" must sum to 1 (+-0.01).',
125+
name
126+
),
127+
call. = FALSE
128+
)
112129
}
113130
}
114131

@@ -124,12 +141,7 @@ check_prob_vector <- function(vec, name) {
124141
#' @export
125142

126143
check_positive_integer <- function(x, name) {
127-
if (is.null(x) || x <= 0L || x %% 1L != 0L) {
128-
stop(
129-
sprintf('The parameter "%s" must be an integer greater than 0.', name),
130-
call. = FALSE
131-
)
132-
}
144+
assert_int(x, lower = 1L, .var.name = name)
133145
}
134146

135147
#' Check if all values are positive
@@ -143,11 +155,14 @@ check_positive_integer <- function(x, name) {
143155
#' @export
144156

145157
check_all_positive <- function(x, name) {
146-
if (!is.null(x) && any(unlist(x) <= 0L)) {
147-
stop(
148-
sprintf('All values in "%s" must be greater than 0.', name),
149-
call. = FALSE
150-
)
158+
if (!is.null(x)) {
159+
val <- unlist(x)
160+
if (any(val <= 0L)) {
161+
stop(
162+
sprintf('All values in "%s" must be greater than 0.', name),
163+
call. = FALSE
164+
)
165+
}
151166
}
152167
}
153168

@@ -162,12 +177,7 @@ check_all_positive <- function(x, name) {
162177
#' @export
163178

164179
check_nonneg_integer <- function(x, name) {
165-
if (is.null(x) || x < 0L || x %% 1L != 0L) {
166-
stop(
167-
sprintf('The parameter "%s" must be an integer >= 0.', name),
168-
call. = FALSE
169-
)
170-
}
180+
assert_int(x, lower = 0L, .var.name = name)
171181
}
172182

173183

@@ -181,6 +191,17 @@ check_nonneg_integer <- function(x, name) {
181191
#' @export
182192

183193
check_allowed_params <- function(object_name, actual_names, allowed_names) {
194+
assert_character(
195+
actual_names,
196+
any.missing = FALSE,
197+
.var.name = "actual_names"
198+
)
199+
assert_character(
200+
allowed_names,
201+
any.missing = FALSE,
202+
.var.name = "allowed_names"
203+
)
204+
184205
extra_names <- setdiff(actual_names, allowed_names)
185206
missing_names <- setdiff(allowed_names, actual_names)
186207
if (length(extra_names) > 0L) {
@@ -213,6 +234,7 @@ check_allowed_params <- function(object_name, actual_names, allowed_names) {
213234
#' @export
214235

215236
check_param_values <- function(param) {
237+
assert_list(param, names = "unique", .var.name = "param")
216238

217239
# High-level parameters (runs, simulation run length)
218240
check_positive_integer(param[["number_of_runs"]], "number_of_runs")
@@ -251,10 +273,21 @@ check_param_values <- function(param) {
251273
if (type == "discrete") {
252274
vals <- unlist(params$values)
253275
prob <- unlist(params$prob)
276+
277+
# For discrete distributions, values are typically character route names,
278+
# so we just check they exist and are non-empty
279+
assert_character(
280+
vals,
281+
any.missing = FALSE,
282+
min.chars = 1L,
283+
.var.name = paste0(dist_name, "$params$values")
284+
)
285+
254286
if (length(vals) != length(prob)) {
255287
stop(sprintf("Discrete dist '%s' values and prob length mismatch",
256288
dist_name), call. = FALSE)
257289
}
290+
258291
check_prob_vector(prob, paste0(dist_name, "$params$prob"))
259292
check_allowed_params(
260293
object_name = paste0("param$dist_config$", dist_name, "$params"),

renv.lock

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,18 @@
205205
],
206206
"Hash": "d7e13f49c19103ece9e58ad2d83a7354"
207207
},
208+
"checkmate": {
209+
"Package": "checkmate",
210+
"Version": "2.3.3",
211+
"Source": "Repository",
212+
"Repository": "CRAN",
213+
"Requirements": [
214+
"R",
215+
"backports",
216+
"utils"
217+
],
218+
"Hash": "92b3b84e5789dcde37cb0c409c54553a"
219+
},
208220
"class": {
209221
"Package": "class",
210222
"Version": "7.3-22",

0 commit comments

Comments
 (0)