@@ -22,6 +22,7 @@ use crate::{
2222/// a cutoff value or nan.
2323/// - The logp function caused a recoverable error (eg if an ODE solver
2424/// failed)
25+ #[ non_exhaustive]
2526#[ derive( Debug , Clone ) ]
2627pub struct DivergenceInfo {
2728 pub start_momentum : Option < Box < [ f64 ] > > ,
@@ -32,6 +33,86 @@ pub struct DivergenceInfo {
3233 pub end_idx_in_trajectory : Option < i64 > ,
3334 pub start_idx_in_trajectory : Option < i64 > ,
3435 pub logp_function_error : Option < Arc < dyn std:: error:: Error + Send + Sync > > ,
36+ pub non_reversible : bool ,
37+ }
38+ impl DivergenceInfo {
39+ pub fn new ( ) -> Self {
40+ DivergenceInfo {
41+ start_momentum : None ,
42+ start_location : None ,
43+ start_gradient : None ,
44+ end_location : None ,
45+ energy_error : None ,
46+ end_idx_in_trajectory : None ,
47+ start_idx_in_trajectory : None ,
48+ logp_function_error : None ,
49+ non_reversible : false ,
50+ }
51+ }
52+
53+ pub fn new_energy_error_too_large < M : Math > (
54+ math : & mut M ,
55+ start : & State < M , impl Point < M > > ,
56+ stop : & State < M , impl Point < M > > ,
57+ ) -> Self {
58+ DivergenceInfo {
59+ logp_function_error : None ,
60+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
61+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
62+ // TODO
63+ start_momentum : None ,
64+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
65+ end_location : Some ( math. box_array ( & stop. point ( ) . position ( ) ) ) ,
66+ end_idx_in_trajectory : Some ( stop. index_in_trajectory ( ) ) ,
67+ // TODO
68+ energy_error : None ,
69+ non_reversible : false ,
70+ }
71+ }
72+
73+ pub fn new_logp_function_error < M : Math > (
74+ math : & mut M ,
75+ start : & State < M , impl Point < M > > ,
76+ logp_function_error : Arc < dyn std:: error:: Error + Send + Sync > ,
77+ ) -> Self {
78+ DivergenceInfo {
79+ logp_function_error : Some ( logp_function_error) ,
80+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
81+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
82+ // TODO
83+ start_momentum : None ,
84+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
85+ end_location : None ,
86+ end_idx_in_trajectory : None ,
87+ energy_error : None ,
88+ non_reversible : false ,
89+ }
90+ }
91+
92+ pub fn new_not_reversible < M : Math > ( math : & mut M , start : & State < M , impl Point < M > > ) -> Self {
93+ // TODO add info about what went wrong
94+ DivergenceInfo {
95+ logp_function_error : None ,
96+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
97+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
98+ // TODO
99+ start_momentum : None ,
100+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
101+ end_location : None ,
102+ end_idx_in_trajectory : None ,
103+ energy_error : None ,
104+ non_reversible : true ,
105+ }
106+ }
107+ pub fn new_max_step_size_halvings < M : Math > ( math : & mut M , num_steps : u64 , info : Self ) -> Self {
108+ info // TODO
109+ }
110+ }
111+
112+ impl DivergenceInfo {
113+ pub ( crate ) fn new_non_reversible ( ) -> DivergenceInfo {
114+ todo ! ( )
115+ }
35116}
36117
37118/// Per-draw divergence statistics, suitable for storage.
@@ -108,6 +189,15 @@ pub enum Direction {
108189 Backward ,
109190}
110191
192+ impl Direction {
193+ pub fn reverse ( & self ) -> Self {
194+ match self {
195+ Direction :: Forward => Direction :: Backward ,
196+ Direction :: Backward => Direction :: Forward ,
197+ }
198+ }
199+ }
200+
111201impl Distribution < Direction > for StandardUniform {
112202 fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Direction {
113203 if rng. random :: < bool > ( ) {
@@ -167,6 +257,40 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
167257 collector : & mut C ,
168258 ) -> LeapfrogResult < M , Self :: Point > ;
169259
260+ fn split_leapfrog < C : Collector < M , Self :: Point > > (
261+ & mut self ,
262+ math : & mut M ,
263+ start : & State < M , Self :: Point > ,
264+ dir : Direction ,
265+ num_steps : u64 ,
266+ collector : & mut C ,
267+ max_error : f64 ,
268+ ) -> LeapfrogResult < M , Self :: Point > {
269+ let mut state = start. clone ( ) ;
270+
271+ let mut min_energy = start. energy ( ) ;
272+ let mut max_energy = min_energy;
273+
274+ for _ in 0 ..num_steps {
275+ state = match self . leapfrog ( math, & state, dir, num_steps, collector) {
276+ LeapfrogResult :: Ok ( state) => state,
277+ LeapfrogResult :: Divergence ( info) => return LeapfrogResult :: Divergence ( info) ,
278+ LeapfrogResult :: Err ( err) => return LeapfrogResult :: Err ( err) ,
279+ } ;
280+ let energy = state. energy ( ) ;
281+ min_energy = min_energy. min ( energy) ;
282+ max_energy = max_energy. max ( energy) ;
283+
284+ // TODO: walnuts papers says to use abs, but c++ code doesn't?
285+ if max_energy - min_energy > max_error {
286+ let info = DivergenceInfo :: new_energy_error_too_large ( math, start, & state) ;
287+ return LeapfrogResult :: Divergence ( info) ;
288+ }
289+ }
290+
291+ LeapfrogResult :: Ok ( state)
292+ }
293+
170294 fn is_turning (
171295 & self ,
172296 math : & mut M ,
@@ -255,4 +379,5 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
255379 let _ = ( math, state, noise, rng, factor) ;
256380 Ok ( ( ) )
257381 }
382+ fn max_energy_error ( & self ) -> f64 ;
258383}
0 commit comments