Skip to content

Commit 85ecdbb

Browse files
perrydvdanielturek
andauthored
Defer initEpsilon() to start of next iteration (#66)
Co-authored-by: Daniel Turek <danielturek@gmail.com>
1 parent 6fd061a commit 85ecdbb

1 file changed

Lines changed: 13 additions & 11 deletions

File tree

nimbleHMC/R/HMC_samplers.R

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,7 @@ sampler_NUTS <- nimbleFunction(
879879
warningCodes <- array(0, c(max(numWarnings,1), 2))
880880
numDivergences <- 0
881881
numTimesMaxTreeDepth <- 0
882+
init_epsilon_next_iter <- FALSE
882883
ADreset <- 1
883884
## nimbleLists
884885
treebranchNL <- treebranchNL_NUTS ## reference input to buildtree
@@ -913,6 +914,14 @@ sampler_NUTS <- nimbleFunction(
913914
mu <<- log(10*epsilon) ## curiously, Stan sets this for the first round *before* init_stepsize
914915
if(initializeEpsilon & adaptive) initEpsilon()
915916
}
917+
if(init_epsilon_next_iter) {
918+
if(initializeEpsilon) initEpsilon()
919+
Hbar <<- 0
920+
logEpsilonBar <<- 0
921+
stepsizeCounter <<- 0
922+
mu <<- log(10*epsilon)
923+
init_epsilon_next_iter <<- FALSE
924+
}
916925
timesRan <<- timesRan + 1
917926
if(printTimesRan) print('============ times ran = ', timesRan)
918927
if(printEpsilon) print('epsilon = ', epsilon)
@@ -1003,16 +1012,7 @@ sampler_NUTS <- nimbleFunction(
10031012
if(adaptEpsilon) adapt_stepsize(accept_prob)
10041013
update <- FALSE
10051014
if(adaptM) update <- adapt_M()
1006-
if(update & adaptEpsilon) {
1007-
if(initializeEpsilon) {
1008-
inverseTransformStoreCalculate(state_sample$q) ## defensively ensure model states are up to date.
1009-
initEpsilon()
1010-
}
1011-
Hbar <<- 0
1012-
logEpsilonBar <<- 0
1013-
stepsizeCounter <<- 0
1014-
mu <<- log(10*epsilon)
1015-
}
1015+
if(update & adaptEpsilon) init_epsilon_next_iter <<- TRUE
10161016
}
10171017
inverseTransformStoreCalculate(state_sample$q)
10181018
nimCopy(from = model, to = mvSaved, row = 1, nodes = calcNodes, logProb = TRUE)
@@ -1273,6 +1273,7 @@ sampler_NUTS <- nimbleFunction(
12731273
Hbar <<- 0
12741274
logEpsilonBar <<- 0
12751275
stepsizeCounter <<- 0
1276+
init_epsilon_next_iter <<- FALSE
12761277
setSize(warmupSamples, adaptWindow_size, d, fillZeros = FALSE)
12771278
}
12781279
}
@@ -1306,6 +1307,7 @@ sampler_NUTS <- nimbleFunction(
13061307
warningInd <<- 0
13071308
M <<- Morig
13081309
sqrtM <<- sqrt(M)
1310+
ADreset <<- 1
13091311
## the adapt_* variables are initialized in before_chain()
13101312
adaptWindow_size <<- 0
13111313
adapt_initBuffer <<- 0
@@ -1314,7 +1316,7 @@ sampler_NUTS <- nimbleFunction(
13141316
adaptWindow_counter <<- 0
13151317
adaptWindow_iter <<- 0
13161318
stepsizeCounter <<- 0
1317-
ADreset <<- 1
1319+
init_epsilon_next_iter <<- FALSE
13181320
}
13191321
),
13201322
buildDerivs = list(

0 commit comments

Comments
 (0)