@@ -92,8 +92,7 @@ crps.matrix <- function(x, x2, y, ..., permutations = 1) {
9292# ' @rdname crps
9393# ' @export
9494crps.numeric <- function (x , x2 , y , ... , permutations = 1 ) {
95- stopifnot(length(x ) == length(x2 ),
96- length(y ) == 1 )
95+ stopifnot(length(x ) == length(x2 ), length(y ) == 1 )
9796 crps.matrix(as.matrix(x ), as.matrix(x2 ), y , permutations )
9897}
9998
@@ -106,23 +105,32 @@ crps.numeric <- function(x, x2, y, ..., permutations = 1) {
106105# ' @param cores The number of cores to use for parallelization of `[psis()]`.
107106# ' See [psis()] for details.
108107loo_crps.matrix <-
109- function (x ,
110- x2 ,
111- y ,
112- log_lik ,
113- ... ,
114- permutations = 1 ,
115- r_eff = 1 ,
116- cores = getOption(" mc.cores" , 1 )) {
117- validate_crps_input(x , x2 , y , log_lik )
118- repeats <- replicate(permutations ,
119- EXX_loo_compute(x , x2 , log_lik , r_eff = r_eff , ... ),
120- simplify = F )
121- EXX <- Reduce(`+` , repeats ) / permutations
122- psis_obj <- psis(- log_lik , r_eff = r_eff , cores = cores )
123- EXy <- E_loo(abs(sweep(x , 2 , y )), psis_obj , log_ratios = - log_lik , ... )$ value
124- crps_output(.crps_fun(EXX , EXy ))
125- }
108+ function (
109+ x ,
110+ x2 ,
111+ y ,
112+ log_lik ,
113+ ... ,
114+ permutations = 1 ,
115+ r_eff = 1 ,
116+ cores = getOption(" mc.cores" , 1 )
117+ ) {
118+ validate_crps_input(x , x2 , y , log_lik )
119+ repeats <- replicate(
120+ permutations ,
121+ EXX_loo_compute(x , x2 , log_lik , r_eff = r_eff , ... ),
122+ simplify = F
123+ )
124+ EXX <- Reduce(`+` , repeats ) / permutations
125+ psis_obj <- psis(- log_lik , r_eff = r_eff , cores = cores )
126+ EXy <- E_loo(
127+ abs(sweep(x , 2 , y )),
128+ psis_obj ,
129+ log_ratios = - log_lik ,
130+ ...
131+ )$ value
132+ crps_output(.crps_fun(EXX , EXy ))
133+ }
126134
127135
128136# ' @rdname crps
@@ -138,8 +146,7 @@ scrps.matrix <- function(x, x2, y, ..., permutations = 1) {
138146# ' @rdname crps
139147# ' @export
140148scrps.numeric <- function (x , x2 , y , ... , permutations = 1 ) {
141- stopifnot(length(x ) == length(x2 ),
142- length(y ) == 1 )
149+ stopifnot(length(x ) == length(x2 ), length(y ) == 1 )
143150 scrps.matrix(as.matrix(x ), as.matrix(x2 ), y , permutations )
144151}
145152
@@ -155,40 +162,54 @@ loo_scrps.matrix <-
155162 ... ,
156163 permutations = 1 ,
157164 r_eff = 1 ,
158- cores = getOption(" mc.cores" , 1 )) {
159- validate_crps_input(x , x2 , y , log_lik )
160- repeats <- replicate(permutations ,
161- EXX_loo_compute(x , x2 , log_lik , r_eff = r_eff , ... ),
162- simplify = F )
163- EXX <- Reduce(`+` , repeats ) / permutations
164- psis_obj <- psis(- log_lik , r_eff = r_eff , cores = cores )
165- EXy <- E_loo(abs(sweep(x , 2 , y )), psis_obj , log_ratios = - log_lik , ... )$ value
166- crps_output(.crps_fun(EXX , EXy , scale = TRUE ))
167- }
165+ cores = getOption(" mc.cores" , 1 )
166+ ) {
167+ validate_crps_input(x , x2 , y , log_lik )
168+ repeats <- replicate(
169+ permutations ,
170+ EXX_loo_compute(x , x2 , log_lik , r_eff = r_eff , ... ),
171+ simplify = F
172+ )
173+ EXX <- Reduce(`+` , repeats ) / permutations
174+ psis_obj <- psis(- log_lik , r_eff = r_eff , cores = cores )
175+ EXy <- E_loo(
176+ abs(sweep(x , 2 , y )),
177+ psis_obj ,
178+ log_ratios = - log_lik ,
179+ ...
180+ )$ value
181+ crps_output(.crps_fun(EXX , EXy , scale = TRUE ))
182+ }
168183
169184# ------------ Internals ----------------
170185
171-
172186EXX_compute <- function (x , x2 ) {
173187 S <- nrow(x )
174- colMeans(abs(x - x2 [sample(1 : S ),]))
188+ colMeans(abs(x - x2 [sample(1 : S ), ]))
175189}
176190
177191
178192EXX_loo_compute <- function (x , x2 , log_lik , r_eff = 1 , ... ) {
179193 S <- nrow(x )
180- shuffle <- sample (1 : S )
181- x2 <- x2 [shuffle ,]
182- log_lik2 <- log_lik [shuffle ,]
183- psis_obj_joint <- psis(- log_lik - log_lik2 , r_eff = r_eff )
184- E_loo(abs(x - x2 ), psis_obj_joint , log_ratios = - log_lik - log_lik2 , ... )$ value
194+ shuffle <- sample(1 : S )
195+ x2 <- x2 [shuffle , ]
196+ log_lik2 <- log_lik [shuffle , ]
197+ psis_obj_joint <- psis(- log_lik - log_lik2 , r_eff = r_eff )
198+ E_loo(
199+ abs(x - x2 ),
200+ psis_obj_joint ,
201+ log_ratios = - log_lik - log_lik2 ,
202+ ...
203+ )$ value
185204}
186205
187206
188207# ' Function to compute crps and scrps
189208# ' @noRd
190209.crps_fun <- function (EXX , EXy , scale = FALSE ) {
191- if (scale ) return (- EXy / EXX - 0.5 * log(EXX ))
210+ if (scale ) {
211+ return (- EXy / EXX - 0.5 * log(EXX ))
212+ }
192213 0.5 * EXX - EXy
193214}
194215
@@ -208,11 +229,12 @@ crps_output <- function(crps_pw) {
208229# ' Check that predictive draws and observed data are of compatible shape
209230# ' @noRd
210231validate_crps_input <- function (x , x2 , y , log_lik = NULL ) {
211- stopifnot(is.numeric(x ),
212- is.numeric(x2 ),
213- is.numeric(y ),
214- identical(dim(x ), dim(x2 )),
215- ncol(x ) == length(y ),
216- ifelse(is.null(log_lik ), TRUE , identical(dim(log_lik ), dim(x )))
217- )
232+ stopifnot(
233+ is.numeric(x ),
234+ is.numeric(x2 ),
235+ is.numeric(y ),
236+ identical(dim(x ), dim(x2 )),
237+ ncol(x ) == length(y ),
238+ ifelse(is.null(log_lik ), TRUE , identical(dim(log_lik ), dim(x )))
239+ )
218240}
0 commit comments