Skip to content

Commit 69bd636

Browse files
committed
fix: incorrect momentum refresh in mclmc
1 parent c3d34d7 commit 69bd636

5 files changed

Lines changed: 86 additions & 27 deletions

File tree

src/dynamics/hamiltonian.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,9 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
151151
/// Perform one leapfrog step.
152152
///
153153
/// `step_size_factor` scales the hamiltonian's base step size for this
154-
/// step only: `actual_ε = hamiltonian.step_size() * step_size_factor`.
155-
/// The actual step size used is stored on the output point as
156-
/// `point.step_size`, so callers can compute importance weights as
157-
/// `log(point.step_size) - energy_error`.
154+
/// step only.
155+
/// `energy_baseline` is the energy value against which the divergence
156+
/// check (`|energy_error| >= max_energy_error`) is evaluated.
158157
///
159158
/// Return either an unrecoverable error, a new state or a divergence.
160159
fn leapfrog<C: Collector<M, Self::Point>>(
@@ -163,6 +162,7 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
163162
start: &State<M, Self::Point>,
164163
dir: Direction,
165164
step_size_factor: f64,
165+
energy_baseline: f64,
166166
collector: &mut C,
167167
) -> LeapfrogResult<M, Self::Point>;
168168

src/dynamics/transformed_hamiltonian.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
558558
start: &State<M, Self::Point>,
559559
dir: Direction,
560560
step_size_factor: f64,
561+
energy_baseline: f64,
561562
collector: &mut C,
562563
) -> LeapfrogResult<M, Self::Point> {
563564
let mut out = self.pool().new_state(math);
@@ -616,8 +617,14 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
616617

617618
out_point.index_in_trajectory = start.index_in_trajectory() + sign;
618619

619-
let energy_error = out_point.energy_error();
620-
if (energy_error > self.max_energy_error) | !energy_error.is_finite() {
620+
let energy_error = out_point.energy() - energy_baseline;
621+
let bad_energy = match self.kinetic_energy_kind {
622+
KineticEnergyKind::Euclidean | KineticEnergyKind::ExactNormal => {
623+
energy_error > self.max_energy_error
624+
}
625+
KineticEnergyKind::Microcanonical => energy_error.abs() >= self.max_energy_error,
626+
};
627+
if bad_energy | !energy_error.is_finite() {
621628
let divergence_info = DivergenceInfo {
622629
logp_function_error: None,
623630
start_location: Some(math.box_array(start.point().position())),
@@ -807,6 +814,7 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
807814

808815
let half_step = self.step_size * factor / 2.0;
809816

817+
// TODO: Avoid array allocation
810818
let mut noise = math.new_array();
811819
math.array_gaussian(rng, &mut noise, &self.ones);
812820

src/mclmc.rs

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use crate::{
3737
Math, NutsError,
3838
chain::{AdaptStrategy, Chain, StatOptions},
3939
dynamics::{
40-
DivergenceInfo, DivergenceStats, Hamiltonian, KineticEnergyKind, Point, State,
40+
Direction, DivergenceInfo, DivergenceStats, Hamiltonian, KineticEnergyKind, Point, State,
4141
TransformedHamiltonian, TransformedPoint,
4242
},
4343
nuts::{Collector, NutsOptions},
@@ -168,6 +168,7 @@ where
168168
math: RefCell<M>,
169169
stats_options: StatOptions<M, A>,
170170
last_info: Option<MclmcInfo>,
171+
tmp_velocity: M::Vector,
171172
}
172173

173174
impl<M, R, A, T> MclmcChain<M, R, A, T>
@@ -191,6 +192,7 @@ where
191192
) -> Self {
192193
let state = hamiltonian.pool().new_state(&mut math);
193194
let collector = adapt.new_collector(&mut math);
195+
let tmp_velocity = math.new_array();
194196
Self {
195197
hamiltonian,
196198
collector,
@@ -207,6 +209,7 @@ where
207209
math: math.into(),
208210
stats_options,
209211
last_info: None,
212+
tmp_velocity,
210213
}
211214
}
212215

@@ -223,12 +226,8 @@ where
223226
.round()
224227
.max(1.0) as u64;
225228

226-
// ── First partial momentum refresh ───────────────────────────────────
227-
self.hamiltonian
228-
.partial_momentum_refresh(math, &mut self.state, &mut self.rng, 1.0)?;
229-
230-
// Reset the energy baseline to the post-refresh state so that
231-
// energy_error() measures drift over this draw only.
229+
// Reset the kinetic-energy accumulator so that energy() at the draw
230+
// start is a clean baseline for MclmcInfo.energy_change.
232231
{
233232
let kind = self.hamiltonian.kinetic_energy_kind;
234233

@@ -251,7 +250,12 @@ where
251250

252251
use crate::dynamics::LeapfrogResult;
253252

254-
let mut current = self.state.clone();
253+
let mut current = self.hamiltonian.copy_state(math, &self.state);
254+
255+
// Capture the draw-start energy once; used at the end to compute
256+
// MclmcInfo.energy_change independently of the per-step baselines.
257+
let draw_start_energy = current.point().energy();
258+
255259
let mut divergence_info: Option<DivergenceInfo> = None;
256260
let mut steps_taken = 0u64;
257261

@@ -264,14 +268,44 @@ where
264268

265269
let mut remaining = num_steps;
266270
while remaining > 0 {
271+
// debt == 2 indicates a previous unsuccessful leapfrog
272+
if debt != 2 {
273+
math.copy_into(&current.point().velocity, &mut self.tmp_velocity);
274+
} else {
275+
// Restore the old velocity and refresh momentum with smaller factor
276+
// TODO: I think we should reuse the original gaussian noise?
277+
math.copy_into(
278+
&self.tmp_velocity,
279+
&mut current.try_point_mut().unwrap().velocity,
280+
);
281+
}
282+
self.hamiltonian
283+
.partial_momentum_refresh(math, &mut current, &mut self.rng, factor)?;
284+
285+
// Use the post-refresh energy as the divergence baseline so that
286+
// the leapfrog's energy_error measures only this single step's
287+
// integration error (O(ε²)), not the cumulative drift from the
288+
// draw start. Without this, many small steps with a tight
289+
// max_energy_error threshold can exhaust all halvings because the
290+
// accumulated baseline already sits at the threshold.
291+
let step_baseline = current.point().energy();
267292
let next = match self.hamiltonian.leapfrog(
268293
math,
269294
&current,
270-
crate::dynamics::Direction::Forward,
295+
Direction::Forward,
271296
factor,
297+
step_baseline,
272298
&mut self.collector,
273299
) {
274-
LeapfrogResult::Ok(next) => next,
300+
LeapfrogResult::Ok(mut next) => {
301+
self.hamiltonian.partial_momentum_refresh(
302+
math,
303+
&mut next,
304+
&mut self.rng,
305+
factor,
306+
)?;
307+
next
308+
}
275309
LeapfrogResult::Divergence(info) => {
276310
if halvings >= max_halvings {
277311
// Genuinely diverged — give up.
@@ -283,6 +317,7 @@ where
283317
factor *= 0.5;
284318
halvings += 1;
285319
debt = 2;
320+
remaining += 1;
286321
continue;
287322
}
288323
LeapfrogResult::Err(e) => {
@@ -318,7 +353,7 @@ where
318353
let mut next_state = self.hamiltonian.copy_state(math, &self.state);
319354
self.hamiltonian
320355
.initialize_trajectory(math, &mut next_state, &mut self.rng)?;
321-
let energy_change = current.point().energy_error();
356+
let energy_change = current.point().energy() - draw_start_energy;
322357
let info = MclmcInfo {
323358
energy_change,
324359
diverging: true,
@@ -336,18 +371,14 @@ where
336371
};
337372
self.collector.register_draw(math, &current, &sample_info);
338373

339-
let energy_change = current.point().energy_error();
374+
let energy_change = current.point().energy() - draw_start_energy;
340375

341-
// ── Second partial momentum refresh ──────────────────────────────────
342-
let mut next_state = if current.try_point_mut().is_err() {
376+
let next_state = if current.try_point_mut().is_err() {
343377
self.hamiltonian.copy_state(math, &current)
344378
} else {
345379
current
346380
};
347381

348-
self.hamiltonian
349-
.partial_momentum_refresh(math, &mut next_state, &mut self.rng, 1.0)?;
350-
351382
let info = MclmcInfo {
352383
energy_change,
353384
diverging: false,

src/nuts.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,14 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
217217
Direction::Forward => &self.right,
218218
Direction::Backward => &self.left,
219219
};
220-
let end = match hamiltonian.leapfrog(math, start, direction, 1.0, collector) {
220+
let end = match hamiltonian.leapfrog(
221+
math,
222+
start,
223+
direction,
224+
1.0,
225+
start.point().initial_energy(),
226+
collector,
227+
) {
221228
LeapfrogResult::Divergence(info) => return Ok(Err(info)),
222229
LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
223230
LeapfrogResult::Ok(end) => end,

src/stepsize/adapt.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,14 @@ impl Strategy {
109109

110110
*hamiltonian.step_size_mut() = self.options.initial_step;
111111

112-
let state_next =
113-
hamiltonian.leapfrog(math, &state, Direction::Forward, 1.0, &mut collector);
112+
let state_next = hamiltonian.leapfrog(
113+
math,
114+
&state,
115+
Direction::Forward,
116+
1.0,
117+
state.point().initial_energy(),
118+
&mut collector,
119+
);
114120

115121
let LeapfrogResult::Ok(_) = state_next else {
116122
return Ok(());
@@ -126,7 +132,14 @@ impl Strategy {
126132
for _ in 0..100 {
127133
let mut collector = AcceptanceRateCollector::new();
128134
collector.register_init(math, &state, options);
129-
let state_next = hamiltonian.leapfrog(math, &state, dir, 1.0, &mut collector);
135+
let state_next = hamiltonian.leapfrog(
136+
math,
137+
&state,
138+
dir,
139+
1.0,
140+
state.point().initial_energy(),
141+
&mut collector,
142+
);
130143
let LeapfrogResult::Ok(_) = state_next else {
131144
*hamiltonian.step_size_mut() = self.options.initial_step;
132145
return Ok(());

0 commit comments

Comments
 (0)