Skip to content

Commit e0f18a8

Browse files
authored
added system of Derived Quantities to MCMC
Added a system for derived quantities to the MCMC. The derived quantity functions initially included are: - `mean` - `variance` - `logProb` - `predictive`
1 parent f03c88f commit e0f18a8

12 files changed

Lines changed: 1564 additions & 94 deletions

packages/nimble/DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ Collate:
134134
MCMC_build.R
135135
MCMC_run.R
136136
MCMC_samplers.R
137+
MCMC_derived.R
137138
MCMC_conjugacy.R
138139
MCMC_autoBlock.R
139140
MCMC_RJ.R

packages/nimble/NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ export(dcar_proper)
8585
export(dcat)
8686
export(dconstraint)
8787
export(ddexp)
88+
export(derived_BASE)
89+
export(derived_logProb)
90+
export(derived_mean)
91+
export(derived_predictive)
92+
export(derived_variance)
8893
export(ddirch)
8994
export(decide)
9095
export(decideAndJump)

packages/nimble/R/MCMC_autoBlock.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ autoBlockClass <- setRefClass(
416416
## msg <- 'using \'posterior_predictive\' sampler may lead to results we don\'t want'
417417
## cat(paste0('\nWARNING: ', msg, '\n\n')); warning(msg)
418418
## }
419-
if(grepl('^conjugate_', ss$name) && getNimbleOption('verifyConjugatePosteriors')) {
419+
if(grepl('^conjugate_', ss$name) && getNimbleOption('MCMCverifyConjugatePosteriors')) {
420420
##msg <- 'conjugate sampler running slow due to checking the posterior'
421421
##cat(paste0('\nWARNING: ', msg, '\n\n')); warning(msg)
422422
warn <- TRUE

packages/nimble/R/MCMC_build.R

Lines changed: 98 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#'
3131
#' \code{initializeModel}: Boolean specifying whether to run the initializeModel routine on the underlying model object, prior to beginning MCMC sampling (default = TRUE).
3232
#'
33+
#' \code{chain}: Integer specifying the MCMC chain number. The chain number is passed to each MCMC sampler's before_chain method. The value for this argument is specified automatically from invocation via runMCMC, and need not be supplied when calling mcmc$run (default = 1).
3334
#' \code{chain}: Integer specifying the MCMC chain number. The chain number is passed to each MCMC sampler's before_chain and after_chain methods. The value for this argument is specified automatically from invocation via runMCMC, and genernally need not be supplied when calling mcmc$run (default = 1).
3435
#'
3536
#' \code{time}: Boolean specifying whether to record runtimes of the individual internal MCMC samplers. When \code{time = TRUE}, a vector of runtimes (measured in seconds) can be extracted from the MCMC using the method \code{mcmc$getTimes()} (default = FALSE).
@@ -180,16 +181,47 @@ buildMCMC <- nimbleFunction(
180181
for(i in seq_along(conf$samplerConfs))
181182
samplerFunctions[[i]] <- conf$samplerConfs[[i]]$buildSampler(model=model, mvSaved=mvSaved)
182183
}
184+
185+
## construct mvSamples and mvSamples2
186+
mvSamplesConf <- conf$getMvSamplesConf(1)
187+
mvSamples2Conf <- conf$getMvSamplesConf(2)
188+
mvSamples <- modelValues(mvSamplesConf)
189+
mvSamples2 <- modelValues(mvSamples2Conf)
183190

191+
## build derived quantity intervals
192+
derivedFunctions <- nimbleFunctionList(derived_BASE)
193+
numDerived <- length(conf$derivedConfs)
194+
derivedIntervals <- numeric(max(numDerived, 2)) ## force to be a vector
195+
for(i in seq_along(conf$derivedConfs)) {
196+
derivedIntervals[i] <- conf$derivedConfs[[i]]$interval
197+
}
198+
199+
## build derived quantity function list
200+
derivedFunctions <- nimbleFunctionList(derived_BASE)
201+
## code below allows 'mcmc' to be a setup argument of derived quantity nimbleFunctions
202+
on.exit({
203+
for(i in seq_along(conf$derivedConfs)) {
204+
derivedFunctions[[i]] <- conf$derivedConfs[[i]]$buildDerived(model=model, mcmc=nfRefClassObject)
205+
}
206+
## need to catch the case where buildMCMC errors out early,
207+
## prior to executing the final lines which actually create the nfRefClassObject object:
208+
if(exists('nfRefClassObject', inherits = FALSE)) {
209+
nfRefClassObject[['derivedFunctions']] <- nf_preProcessMemberDataObject(get('derivedFunctions'))
210+
}
211+
}, add = TRUE)
212+
213+
## for naming the derivedList return object from runMCMC
214+
derivedTypes <- sapply(conf$derivedConfs, `[[`, 'name')
215+
if(length(derivedTypes) == 0) derivedTypes <- character()
216+
## used for extracting names of derived quantities,
217+
## having as member data is necessary for compilation
218+
derivedNames <- character(2)
219+
184220
samplerExecutionOrderFromConfPlusTwoZeros <- c(conf$samplerExecutionOrder, 0, 0) ## establish as a vector
185221
monitors <- mcmc_processMonitorNames(model, conf$monitors)
186222
monitors2 <- mcmc_processMonitorNames(model, conf$monitors2)
187223
thinFromConfVec <- c(conf$thin, conf$thin2) ## vector
188224
thinToUseVec <- c(0, 0) ## vector, needs to member data
189-
mvSamplesConf <- conf$getMvSamplesConf(1)
190-
mvSamples2Conf <- conf$getMvSamplesConf(2)
191-
mvSamples <- modelValues(mvSamplesConf)
192-
mvSamples2 <- modelValues(mvSamples2Conf)
193225
samplerTimes <- c(0,0) ## establish as a vector
194226
progressBarLength <- 52 ## multiples of 4 only
195227
progressBarDefaultSetting <- getNimbleOption('MCMCprogressBar')
@@ -212,6 +244,8 @@ buildMCMC <- nimbleFunction(
212244
thinWAIC <- FALSE
213245
nburnin_extraWAIC <- 0
214246
}
247+
firstRun <- TRUE
248+
setupOutputs(derivedTypes)
215249
},
216250

217251
run = function(
@@ -231,23 +265,34 @@ buildMCMC <- nimbleFunction(
231265
if(niter < 0) stop('cannot specify niter < 0')
232266
if(nburnin < 0) stop('cannot specify nburnin < 0')
233267
if(nburnin > niter) stop('cannot specify nburnin > niter')
234-
thinToUseVec <<- thinFromConfVec
235-
if(thin != -1) thinToUseVec[1] <<- thin
236-
if(thin2 != -1) thinToUseVec[2] <<- thin2
237-
for(iThin in 1:2) {
238-
if(thinToUseVec[iThin] < 1) stop('cannot use thin < 1')
239-
if(thinToUseVec[iThin] != floor(thinToUseVec[iThin])) stop('cannot use non-integer thin')
240-
}
241-
if(initializeModel) my_initializeModel$run()
242-
nimCopy(from = model, to = mvSaved, row = 1, logProb = TRUE)
268+
if(firstRun) reset <- TRUE ## compulsory reset on first run of MCMC
269+
firstRun <<- FALSE
243270
if(reset) {
244-
samplerTimes <<- numeric(length(samplerFunctions) + 1) ## default inititialization to zero
271+
if(initializeModel) my_initializeModel$run()
272+
thinToUseVec <<- thinFromConfVec
273+
if(thin != -1) thinToUseVec[1] <<- thin
274+
if(thin2 != -1) thinToUseVec[2] <<- thin2
275+
for(iThin in 1:2) {
276+
if(thinToUseVec[iThin] < 1) stop('cannot use thin < 1')
277+
if(thinToUseVec[iThin] != floor(thinToUseVec[iThin])) stop('cannot use non-integer thin')
278+
}
279+
for(i in seq_along(derivedFunctions)) {
280+
if(derivedIntervals[i] == 0) {
281+
derivedIntervals[i] <<- thinToUseVec[1]
282+
derivedFunctions[[i]]$set_interval(thinToUseVec[1])
283+
}
284+
}
245285
for(i in seq_along(samplerFunctions)) samplerFunctions[[i]]$reset()
246-
for(i in seq_along(samplerFunctions)) samplerFunctions[[i]]$before_chain(niter, nburnin, chain)
286+
for(i in seq_along(derivedFunctions)) derivedFunctions[[i]]$reset()
287+
for(i in seq_along(samplerFunctions)) samplerFunctions[[i]]$before_chain(niter, nburnin, chain)
288+
for(i in seq_along(derivedFunctions)) derivedFunctions[[i]]$before_chain(niter-nburnin, nburnin, thinToUseVec, chain)
289+
samplerTimes <<- numeric(length(samplerFunctions) + 1) ## default inititialization to zero
247290
mvSamples_copyRow <- 0
248291
mvSamples2_copyRow <- 0
249292
} else {
250-
if(nburnin != 0) stop('cannot specify nburnin when using reset = FALSE.')
293+
if(nburnin != 0) stop('cannot specify nburnin when using reset = FALSE.')
294+
if(thin != -1) stop('cannot specify thin when using reset = FALSE.')
295+
if(thin2 != -1) stop('cannot specify thin2 when using reset = FALSE.')
251296
if(dim(samplerTimes)[1] != length(samplerFunctions) + 1) samplerTimes <<- numeric(length(samplerFunctions) + 1) ## first run: default inititialization to zero
252297
if (resetMV) {
253298
mvSamples_copyRow <- 0
@@ -257,8 +302,8 @@ buildMCMC <- nimbleFunction(
257302
mvSamples2_copyRow <- getsize(mvSamples2)
258303
}
259304
}
260-
if(onlineWAIC & resetWAIC)
261-
waicFun[[1]]$reset()
305+
nimCopy(from = model, to = mvSaved, row = 1, logProb = TRUE)
306+
if(onlineWAIC & resetWAIC) waicFun[[1]]$reset()
262307
resize(mvSamples, mvSamples_copyRow + floor((niter-nburnin) / thinToUseVec[1]))
263308
resize(mvSamples2, mvSamples2_copyRow + floor((niter-nburnin) / thinToUseVec[2]))
264309
## reinstate samplerExecutionOrder as a runtime argument, once we support non-scalar default values for runtime arguments:
@@ -278,6 +323,7 @@ buildMCMC <- nimbleFunction(
278323
if(niter < 1) return()
279324
for(iter in 1:niter) {
280325
checkInterrupt()
326+
## execute samplerFunctions
281327
if(time) {
282328
for(i in seq_along(samplerExecutionOrderToUse)) {
283329
ind <- samplerExecutionOrderToUse[i]
@@ -289,42 +335,69 @@ buildMCMC <- nimbleFunction(
289335
samplerFunctions[[ind]]$run()
290336
}
291337
}
292-
## adding "accumulators" to MCMC
293-
## https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
294338
if(iter > nburnin) {
295-
sampleNumber <- iter - nburnin
296-
if(sampleNumber %% thinToUseVec[1] == 0) {
339+
## save samples
340+
iterPostBurnin <- iter - nburnin
341+
if(iterPostBurnin %% thinToUseVec[1] == 0) {
297342
mvSamples_copyRow <- mvSamples_copyRow + 1
298343
nimCopy(from = model, to = mvSamples, row = mvSamples_copyRow, nodes = monitors)
299344
}
300-
if(sampleNumber %% thinToUseVec[2] == 0) {
345+
if(iterPostBurnin %% thinToUseVec[2] == 0) {
301346
mvSamples2_copyRow <- mvSamples2_copyRow + 1
302347
nimCopy(from = model, to = mvSamples2, row = mvSamples2_copyRow, nodes = monitors2)
303348
}
349+
## save WAIC
304350
if(enableWAIC & onlineWAIC & iter > nburnin + nburnin_extraWAIC) {
305351
if (!thinWAIC) {
306352
waicFun[[1]]$updateStats()
307-
} else if (sampleNumber %% thinToUseVec[1] == 0){
353+
} else if (iterPostBurnin %% thinToUseVec[1] == 0) {
308354
waicFun[[1]]$updateStats()
309355
}
310356
}
357+
## execute derivedFunctions
358+
for(i in seq_along(derivedFunctions)) {
359+
if(iterPostBurnin %% derivedIntervals[i] == 0) {
360+
derivedFunctions[[i]]$run( iterPostBurnin/derivedIntervals[i] )
361+
}
362+
}
311363
}
364+
## progress bar
312365
if(progressBar & (iter == progressBarNextFloor)) {
313366
cat('-')
314367
progressBarNext <- progressBarNext + progressBarIncrement
315368
progressBarNextFloor <- floor(progressBarNext)
316369
}
317370
}
318371
if(progressBar) print('|')
372+
## after_chain methods
319373
for(i in seq_along(samplerFunctions)) samplerFunctions[[i]]$after_chain()
374+
for(i in seq_along(derivedFunctions)) derivedFunctions[[i]]$after_chain()
320375
returnType(void())
321376
},
322377
methods = list(
323378
getTimes = function() {
324379
returnType(double(1))
325380
return(samplerTimes[1:(length(samplerTimes)-1)])
326381
},
327-
## Old-style post-sampling WAIC calculation.
382+
getNumDerived = function() {
383+
returnType(double())
384+
return(numDerived)
385+
},
386+
getDerivedQuantityResults = function(ind = double()) {
387+
if(ind > numDerived) {
388+
print('there aren\'t that many derived functions')
389+
return(array(0, c(0,0)))
390+
}
391+
returnType(double(2))
392+
return(derivedFunctions[[ind]]$get_results())
393+
},
394+
getDerivedQuantityNames = function(ind = double()) {
395+
if(ind > numDerived) print('there aren\'t that many derived functions')
396+
returnType(character(1))
397+
derivedNames <<- derivedFunctions[[ind]]$get_names()
398+
return(derivedNames)
399+
},
400+
## old-style post-sampling WAIC calculation
328401
calculateWAIC = function(nburnin = integer(default = 0)) {
329402
if(!enableWAIC) {
330403
print('Error: One must set enableWAIC = TRUE in \'configureMCMC\' or \'buildMCMC\'. See \'help(configureMCMC)\' for additional information.')

0 commit comments

Comments
 (0)