Skip to content

Commit 2588e9d

Browse files
committed
feat/refactor/fix(distributionregistry): big commit! (1) add JSON + dict, (2) remove validation (to be replaced) (3) remove old param classes (4) import parameters from JSON (5) add discrete and lognormal to DistributionRegistry (5) improve DistributionRegistry$get() error message (6) move transform_to_lnorm into DistributionRegistry and automatically apply as part of create() (7) remove redunant loop in create_batch()
1 parent 3583ae3 commit 2588e9d

25 files changed

Lines changed: 338 additions & 916 deletions

NAMESPACE

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,15 @@
22

33
export(DistributionRegistry)
44
export(add_patient_generator)
5-
export(check_all_param_names)
6-
export(check_all_positive)
75
export(check_log_file_path)
8-
export(check_nonneg_integer)
9-
export(check_param_names)
10-
export(check_param_values)
11-
export(check_positive_integer)
12-
export(check_prob_vector)
136
export(check_run_number)
14-
export(create_asu_arrivals)
15-
export(create_asu_los)
16-
export(create_asu_routing)
177
export(create_asu_trajectory)
18-
export(create_parameters)
19-
export(create_rehab_arrivals)
20-
export(create_rehab_los)
21-
export(create_rehab_routing)
228
export(create_rehab_trajectory)
239
export(filter_warmup)
2410
export(get_occupancy_stats)
2511
export(model)
12+
export(parameters)
2613
export(runner)
27-
export(sample_routing)
28-
export(transform_to_lnorm)
2914
export(valid_inputs)
3015
importFrom(R6,R6Class)
3116
importFrom(dplyr,bind_rows)
@@ -38,6 +23,7 @@ importFrom(future,multisession)
3823
importFrom(future,plan)
3924
importFrom(future,sequential)
4025
importFrom(future.apply,future_lapply)
26+
importFrom(jsonlite,fromJSON)
4127
importFrom(rlang,.data)
4228
importFrom(simmer,add_generator)
4329
importFrom(simmer,add_resource)

R/distribution_registry.R

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,28 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
3030
self$register("uniform", function(min, max) {
3131
function(size = 1L) runif(size, min = min, max = max)
3232
})
33+
self$register("discrete", function(values, prob) {
34+
values <- unlist(values)
35+
prob <- unlist(prob)
36+
# Validation (as not using a pre-made distribution function)
37+
stopifnot(length(values) == length(prob))
38+
stopifnot(all(prob >= 0))
39+
if (round(abs(sum(prob) - 1), 2) > 0.01) {
40+
stop(sprintf(
41+
"'prob' must sum to 1 ± 0.01. Sum: %s", abs(sum(unlist(prob)))
42+
))
43+
}
44+
# Sampling function
45+
function(size = 1L) sample(
46+
values, size = size, replace = TRUE, prob = prob
47+
)
48+
})
3349
self$register("normal", function(mean, sd) {
3450
function(size = 1L) rnorm(size, mean = mean, sd = sd)
3551
})
52+
self$register("lognormal", function(meanlog, sdlog) {
53+
function(size = 1L) rlnorm(size, meanlog = meanlog, sdlog = sdlog)
54+
})
3655
self$register("poisson", function(lambda) {
3756
function(size = 1L) rpois(size, lambda = lambda)
3857
})
@@ -76,22 +95,59 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
7695
#' @return Generator function for the distribution.
7796
get = function(name) {
7897
if (!(name %in% names(self$registry)))
79-
stop(sprintf("Distribution '%s' not found", name), call. = FALSE)
98+
stop(
99+
sprintf(
100+
c("Distribution '%s' not found.\nAvailable distributions:\n\t%s\n",
101+
"Use register() to add new distributions."),
102+
name, paste(names(self$registry), collapse = ",\n\t")
103+
),
104+
call. = FALSE)
80105
self$registry[[name]]
81106
},
82107

108+
#' @description
109+
#' Convert mean/sd to lognormal parameters, returning the corresponding
110+
#' \code{meanlog} and \code{sdlog} parameters used by R's \code{rlnorm()}.
111+
#'
112+
#' @param params Named list with two elements: mean and sd.
113+
#' @return A named list of the same structure, but with elements
114+
#' \code{meanlog} and \code{sdlog} for each patient type.
115+
transform_to_lnorm = function(params) {
116+
variance <- params$sd^2L
117+
sigma_sq <- log(variance / (params$mean^2L) + 1L)
118+
sdlog <- sqrt(sigma_sq)
119+
meanlog <- log(params$mean) - sigma_sq / 2L
120+
list(meanlog = meanlog, sdlog = sdlog)
121+
},
122+
83123
#' @description
84124
#' Create a parameterised sampler for a distribution.
85125
#'
86126
#' The returned function draws random samples of a specified size from
87127
#' the given distribution with fixed parameters.
88128
#'
129+
#' For "lognormal", if "meanlog" and "sdlog" are present in the parameters,
130+
#' they will be used as-is. If not, but "mean" and "sd" are provided, these
131+
#' will be transformed into "meanlog"/"sdlog" automatically.
132+
#'
89133
#' @param name Distribution name
90134
#' @param ... Parameters for the generator
91135
#' @return A function that draws samples when called.
92136
create = function(name, ...) {
137+
dots <- list(...)
138+
if (name == "lognormal") {
139+
if (!is.null(dots$meanlog) && !is.null(dots$sdlog)) {
140+
dots <- dots
141+
} else if (!is.null(dots$mean) && !is.null(dots$sd)) {
142+
transformed <- self$transform_to_lnorm(dots)
143+
dots <- c(transformed, dots[setdiff(names(dots), c("mean", "sd"))])
144+
} else {
145+
stop("Please supply either 'meanlog' and 'sdlog', or 'mean' and 'sd' ",
146+
"for a lognormal distribution.")
147+
}
148+
}
93149
generator <- self$get(name)
94-
generator(...)
150+
do.call(generator, dots)
95151
},
96152

97153
#' @description
@@ -105,11 +161,7 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
105161
#' 'class_name' and 'params'.
106162
#' @return List of parameterised samplers (named if config is named).
107163
create_batch = function(config) {
108-
if (is.list(config) && is.null(names(config))) {
109-
lapply(config, function(cfg) {
110-
do.call(self$create, c(cfg$class_name, cfg$params))
111-
})
112-
} else if (is.list(config)) {
164+
if (is.list(config)) {
113165
lapply(config, function(cfg) {
114166
do.call(self$create, c(cfg$class_name, cfg$params))
115167
})

0 commit comments

Comments
 (0)