@@ -879,6 +879,7 @@ sampler_NUTS <- nimbleFunction(
879879 warningCodes <- array (0 , c(max(numWarnings ,1 ), 2 ))
880880 numDivergences <- 0
881881 numTimesMaxTreeDepth <- 0
882+ init_epsilon_next_iter <- FALSE
882883 ADreset <- 1
883884 # # nimbleLists
884885 treebranchNL <- treebranchNL_NUTS # # reference input to buildtree
@@ -913,6 +914,14 @@ sampler_NUTS <- nimbleFunction(
913914 mu <<- log(10 * epsilon ) # # curiously, Stan sets this for the first round *before* init_stepsize
914915 if (initializeEpsilon & adaptive ) initEpsilon()
915916 }
917+ if (init_epsilon_next_iter ) {
918+ if (initializeEpsilon ) initEpsilon()
919+ Hbar <<- 0
920+ logEpsilonBar <<- 0
921+ stepsizeCounter <<- 0
922+ mu <<- log(10 * epsilon )
923+ init_epsilon_next_iter <<- FALSE
924+ }
916925 timesRan <<- timesRan + 1
917926 if (printTimesRan ) print(' ============ times ran = ' , timesRan )
918927 if (printEpsilon ) print(' epsilon = ' , epsilon )
@@ -1003,16 +1012,7 @@ sampler_NUTS <- nimbleFunction(
10031012 if (adaptEpsilon ) adapt_stepsize(accept_prob )
10041013 update <- FALSE
10051014 if (adaptM ) update <- adapt_M()
1006- if (update & adaptEpsilon ) {
1007- if (initializeEpsilon ) {
1008- inverseTransformStoreCalculate(state_sample $ q ) # # defensively ensure model states are up to date.
1009- initEpsilon()
1010- }
1011- Hbar <<- 0
1012- logEpsilonBar <<- 0
1013- stepsizeCounter <<- 0
1014- mu <<- log(10 * epsilon )
1015- }
1015+ if (update & adaptEpsilon ) init_epsilon_next_iter <<- TRUE
10161016 }
10171017 inverseTransformStoreCalculate(state_sample $ q )
10181018 nimCopy(from = model , to = mvSaved , row = 1 , nodes = calcNodes , logProb = TRUE )
@@ -1273,6 +1273,7 @@ sampler_NUTS <- nimbleFunction(
12731273 Hbar <<- 0
12741274 logEpsilonBar <<- 0
12751275 stepsizeCounter <<- 0
1276+ init_epsilon_next_iter <<- FALSE
12761277 setSize(warmupSamples , adaptWindow_size , d , fillZeros = FALSE )
12771278 }
12781279 }
@@ -1306,6 +1307,7 @@ sampler_NUTS <- nimbleFunction(
13061307 warningInd <<- 0
13071308 M <<- Morig
13081309 sqrtM <<- sqrt(M )
1310+ ADreset <<- 1
13091311 # # the adapt_* variables are initialized in before_chain()
13101312 adaptWindow_size <<- 0
13111313 adapt_initBuffer <<- 0
@@ -1314,7 +1316,7 @@ sampler_NUTS <- nimbleFunction(
13141316 adaptWindow_counter <<- 0
13151317 adaptWindow_iter <<- 0
13161318 stepsizeCounter <<- 0
1317- ADreset <<- 1
1319+ init_epsilon_next_iter <<- FALSE
13181320 }
13191321 ),
13201322 buildDerivs = list (
0 commit comments