Skip to content

Commit e892d9b

Browse files
committed
fix: mclmc fix various small issues
1 parent 69bd636 commit e892d9b

7 files changed

Lines changed: 158 additions & 194 deletions

File tree

src/adapt_strategy.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -446,15 +446,9 @@ mod test {
446446
GlobalStrategy::<_, DiagAdaptStrategy<_>>::new(&mut math, options, num_tune, 0u64);
447447

448448
let mass_matrix = DiagMassMatrix::new(&mut math, true);
449-
let max_energy_error = 1000f64;
450449

451450
let hamiltonian: TransformedHamiltonian<_, DiagMassMatrix<CpuMath<NormalLogp>>> =
452-
TransformedHamiltonian::new(
453-
&mut math,
454-
max_energy_error,
455-
mass_matrix,
456-
KineticEnergyKind::Euclidean,
457-
);
451+
TransformedHamiltonian::new(&mut math, mass_matrix, KineticEnergyKind::Euclidean);
458452

459453
let options = NutsOptions {
460454
maxdepth: 10u64,
@@ -463,6 +457,7 @@ mod test {
463457
store_divergences: false,
464458
target_integration_time: None,
465459
extra_doublings: 0,
460+
max_energy_error: 1000.0,
466461
};
467462

468463
let rng = {

src/dynamics/hamiltonian.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
163163
dir: Direction,
164164
step_size_factor: f64,
165165
energy_baseline: f64,
166+
max_energy_error: f64,
166167
collector: &mut C,
167168
) -> LeapfrogResult<M, Self::Point>;
168169

@@ -195,6 +196,7 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
195196
&self,
196197
math: &mut M,
197198
state: &mut State<M, Self::Point>,
199+
resaple_velocity: bool,
198200
rng: &mut R,
199201
) -> Result<(), NutsError>;
200202

@@ -246,10 +248,11 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
246248
&mut self,
247249
math: &mut M,
248250
state: &mut State<M, Self::Point>,
251+
noise: &M::Vector,
249252
rng: &mut R,
250253
factor: f64,
251254
) -> Result<(), NutsError> {
252-
let _ = (math, state, rng, factor);
255+
let _ = (math, state, noise, rng, factor);
253256
Ok(())
254257
}
255258
}

src/dynamics/transformed_hamiltonian.rs

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -261,30 +261,6 @@ impl<M: Math> TransformedPoint<M> {
261261
self.kinetic_energy = 0.5 * math.array_vector_dot(&self.velocity, &self.velocity);
262262
}
263263

264-
/// Reset the trajectory-tracking fields so that `energy_error()` is measured
265-
/// relative to the current state (e.g. after a partial momentum refresh).
266-
///
267-
/// For [`KineticEnergyKind::Microcanonical`]: sets `kinetic_energy = 0` (the
268-
/// accumulated ΔKE accumulator is zeroed for the new trajectory).
269-
/// For [`KineticEnergyKind::Euclidean`] / [`KineticEnergyKind::ExactNormal`]:
270-
/// recomputes `kinetic_energy = ½‖v‖²` from the current (post-refresh) velocity.
271-
///
272-
/// In both cases sets `index_in_trajectory = 0`, `initial_energy = energy()`,
273-
/// and `step_size_factor = 1.0`.
274-
pub(crate) fn reset_trajectory_energy(&mut self, math: &mut M, kind: KineticEnergyKind) {
275-
match kind {
276-
KineticEnergyKind::Microcanonical => {
277-
self.kinetic_energy = 0.0;
278-
}
279-
KineticEnergyKind::Euclidean | KineticEnergyKind::ExactNormal => {
280-
self.update_kinetic_energy(math);
281-
}
282-
}
283-
self.index_in_trajectory = 0;
284-
self.initial_energy = self.energy();
285-
self.step_size_factor = 1.0;
286-
}
287-
288264
fn init_from_untransformed_position<T: Transformation<M>>(
289265
&mut self,
290266
transformation: &T,
@@ -438,18 +414,12 @@ pub struct TransformedHamiltonian<M: Math, T: Transformation<M>> {
438414
/// `None` disables the refresh (used by NUTS); `Some(L)` enables it (MCLMC).
439415
momentum_decoherence_length: Option<f64>,
440416
transformation: T,
441-
max_energy_error: f64,
442417
pub kinetic_energy_kind: KineticEnergyKind,
443418
pool: StatePool<M, TransformedPoint<M>>,
444419
}
445420

446421
impl<M: Math, T: Transformation<M>> TransformedHamiltonian<M, T> {
447-
pub fn new(
448-
math: &mut M,
449-
max_energy_error: f64,
450-
transformation: T,
451-
kinetic_energy_kind: KineticEnergyKind,
452-
) -> Self {
422+
pub fn new(math: &mut M, transformation: T, kinetic_energy_kind: KineticEnergyKind) -> Self {
453423
let mut ones = math.new_array();
454424
math.fill_array(&mut ones, 1f64);
455425
let mut zeros = math.new_array();
@@ -461,7 +431,6 @@ impl<M: Math, T: Transformation<M>> TransformedHamiltonian<M, T> {
461431
ones,
462432
zeros,
463433
transformation,
464-
max_energy_error,
465434
kinetic_energy_kind,
466435
pool,
467436
}
@@ -559,6 +528,7 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
559528
dir: Direction,
560529
step_size_factor: f64,
561530
energy_baseline: f64,
531+
max_energy_error: f64,
562532
collector: &mut C,
563533
) -> LeapfrogResult<M, Self::Point> {
564534
let mut out = self.pool().new_state(math);
@@ -620,9 +590,9 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
620590
let energy_error = out_point.energy() - energy_baseline;
621591
let bad_energy = match self.kinetic_energy_kind {
622592
KineticEnergyKind::Euclidean | KineticEnergyKind::ExactNormal => {
623-
energy_error > self.max_energy_error
593+
energy_error > max_energy_error
624594
}
625-
KineticEnergyKind::Microcanonical => energy_error.abs() >= self.max_energy_error,
595+
KineticEnergyKind::Microcanonical => energy_error.abs() >= max_energy_error,
626596
};
627597
if bad_energy | !energy_error.is_finite() {
628598
let divergence_info = DivergenceInfo {
@@ -718,16 +688,19 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
718688
&self,
719689
math: &mut M,
720690
state: &mut State<M, Self::Point>,
691+
resample_velocity: bool,
721692
rng: &mut R,
722693
) -> Result<(), NutsError> {
723694
let point = state.try_point_mut().expect("State has other references");
724695

725-
// Sample raw isotropic Gaussian momentum.
726-
math.array_gaussian(rng, &mut point.velocity, &self.ones);
696+
if resample_velocity {
697+
// Sample raw isotropic Gaussian momentum.
698+
math.array_gaussian(rng, &mut point.velocity, &self.ones);
727699

728-
// For Microcanonical HMC the momentum must lie on the unit sphere.
729-
if self.kinetic_energy_kind == KineticEnergyKind::Microcanonical {
730-
math.array_normalize(&mut point.velocity);
700+
// For Microcanonical HMC the momentum must lie on the unit sphere.
701+
if self.kinetic_energy_kind == KineticEnergyKind::Microcanonical {
702+
math.array_normalize(&mut point.velocity);
703+
}
731704
}
732705

733706
let current_transform_id = self.transformation().transformation_id(math);
@@ -805,7 +778,8 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
805778
&mut self,
806779
math: &mut M,
807780
state: &mut State<M, Self::Point>,
808-
rng: &mut R,
781+
noise: &M::Vector,
782+
_rng: &mut R,
809783
factor: f64,
810784
) -> Result<(), NutsError> {
811785
let Some(momentum_decoherence_length) = self.momentum_decoherence_length else {
@@ -814,10 +788,6 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
814788

815789
let half_step = self.step_size * factor / 2.0;
816790

817-
// TODO: Avoid array allocation
818-
let mut noise = math.new_array();
819-
math.array_gaussian(rng, &mut noise, &self.ones);
820-
821791
let point = state.try_point_mut().map_err(|_| {
822792
NutsError::BadInitGrad(anyhow::anyhow!("State in use during momentum refresh").into())
823793
})?;

0 commit comments

Comments
 (0)