@@ -30,9 +30,28 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
3030 self $ register(" uniform" , function (min , max ) {
3131 function (size = 1L ) runif(size , min = min , max = max )
3232 })
33+ self $ register(" discrete" , function (values , prob ) {
34+ values <- unlist(values )
35+ prob <- unlist(prob )
36+ # Validation (as not using a pre-made distribution function)
37+ stopifnot(length(values ) == length(prob ))
38+ stopifnot(all(prob > = 0 ))
39+ if (round(abs(sum(prob ) - 1 ), 2 ) > 0.01 ) {
40+ stop(sprintf(
41+ " 'prob' must sum to 1 ± 0.01. Sum: %s" , abs(sum(unlist(prob )))
42+ ))
43+ }
44+ # Sampling function
45+ function (size = 1L ) sample(
46+ values , size = size , replace = TRUE , prob = prob
47+ )
48+ })
3349 self $ register(" normal" , function (mean , sd ) {
3450 function (size = 1L ) rnorm(size , mean = mean , sd = sd )
3551 })
52+ self $ register(" lognormal" , function (meanlog , sdlog ) {
53+ function (size = 1L ) rlnorm(size , meanlog = meanlog , sdlog = sdlog )
54+ })
3655 self $ register(" poisson" , function (lambda ) {
3756 function (size = 1L ) rpois(size , lambda = lambda )
3857 })
@@ -76,22 +95,59 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
7695 # ' @return Generator function for the distribution.
7796 get = function (name ) {
7897 if (! (name %in% names(self $ registry )))
79- stop(sprintf(" Distribution '%s' not found" , name ), call. = FALSE )
98+ stop(
99+ sprintf(
100+ c(" Distribution '%s' not found.\n Available distributions:\n\t %s\n " ,
101+ " Use register() to add new distributions." ),
102+ name , paste(names(self $ registry ), collapse = " ,\n\t " )
103+ ),
104+ call. = FALSE )
80105 self $ registry [[name ]]
81106 },
82107
108+ # ' @description
109+ # ' Convert mean/sd to lognormal parameters, returning the corresponding
110+ # ' \code{meanlog} and \code{sdlog} parameters used by R's \code{rlnorm()}.
111+ # '
112+ # ' @param params Named list with two elements: mean and sd.
113+ # ' @return A named list of the same structure, but with elements
114+ # ' \code{meanlog} and \code{sdlog} for each patient type.
115+ transform_to_lnorm = function (params ) {
116+ variance <- params $ sd ^ 2L
117+ sigma_sq <- log(variance / (params $ mean ^ 2L ) + 1L )
118+ sdlog <- sqrt(sigma_sq )
119+ meanlog <- log(params $ mean ) - sigma_sq / 2L
120+ list (meanlog = meanlog , sdlog = sdlog )
121+ },
122+
83123 # ' @description
84124 # ' Create a parameterised sampler for a distribution.
85125 # '
86126 # ' The returned function draws random samples of a specified size from
87127 # ' the given distribution with fixed parameters.
88128 # '
129+ # ' For "lognormal", if "meanlog" and "sdlog" are present in the parameters,
130+ # ' they will be used as-is. If not, but "mean" and "sd" are provided, these
131+ # ' will be transformed into "meanlog"/"sdlog" automatically.
132+ # '
89133 # ' @param name Distribution name
90134 # ' @param ... Parameters for the generator
91135 # ' @return A function that draws samples when called.
92136 create = function (name , ... ) {
137+ dots <- list (... )
138+ if (name == " lognormal" ) {
139+ if (! is.null(dots $ meanlog ) && ! is.null(dots $ sdlog )) {
140+ dots <- dots
141+ } else if (! is.null(dots $ mean ) && ! is.null(dots $ sd )) {
142+ transformed <- self $ transform_to_lnorm(dots )
143+ dots <- c(transformed , dots [setdiff(names(dots ), c(" mean" , " sd" ))])
144+ } else {
145+ stop(" Please supply either 'meanlog' and 'sdlog', or 'mean' and 'sd' " ,
146+ " for a lognormal distribution." )
147+ }
148+ }
93149 generator <- self $ get(name )
94- generator( ... )
150+ do.call( generator , dots )
95151 },
96152
97153 # ' @description
@@ -105,11 +161,7 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
105161 # ' 'class_name' and 'params'.
106162 # ' @return List of parameterised samplers (named if config is named).
107163 create_batch = function (config ) {
108- if (is.list(config ) && is.null(names(config ))) {
109- lapply(config , function (cfg ) {
110- do.call(self $ create , c(cfg $ class_name , cfg $ params ))
111- })
112- } else if (is.list(config )) {
164+ if (is.list(config )) {
113165 lapply(config , function (cfg ) {
114166 do.call(self $ create , c(cfg $ class_name , cfg $ params ))
115167 })
0 commit comments