Skip to content

Commit d98263c

Browse files
committed
refactor(registry): change R6 DistributionRegistry to function create_distribution_registry()
1 parent af8deda commit d98263c

9 files changed

Lines changed: 318 additions & 381 deletions

NAMESPACE

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Generated by roxygen2: do not edit by hand
22

3-
export(DistributionRegistry)
43
export(add_patient_generator)
54
export(check_all_positive)
65
export(check_allowed_params)
@@ -12,13 +11,16 @@ export(check_positive_integer)
1211
export(check_prob_vector)
1312
export(check_run_number)
1413
export(create_asu_trajectory)
14+
export(create_distribution_registry)
1515
export(create_rehab_trajectory)
1616
export(filter_warmup)
1717
export(get_occupancy_stats)
1818
export(model)
1919
export(parameters)
2020
export(runner)
21+
export(transform_to_lnorm)
2122
export(valid_inputs)
23+
export(validate_single_config)
2224
importFrom(R6,R6Class)
2325
importFrom(dplyr,bind_rows)
2426
importFrom(dplyr,filter)

R/create_distribution_registry.R

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#' Transform mean/sd on original scale to meanlog/sdlog for
2+
#' lognormal
3+
#'
4+
#' @param params A list with elements `mean` and `sd` (on
5+
#' original scale).
6+
#' @return A list with elements `meanlog` and `sdlog`.
7+
#' @export
8+
9+
transform_to_lnorm <- function(params) {
10+
variance <- params$sd^2L
11+
sigma_sq <- log(variance / (params$mean^2L) + 1L)
12+
sdlog <- sqrt(sigma_sq)
13+
meanlog <- log(params$mean) - sigma_sq / 2L
14+
list(meanlog = meanlog, sdlog = sdlog)
15+
}
16+
17+
18+
#' Validate a single sampler config element
19+
#'
20+
#' @param cfg A list expected to contain `class_name` and `params` elements.
21+
#' @return Invisibly `NULL` on success; stops with an error otherwise.
22+
#' @export
23+
24+
validate_single_config <- function(cfg) {
25+
if (!is.list(cfg)) {
26+
stop("Each element of 'config' must be a list.", call. = FALSE)
27+
}
28+
if (is.null(cfg$class_name)) {
29+
stop(
30+
"Each config element must have a 'class_name' entry.",
31+
call. = FALSE
32+
)
33+
}
34+
if (is.null(cfg$params) || !is.list(cfg$params)) {
35+
stop(
36+
"Each config element must have a 'params' list.",
37+
call. = FALSE
38+
)
39+
}
40+
invisible(NULL)
41+
}
42+
43+
44+
#' Create a distribution registry
45+
#'
46+
#' @description
47+
#' Creates a distribution registry that manages and generates
48+
#' parameterised samplers for a variety of probability
49+
#' distributions. Common distributions are included by default,
50+
#' and more can be added.
51+
#'
52+
#' Once defined, you can create sampler objects for each
53+
#' distribution - individually (`dist_create`) or in batches
54+
#' (`dist_create_batch`) - and then easily draw random samples
55+
#' from these objects.
56+
#'
57+
#' @return A list containing registry functions and data.
58+
#' @export
59+
60+
create_distribution_registry <- function() {
61+
62+
# Internal registry storage
63+
registry <- new.env(parent = emptyenv())
64+
65+
# ===== Register default distributions =====
66+
67+
registry[["exponential"]] <- function(mean) {
68+
function(size = 1L) rexp(size, rate = 1L / mean)
69+
}
70+
71+
registry[["uniform"]] <- function(min, max) {
72+
function(size = 1L) runif(size, min = min, max = max)
73+
}
74+
75+
registry[["discrete"]] <- function(values, prob) {
76+
values <- unlist(values)
77+
prob <- unlist(prob)
78+
79+
stopifnot(length(values) == length(prob), prob >= 0L)
80+
81+
if (round(abs(sum(prob) - 1L), 2L) > 0.01) {
82+
stop(
83+
sprintf(
84+
"'prob' must sum to 1 +- 0.01. Sum: %s",
85+
abs(sum(unlist(prob)))
86+
),
87+
call. = FALSE
88+
)
89+
}
90+
91+
function(size = 1L) {
92+
sample(values, size = size, replace = TRUE, prob = prob)
93+
}
94+
}
95+
96+
registry[["normal"]] <- function(mean, sd) {
97+
function(size = 1L) rnorm(size, mean = mean, sd = sd)
98+
}
99+
100+
registry[["lognormal"]] <- function(meanlog = NULL, sdlog = NULL,
101+
mean = NULL, sd = NULL) {
102+
# If meanlog/sdlog provided, use them directly
103+
if (!is.null(meanlog) && !is.null(sdlog)) {
104+
function(size = 1L) {
105+
rlnorm(size, meanlog = meanlog, sdlog = sdlog)
106+
}
107+
108+
} else if (!is.null(mean) && !is.null(sd)) {
109+
110+
# Transform mean/sd to meanlog/sdlog
111+
params <- transform_to_lnorm(list(mean = mean, sd = sd))
112+
function(size = 1L) {
113+
rlnorm(size,
114+
meanlog = params$meanlog,
115+
sdlog = params$sdlog)
116+
}
117+
118+
} else {
119+
stop(
120+
"Please supply either 'meanlog' and 'sdlog', or 'mean' and 'sd' ",
121+
"for a lognormal distribution.",
122+
call. = FALSE
123+
)
124+
}
125+
}
126+
127+
registry[["poisson"]] <- function(lambda) {
128+
function(size = 1L) rpois(size, lambda = lambda)
129+
}
130+
131+
registry[["binomial"]] <- function(size_param, prob) {
132+
function(size = 1L) {
133+
rbinom(size, size = size_param, prob = prob)
134+
}
135+
}
136+
137+
registry[["geometric"]] <- function(prob) {
138+
function(size = 1L) rgeom(size, prob = prob)
139+
}
140+
141+
registry[["beta"]] <- function(shape1, shape2) {
142+
function(size = 1L) {
143+
rbeta(size, shape1 = shape1, shape2 = shape2)
144+
}
145+
}
146+
147+
registry[["gamma"]] <- function(shape, rate) {
148+
function(size = 1L) {
149+
rgamma(size, shape = shape, rate = rate)
150+
}
151+
}
152+
153+
registry[["chisq"]] <- function(df) {
154+
function(size = 1L) rchisq(size, df = df)
155+
}
156+
157+
registry[["student_t"]] <- function(df) {
158+
function(size = 1L) rt(size, df = df)
159+
}
160+
161+
# ===== Public API =====
162+
163+
# Register a new distribution
164+
register <- function(name, generator, overwrite = FALSE) {
165+
if (!overwrite && exists(name, envir = registry, inherits = FALSE)) {
166+
stop(
167+
sprintf(
168+
"Distribution '%s' already exists. Set overwrite = TRUE ",
169+
"to replace it."
170+
),
171+
call. = FALSE
172+
)
173+
}
174+
assign(name, generator, envir = registry)
175+
invisible(TRUE)
176+
}
177+
178+
# Get a registered distribution generator
179+
get_distribution <- function(name) {
180+
if (!exists(name, envir = registry, inherits = FALSE)) {
181+
stop(
182+
sprintf(
183+
paste0(
184+
"Distribution '%s' not found.\n",
185+
"Available distributions:\n\t%s\n",
186+
"Use register() to add new distributions."
187+
),
188+
name,
189+
toString(ls(envir = registry))
190+
),
191+
call. = FALSE
192+
)
193+
}
194+
get(name, envir = registry, inherits = FALSE)
195+
}
196+
197+
# Create a parameterised sampler
198+
create_sampler <- function(name, ...) {
199+
generator <- get_distribution(name)
200+
arg_list <- list(...)
201+
202+
formals_names <- names(formals(generator))
203+
if (!is.null(formals_names)) {
204+
extra_args <- setdiff(names(arg_list), formals_names)
205+
if (length(extra_args) > 0L) {
206+
warning(
207+
sprintf(
208+
"Unused argument(s) for distribution '%s': %s",
209+
name,
210+
toString(extra_args)
211+
),
212+
call. = FALSE
213+
)
214+
}
215+
}
216+
217+
do.call(generator, arg_list)
218+
}
219+
220+
# Create multiple samplers from config
221+
# Each element of `config` is expected to be a list with
222+
# components:
223+
# - class_name: name of the registered distribution
224+
# - params: list of parameters to pass to that distribution's
225+
# generator
226+
create_batch <- function(config) {
227+
if (!is.list(config)) {
228+
stop("config must be a list (named or unnamed).", call. = FALSE)
229+
}
230+
231+
lapply(config, function(cfg) {
232+
validate_single_config(cfg)
233+
do.call(create_sampler, c(list(cfg$class_name), cfg$params))
234+
})
235+
}
236+
237+
# Return public API as a list
238+
list(
239+
register = register,
240+
get = get_distribution,
241+
create = create_sampler,
242+
create_batch = create_batch,
243+
transform_to_lnorm = transform_to_lnorm
244+
)
245+
}

0 commit comments

Comments
 (0)