@@ -37,7 +37,7 @@ use crate::{
3737 Math , NutsError ,
3838 chain:: { AdaptStrategy , Chain , StatOptions } ,
3939 dynamics:: {
40- DivergenceInfo , DivergenceStats , Hamiltonian , KineticEnergyKind , Point , State ,
40+ Direction , DivergenceInfo , DivergenceStats , Hamiltonian , KineticEnergyKind , Point , State ,
4141 TransformedHamiltonian , TransformedPoint ,
4242 } ,
4343 nuts:: { Collector , NutsOptions } ,
@@ -168,6 +168,7 @@ where
168168 math : RefCell < M > ,
169169 stats_options : StatOptions < M , A > ,
170170 last_info : Option < MclmcInfo > ,
171+ tmp_velocity : M :: Vector ,
171172}
172173
173174impl < M , R , A , T > MclmcChain < M , R , A , T >
@@ -191,6 +192,7 @@ where
191192 ) -> Self {
192193 let state = hamiltonian. pool ( ) . new_state ( & mut math) ;
193194 let collector = adapt. new_collector ( & mut math) ;
195+ let tmp_velocity = math. new_array ( ) ;
194196 Self {
195197 hamiltonian,
196198 collector,
@@ -207,6 +209,7 @@ where
207209 math : math. into ( ) ,
208210 stats_options,
209211 last_info : None ,
212+ tmp_velocity,
210213 }
211214 }
212215
@@ -223,12 +226,8 @@ where
223226 . round ( )
224227 . max ( 1.0 ) as u64 ;
225228
226- // ── First partial momentum refresh ───────────────────────────────────
227- self . hamiltonian
228- . partial_momentum_refresh ( math, & mut self . state , & mut self . rng , 1.0 ) ?;
229-
230- // Reset the energy baseline to the post-refresh state so that
231- // energy_error() measures drift over this draw only.
229+ // Reset the kinetic-energy accumulator so that energy() at the draw
230+ // start is a clean baseline for MclmcInfo.energy_change.
232231 {
233232 let kind = self . hamiltonian . kinetic_energy_kind ;
234233
@@ -251,7 +250,12 @@ where
251250
252251 use crate :: dynamics:: LeapfrogResult ;
253252
254- let mut current = self . state . clone ( ) ;
253+ let mut current = self . hamiltonian . copy_state ( math, & self . state ) ;
254+
255+ // Capture the draw-start energy once; used at the end to compute
256+ // MclmcInfo.energy_change independently of the per-step baselines.
257+ let draw_start_energy = current. point ( ) . energy ( ) ;
258+
255259 let mut divergence_info: Option < DivergenceInfo > = None ;
256260 let mut steps_taken = 0u64 ;
257261
@@ -264,14 +268,44 @@ where
264268
265269 let mut remaining = num_steps;
266270 while remaining > 0 {
271+ // debt == 2 indicates a previous unsuccessful leapfrog
272+ if debt != 2 {
273+ math. copy_into ( & current. point ( ) . velocity , & mut self . tmp_velocity ) ;
274+ } else {
275+ // Restore the old velocity and refresh momentum with smaller factor
276+ // TODO: I think we should reuse the original gaussian noise?
277+ math. copy_into (
278+ & self . tmp_velocity ,
279+ & mut current. try_point_mut ( ) . unwrap ( ) . velocity ,
280+ ) ;
281+ }
282+ self . hamiltonian
283+ . partial_momentum_refresh ( math, & mut current, & mut self . rng , factor) ?;
284+
285+ // Use the post-refresh energy as the divergence baseline so that
286+ // the leapfrog's energy_error measures only this single step's
287+ // integration error (O(ε²)), not the cumulative drift from the
288+ // draw start. Without this, many small steps with a tight
289+ // max_energy_error threshold can exhaust all halvings because the
290+ // accumulated baseline already sits at the threshold.
291+ let step_baseline = current. point ( ) . energy ( ) ;
267292 let next = match self . hamiltonian . leapfrog (
268293 math,
269294 & current,
270- crate :: dynamics :: Direction :: Forward ,
295+ Direction :: Forward ,
271296 factor,
297+ step_baseline,
272298 & mut self . collector ,
273299 ) {
274- LeapfrogResult :: Ok ( next) => next,
300+ LeapfrogResult :: Ok ( mut next) => {
301+ self . hamiltonian . partial_momentum_refresh (
302+ math,
303+ & mut next,
304+ & mut self . rng ,
305+ factor,
306+ ) ?;
307+ next
308+ }
275309 LeapfrogResult :: Divergence ( info) => {
276310 if halvings >= max_halvings {
277311 // Genuinely diverged — give up.
@@ -283,6 +317,7 @@ where
283317 factor *= 0.5 ;
284318 halvings += 1 ;
285319 debt = 2 ;
320+ remaining += 1 ;
286321 continue ;
287322 }
288323 LeapfrogResult :: Err ( e) => {
@@ -318,7 +353,7 @@ where
318353 let mut next_state = self . hamiltonian . copy_state ( math, & self . state ) ;
319354 self . hamiltonian
320355 . initialize_trajectory ( math, & mut next_state, & mut self . rng ) ?;
321- let energy_change = current. point ( ) . energy_error ( ) ;
356+ let energy_change = current. point ( ) . energy ( ) - draw_start_energy ;
322357 let info = MclmcInfo {
323358 energy_change,
324359 diverging : true ,
@@ -336,18 +371,14 @@ where
336371 } ;
337372 self . collector . register_draw ( math, & current, & sample_info) ;
338373
339- let energy_change = current. point ( ) . energy_error ( ) ;
374+ let energy_change = current. point ( ) . energy ( ) - draw_start_energy ;
340375
341- // ── Second partial momentum refresh ──────────────────────────────────
342- let mut next_state = if current. try_point_mut ( ) . is_err ( ) {
376+ let next_state = if current. try_point_mut ( ) . is_err ( ) {
343377 self . hamiltonian . copy_state ( math, & current)
344378 } else {
345379 current
346380 } ;
347381
348- self . hamiltonian
349- . partial_momentum_refresh ( math, & mut next_state, & mut self . rng , 1.0 ) ?;
350-
351382 let info = MclmcInfo {
352383 energy_change,
353384 diverging : false ,
0 commit comments