Skip to content

Commit 37f1227

Browse files
committed
feat: use standard hmc early in mclmc
1 parent 2909ae7 commit 37f1227

4 files changed

Lines changed: 278 additions & 63 deletions

File tree

src/dynamics/transformed_hamiltonian.rs

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,22 @@ impl<M: Math> TransformedPoint<M> {
264264
/// Reset the trajectory-tracking fields so that `energy_error()` is measured
265265
/// relative to the current state (e.g. after a partial momentum refresh).
266266
///
267-
/// Sets `kinetic_energy = 0`, `index_in_trajectory = 0`,
268-
/// `initial_energy = energy()`, and `step_size_factor = 1.0`.
269-
/// Used by the MCLMC sampler after each isokinetic Langevin refresh.
270-
pub(crate) fn reset_trajectory_energy(&mut self) {
271-
self.kinetic_energy = 0.0;
267+
/// For [`KineticEnergyKind::Microcanonical`]: sets `kinetic_energy = 0` (the
268+
/// accumulated ΔKE accumulator is zeroed for the new trajectory).
269+
/// For [`KineticEnergyKind::Euclidean`] / [`KineticEnergyKind::ExactNormal`]:
270+
/// recomputes `kinetic_energy = ½‖v‖²` from the current (post-refresh) velocity.
271+
///
272+
/// In both cases sets `index_in_trajectory = 0`, `initial_energy = energy()`,
273+
/// and `step_size_factor = 1.0`.
274+
pub(crate) fn reset_trajectory_energy(&mut self, math: &mut M, kind: KineticEnergyKind) {
275+
match kind {
276+
KineticEnergyKind::Microcanonical => {
277+
self.kinetic_energy = 0.0;
278+
}
279+
KineticEnergyKind::Euclidean | KineticEnergyKind::ExactNormal => {
280+
self.update_kinetic_energy(math);
281+
}
282+
}
272283
self.index_in_trajectory = 0;
273284
self.initial_energy = self.energy();
274285
self.step_size_factor = 1.0;
@@ -467,6 +478,16 @@ impl<M: Math, T: Transformation<M>> TransformedHamiltonian<M, T> {
467478
pub fn set_momentum_decoherence_length(&mut self, l: Option<f64>) {
468479
self.momentum_decoherence_length = l;
469480
}
481+
482+
/// Change the kinetic-energy kind (and thus the leapfrog integrator and
483+
/// momentum distribution) used by this Hamiltonian.
484+
///
485+
/// When switching from [`KineticEnergyKind::Euclidean`] to
486+
/// [`KineticEnergyKind::Microcanonical`] the caller is responsible for
487+
/// reinitializing the state.
488+
pub fn set_kinetic_energy_kind(&mut self, kind: KineticEnergyKind) {
489+
self.kinetic_energy_kind = kind;
490+
}
470491
}
471492

472493
impl<M: Math> TransformedHamiltonian<M, ExternalTransformation<M>> {
@@ -785,17 +806,43 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
785806
};
786807

787808
let half_step = self.step_size * factor / 2.0;
788-
let n = math.dim() as f64;
789-
let nu = ((2.0 * half_step / momentum_decoherence_length).exp_m1() / n).sqrt();
790809

791810
let mut noise = math.new_array();
792811
math.array_gaussian(rng, &mut noise, &self.ones);
793812

794813
let point = state.try_point_mut().map_err(|_| {
795814
NutsError::BadInitGrad(anyhow::anyhow!("State in use during momentum refresh").into())
796815
})?;
797-
math.axpy(&noise, &mut point.velocity, nu);
798-
math.array_normalize(&mut point.velocity);
816+
817+
match self.kinetic_energy_kind {
818+
KineticEnergyKind::Microcanonical => {
819+
// Isokinetic Langevin (OU on the unit sphere):
820+
// ν = sqrt((exp(2·half_step/L) − 1) / n), n = dim
821+
// p ← (p + ν·z) / ‖p + ν·z‖, z ~ N(0, I)
822+
let n = math.dim() as f64;
823+
let nu = ((2.0 * half_step / momentum_decoherence_length).exp_m1() / n).sqrt();
824+
math.axpy(&noise, &mut point.velocity, nu);
825+
math.array_normalize(&mut point.velocity);
826+
}
827+
KineticEnergyKind::Euclidean | KineticEnergyKind::ExactNormal => {
828+
// Ornstein–Uhlenbeck for Gaussian momentum p ~ N(0, I):
829+
// α = exp(−half_step / L)
830+
// β = sqrt(1 − α²)
831+
// p_new = α · p + β · z, z ~ N(0, I)
832+
//
833+
// `axpy_out(x, y, a, out)` computes `out = y + a·x`.
834+
// So `axpy_out(&velocity, &zeros, alpha, &mut new_velocity)`
835+
// gives `new_velocity = zeros + alpha·velocity = alpha·velocity`.
836+
let alpha = (-half_step / momentum_decoherence_length).exp();
837+
let beta = (1.0 - alpha * alpha).sqrt();
838+
let mut new_velocity = math.new_array();
839+
math.axpy_out(&point.velocity, &self.zeros, alpha, &mut new_velocity);
840+
math.axpy(&noise, &mut new_velocity, beta);
841+
math.copy_into(&new_velocity, &mut point.velocity);
842+
// Keep kinetic_energy consistent with the updated velocity.
843+
point.update_kinetic_energy(math);
844+
}
845+
}
799846

800847
Ok(())
801848
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ pub use adapt_strategy::EuclideanAdaptOptions;
120120
pub use chain::Chain;
121121
pub use dynamics::{DivergenceInfo, KineticEnergyKind};
122122
pub use math::{CpuLogpFunc, CpuMath, CpuMathError, LogpError, Math};
123-
pub use mclmc::{MclmcChain, MclmcInfo, MclmcStats};
123+
pub use mclmc::{MclmcChain, MclmcInfo, MclmcStats, MclmcTrajectoryKind};
124124
pub use model::Model;
125125
pub use nuts::NutsError;
126126

0 commit comments

Comments
 (0)