Skip to content

Commit c73be4d

Browse files
committed
Fixed script calling
1 parent faaca77 commit c73be4d

1 file changed

Lines changed: 48 additions & 7 deletions

File tree

touchstone/script.R

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,36 @@ touchstone::branch_install()
66

77
# These synthetic workloads are large enough to expose real slowdowns in the
88
# core `loo()` paths, but still short enough to keep PR feedback reasonably fast.
9+
touchstone::benchmark_run(
10+
expr_before_benchmark = {
11+
suppressPackageStartupMessages(library(loo))
12+
13+
matrix_draws <- 2000L
14+
matrix_obs <- 500L
15+
n_chains <- 4L
16+
stopifnot(matrix_draws %% n_chains == 0L)
17+
18+
set.seed(20260408)
19+
mu <- stats::rnorm(matrix_draws)
20+
sigma <- exp(stats::rnorm(matrix_draws, mean = -0.2, sd = 0.15))
21+
y <- stats::rnorm(matrix_obs, mean = 0.3, sd = 1.2)
22+
23+
log_lik_matrix <- vapply(
24+
y,
25+
FUN = function(y_i) {
26+
stats::dnorm(y_i, mean = mu, sd = sigma, log = TRUE)
27+
},
28+
FUN.VALUE = numeric(matrix_draws)
29+
)
30+
matrix_r_eff <- rep(1, matrix_obs)
31+
32+
},
33+
loo_matrix = {
34+
suppressWarnings(loo(log_lik_matrix, r_eff = matrix_r_eff, cores = 1))
35+
},
36+
n = 10
37+
)
38+
939
touchstone::benchmark_run(
1040
expr_before_benchmark = {
1141
suppressPackageStartupMessages(library(loo))
@@ -35,9 +65,26 @@ touchstone::benchmark_run(
3565
rows <- ((chain - 1L) * n_iter + 1L):(chain * n_iter)
3666
log_lik_array[, chain, ] <- log_lik_matrix[rows, ]
3767
}
68+
},
69+
loo_array = {
70+
suppressWarnings(loo(log_lik_array, r_eff = matrix_r_eff, cores = 1))
71+
},
72+
n = 10
73+
)
74+
75+
touchstone::benchmark_run(
76+
expr_before_benchmark = {
77+
suppressPackageStartupMessages(library(loo))
3878

79+
matrix_draws <- 2000L
3980
function_obs <- 250L
40-
function_data <- data.frame(y = y[seq_len(function_obs)])
81+
82+
set.seed(20260408)
83+
mu <- stats::rnorm(matrix_draws)
84+
sigma <- exp(stats::rnorm(matrix_draws, mean = -0.2, sd = 0.15))
85+
y <- stats::rnorm(function_obs, mean = 0.3, sd = 1.2)
86+
87+
function_data <- data.frame(y = y)
4188
function_draws <- cbind(mu = mu, sigma = sigma)
4289
function_r_eff <- rep(1, function_obs)
4390
llfun <- function(data_i, draws) {
@@ -49,12 +96,6 @@ touchstone::benchmark_run(
4996
)
5097
}
5198
},
52-
loo_matrix = {
53-
suppressWarnings(loo(log_lik_matrix, r_eff = matrix_r_eff, cores = 1))
54-
},
55-
loo_array = {
56-
suppressWarnings(loo(log_lik_array, r_eff = matrix_r_eff, cores = 1))
57-
},
5899
loo_function = {
59100
suppressWarnings(
60101
loo(

0 commit comments

Comments
 (0)