@@ -261,30 +261,6 @@ impl<M: Math> TransformedPoint<M> {
261261 self . kinetic_energy = 0.5 * math. array_vector_dot ( & self . velocity , & self . velocity ) ;
262262 }
263263
264- /// Reset the trajectory-tracking fields so that `energy_error()` is measured
265- /// relative to the current state (e.g. after a partial momentum refresh).
266- ///
267- /// For [`KineticEnergyKind::Microcanonical`]: sets `kinetic_energy = 0` (the
268- /// accumulated ΔKE accumulator is zeroed for the new trajectory).
269- /// For [`KineticEnergyKind::Euclidean`] / [`KineticEnergyKind::ExactNormal`]:
270- /// recomputes `kinetic_energy = ½‖v‖²` from the current (post-refresh) velocity.
271- ///
272- /// In both cases sets `index_in_trajectory = 0`, `initial_energy = energy()`,
273- /// and `step_size_factor = 1.0`.
274- pub ( crate ) fn reset_trajectory_energy ( & mut self , math : & mut M , kind : KineticEnergyKind ) {
275- match kind {
276- KineticEnergyKind :: Microcanonical => {
277- self . kinetic_energy = 0.0 ;
278- }
279- KineticEnergyKind :: Euclidean | KineticEnergyKind :: ExactNormal => {
280- self . update_kinetic_energy ( math) ;
281- }
282- }
283- self . index_in_trajectory = 0 ;
284- self . initial_energy = self . energy ( ) ;
285- self . step_size_factor = 1.0 ;
286- }
287-
288264 fn init_from_untransformed_position < T : Transformation < M > > (
289265 & mut self ,
290266 transformation : & T ,
@@ -438,18 +414,12 @@ pub struct TransformedHamiltonian<M: Math, T: Transformation<M>> {
438414 /// `None` disables the refresh (used by NUTS); `Some(L)` enables it (MCLMC).
439415 momentum_decoherence_length : Option < f64 > ,
440416 transformation : T ,
441- max_energy_error : f64 ,
442417 pub kinetic_energy_kind : KineticEnergyKind ,
443418 pool : StatePool < M , TransformedPoint < M > > ,
444419}
445420
446421impl < M : Math , T : Transformation < M > > TransformedHamiltonian < M , T > {
447- pub fn new (
448- math : & mut M ,
449- max_energy_error : f64 ,
450- transformation : T ,
451- kinetic_energy_kind : KineticEnergyKind ,
452- ) -> Self {
422+ pub fn new ( math : & mut M , transformation : T , kinetic_energy_kind : KineticEnergyKind ) -> Self {
453423 let mut ones = math. new_array ( ) ;
454424 math. fill_array ( & mut ones, 1f64 ) ;
455425 let mut zeros = math. new_array ( ) ;
@@ -461,7 +431,6 @@ impl<M: Math, T: Transformation<M>> TransformedHamiltonian<M, T> {
461431 ones,
462432 zeros,
463433 transformation,
464- max_energy_error,
465434 kinetic_energy_kind,
466435 pool,
467436 }
@@ -559,6 +528,7 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
559528 dir : Direction ,
560529 step_size_factor : f64 ,
561530 energy_baseline : f64 ,
531+ max_energy_error : f64 ,
562532 collector : & mut C ,
563533 ) -> LeapfrogResult < M , Self :: Point > {
564534 let mut out = self . pool ( ) . new_state ( math) ;
@@ -620,9 +590,9 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
620590 let energy_error = out_point. energy ( ) - energy_baseline;
621591 let bad_energy = match self . kinetic_energy_kind {
622592 KineticEnergyKind :: Euclidean | KineticEnergyKind :: ExactNormal => {
623- energy_error > self . max_energy_error
593+ energy_error > max_energy_error
624594 }
625- KineticEnergyKind :: Microcanonical => energy_error. abs ( ) >= self . max_energy_error ,
595+ KineticEnergyKind :: Microcanonical => energy_error. abs ( ) >= max_energy_error,
626596 } ;
627597 if bad_energy | !energy_error. is_finite ( ) {
628598 let divergence_info = DivergenceInfo {
@@ -718,16 +688,19 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
718688 & self ,
719689 math : & mut M ,
720690 state : & mut State < M , Self :: Point > ,
691+ resample_velocity : bool ,
721692 rng : & mut R ,
722693 ) -> Result < ( ) , NutsError > {
723694 let point = state. try_point_mut ( ) . expect ( "State has other references" ) ;
724695
725- // Sample raw isotropic Gaussian momentum.
726- math. array_gaussian ( rng, & mut point. velocity , & self . ones ) ;
696+ if resample_velocity {
697+ // Sample raw isotropic Gaussian momentum.
698+ math. array_gaussian ( rng, & mut point. velocity , & self . ones ) ;
727699
728- // For Microcanonical HMC the momentum must lie on the unit sphere.
729- if self . kinetic_energy_kind == KineticEnergyKind :: Microcanonical {
730- math. array_normalize ( & mut point. velocity ) ;
700+ // For Microcanonical HMC the momentum must lie on the unit sphere.
701+ if self . kinetic_energy_kind == KineticEnergyKind :: Microcanonical {
702+ math. array_normalize ( & mut point. velocity ) ;
703+ }
731704 }
732705
733706 let current_transform_id = self . transformation ( ) . transformation_id ( math) ;
@@ -805,7 +778,8 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
805778 & mut self ,
806779 math : & mut M ,
807780 state : & mut State < M , Self :: Point > ,
808- rng : & mut R ,
781+ noise : & M :: Vector ,
782+ _rng : & mut R ,
809783 factor : f64 ,
810784 ) -> Result < ( ) , NutsError > {
811785 let Some ( momentum_decoherence_length) = self . momentum_decoherence_length else {
@@ -814,10 +788,6 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
814788
815789 let half_step = self . step_size * factor / 2.0 ;
816790
817- // TODO: Avoid array allocation
818- let mut noise = math. new_array ( ) ;
819- math. array_gaussian ( rng, & mut noise, & self . ones ) ;
820-
821791 let point = state. try_point_mut ( ) . map_err ( |_| {
822792 NutsError :: BadInitGrad ( anyhow:: anyhow!( "State in use during momentum refresh" ) . into ( ) )
823793 } ) ?;
0 commit comments