@@ -21,6 +21,7 @@ use crate::{
2121/// a cutoff value or nan.
2222/// - The logp function caused a recoverable error (eg if an ODE solver
2323/// failed)
24+ #[ non_exhaustive]
2425#[ derive( Debug , Clone ) ]
2526pub struct DivergenceInfo {
2627 pub start_momentum : Option < Box < [ f64 ] > > ,
@@ -31,6 +32,80 @@ pub struct DivergenceInfo {
3132 pub end_idx_in_trajectory : Option < i64 > ,
3233 pub start_idx_in_trajectory : Option < i64 > ,
3334 pub logp_function_error : Option < Arc < dyn std:: error:: Error + Send + Sync > > ,
35+ pub non_reversible : bool ,
36+ }
37+ impl DivergenceInfo {
38+ pub fn new ( ) -> Self {
39+ DivergenceInfo {
40+ start_momentum : None ,
41+ start_location : None ,
42+ start_gradient : None ,
43+ end_location : None ,
44+ energy_error : None ,
45+ end_idx_in_trajectory : None ,
46+ start_idx_in_trajectory : None ,
47+ logp_function_error : None ,
48+ non_reversible : false ,
49+ }
50+ }
51+
52+ pub fn new_energy_error_too_large < M : Math > (
53+ math : & mut M ,
54+ start : & State < M , impl Point < M > > ,
55+ stop : & State < M , impl Point < M > > ,
56+ ) -> Self {
57+ DivergenceInfo {
58+ logp_function_error : None ,
59+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
60+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
61+ // TODO
62+ start_momentum : None ,
63+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
64+ end_location : Some ( math. box_array ( & stop. point ( ) . position ( ) ) ) ,
65+ end_idx_in_trajectory : Some ( stop. index_in_trajectory ( ) ) ,
66+ // TODO
67+ energy_error : None ,
68+ non_reversible : false ,
69+ }
70+ }
71+
72+ pub fn new_logp_function_error < M : Math > (
73+ math : & mut M ,
74+ start : & State < M , impl Point < M > > ,
75+ logp_function_error : Arc < dyn std:: error:: Error + Send + Sync > ,
76+ ) -> Self {
77+ DivergenceInfo {
78+ logp_function_error : Some ( logp_function_error) ,
79+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
80+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
81+ // TODO
82+ start_momentum : None ,
83+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
84+ end_location : None ,
85+ end_idx_in_trajectory : None ,
86+ energy_error : None ,
87+ non_reversible : false ,
88+ }
89+ }
90+
91+ pub fn new_not_reversible < M : Math > ( math : & mut M , start : & State < M , impl Point < M > > ) -> Self {
92+ // TODO add info about what went wrong
93+ DivergenceInfo {
94+ logp_function_error : None ,
95+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
96+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
97+ // TODO
98+ start_momentum : None ,
99+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
100+ end_location : None ,
101+ end_idx_in_trajectory : None ,
102+ energy_error : None ,
103+ non_reversible : true ,
104+ }
105+ }
106+ pub fn new_max_step_size_halvings < M : Math > ( math : & mut M , num_steps : u64 , info : Self ) -> Self {
107+ info // TODO
108+ }
34109}
35110
36111#[ derive( Debug , Copy , Clone ) ]
@@ -39,6 +114,15 @@ pub enum Direction {
39114 Backward ,
40115}
41116
117+ impl Direction {
118+ pub fn reverse ( & self ) -> Self {
119+ match self {
120+ Direction :: Forward => Direction :: Backward ,
121+ Direction :: Backward => Direction :: Forward ,
122+ }
123+ }
124+ }
125+
42126impl Distribution < Direction > for StandardUniform {
43127 fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Direction {
44128 if rng. random :: < bool > ( ) {
@@ -87,9 +171,44 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
87171 math : & mut M ,
88172 start : & State < M , Self :: Point > ,
89173 dir : Direction ,
174+ step_size_splits : u64 ,
90175 collector : & mut C ,
91176 ) -> LeapfrogResult < M , Self :: Point > ;
92177
178+ fn split_leapfrog < C : Collector < M , Self :: Point > > (
179+ & mut self ,
180+ math : & mut M ,
181+ start : & State < M , Self :: Point > ,
182+ dir : Direction ,
183+ num_steps : u64 ,
184+ collector : & mut C ,
185+ max_error : f64 ,
186+ ) -> LeapfrogResult < M , Self :: Point > {
187+ let mut state = start. clone ( ) ;
188+
189+ let mut min_energy = start. energy ( ) ;
190+ let mut max_energy = min_energy;
191+
192+ for _ in 0 ..num_steps {
193+ state = match self . leapfrog ( math, & state, dir, num_steps, collector) {
194+ LeapfrogResult :: Ok ( state) => state,
195+ LeapfrogResult :: Divergence ( info) => return LeapfrogResult :: Divergence ( info) ,
196+ LeapfrogResult :: Err ( err) => return LeapfrogResult :: Err ( err) ,
197+ } ;
198+ let energy = state. energy ( ) ;
199+ min_energy = min_energy. min ( energy) ;
200+ max_energy = max_energy. max ( energy) ;
201+
202+ // TODO: walnuts papers says to use abs, but c++ code doesn't?
203+ if max_energy - min_energy > max_error {
204+ let info = DivergenceInfo :: new_energy_error_too_large ( math, start, & state) ;
205+ return LeapfrogResult :: Divergence ( info) ;
206+ }
207+ }
208+
209+ LeapfrogResult :: Ok ( state)
210+ }
211+
93212 fn is_turning (
94213 & self ,
95214 math : & mut M ,
@@ -128,4 +247,5 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
128247
129248 fn step_size ( & self ) -> f64 ;
130249 fn step_size_mut ( & mut self ) -> & mut f64 ;
250+ fn max_energy_error ( & self ) -> f64 ;
131251}
0 commit comments