@@ -264,11 +264,22 @@ impl<M: Math> TransformedPoint<M> {
264264 /// Reset the trajectory-tracking fields so that `energy_error()` is measured
265265 /// relative to the current state (e.g. after a partial momentum refresh).
266266 ///
267- /// Sets `kinetic_energy = 0`, `index_in_trajectory = 0`,
268- /// `initial_energy = energy()`, and `step_size_factor = 1.0`.
269- /// Used by the MCLMC sampler after each isokinetic Langevin refresh.
270- pub ( crate ) fn reset_trajectory_energy ( & mut self ) {
271- self . kinetic_energy = 0.0 ;
267+ /// For [`KineticEnergyKind::Microcanonical`]: sets `kinetic_energy = 0` (the
268+ /// accumulated ΔKE accumulator is zeroed for the new trajectory).
269+ /// For [`KineticEnergyKind::Euclidean`] / [`KineticEnergyKind::ExactNormal`]:
270+ /// recomputes `kinetic_energy = ½‖v‖²` from the current (post-refresh) velocity.
271+ ///
272+ /// In both cases sets `index_in_trajectory = 0`, `initial_energy = energy()`,
273+ /// and `step_size_factor = 1.0`.
274+ pub ( crate ) fn reset_trajectory_energy ( & mut self , math : & mut M , kind : KineticEnergyKind ) {
275+ match kind {
276+ KineticEnergyKind :: Microcanonical => {
277+ self . kinetic_energy = 0.0 ;
278+ }
279+ KineticEnergyKind :: Euclidean | KineticEnergyKind :: ExactNormal => {
280+ self . update_kinetic_energy ( math) ;
281+ }
282+ }
272283 self . index_in_trajectory = 0 ;
273284 self . initial_energy = self . energy ( ) ;
274285 self . step_size_factor = 1.0 ;
@@ -467,6 +478,16 @@ impl<M: Math, T: Transformation<M>> TransformedHamiltonian<M, T> {
467478 pub fn set_momentum_decoherence_length ( & mut self , l : Option < f64 > ) {
468479 self . momentum_decoherence_length = l;
469480 }
481+
482+ /// Change the kinetic-energy kind (and thus the leapfrog integrator and
483+ /// momentum distribution) used by this Hamiltonian.
484+ ///
485+ /// When switching from [`KineticEnergyKind::Euclidean`] to
486+ /// [`KineticEnergyKind::Microcanonical`] the caller is responsible for
487+ /// reinitializing the state.
488+ pub fn set_kinetic_energy_kind ( & mut self , kind : KineticEnergyKind ) {
489+ self . kinetic_energy_kind = kind;
490+ }
470491}
471492
472493impl < M : Math > TransformedHamiltonian < M , ExternalTransformation < M > > {
@@ -785,17 +806,43 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
785806 } ;
786807
787808 let half_step = self . step_size * factor / 2.0 ;
788- let n = math. dim ( ) as f64 ;
789- let nu = ( ( 2.0 * half_step / momentum_decoherence_length) . exp_m1 ( ) / n) . sqrt ( ) ;
790809
791810 let mut noise = math. new_array ( ) ;
792811 math. array_gaussian ( rng, & mut noise, & self . ones ) ;
793812
794813 let point = state. try_point_mut ( ) . map_err ( |_| {
795814 NutsError :: BadInitGrad ( anyhow:: anyhow!( "State in use during momentum refresh" ) . into ( ) )
796815 } ) ?;
797- math. axpy ( & noise, & mut point. velocity , nu) ;
798- math. array_normalize ( & mut point. velocity ) ;
816+
817+ match self . kinetic_energy_kind {
818+ KineticEnergyKind :: Microcanonical => {
819+ // Isokinetic Langevin (OU on the unit sphere):
820+ // ν = sqrt((exp(2·half_step/L) − 1) / n), n = dim
821+ // p ← (p + ν·z) / ‖p + ν·z‖, z ~ N(0, I)
822+ let n = math. dim ( ) as f64 ;
823+ let nu = ( ( 2.0 * half_step / momentum_decoherence_length) . exp_m1 ( ) / n) . sqrt ( ) ;
824+ math. axpy ( & noise, & mut point. velocity , nu) ;
825+ math. array_normalize ( & mut point. velocity ) ;
826+ }
827+ KineticEnergyKind :: Euclidean | KineticEnergyKind :: ExactNormal => {
828+ // Ornstein–Uhlenbeck for Gaussian momentum p ~ N(0, I):
829+ // α = exp(−half_step / L)
830+ // β = sqrt(1 − α²)
831+ // p_new = α · p + β · z, z ~ N(0, I)
832+ //
833+ // `axpy_out(x, y, a, out)` computes `out = y + a·x`.
834+ // So `axpy_out(&velocity, &zeros, alpha, &mut new_velocity)`
835+ // gives `new_velocity = zeros + alpha·velocity = alpha·velocity`.
836+ let alpha = ( -half_step / momentum_decoherence_length) . exp ( ) ;
837+ let beta = ( 1.0 - alpha * alpha) . sqrt ( ) ;
838+ let mut new_velocity = math. new_array ( ) ;
839+ math. axpy_out ( & point. velocity , & self . zeros , alpha, & mut new_velocity) ;
840+ math. axpy ( & noise, & mut new_velocity, beta) ;
841+ math. copy_into ( & new_velocity, & mut point. velocity ) ;
842+ // Keep kinetic_energy consistent with the updated velocity.
843+ point. update_kinetic_energy ( math) ;
844+ }
845+ }
799846
800847 Ok ( ( ) )
801848 }
0 commit comments