Skip to content

Commit ee9f927

Browse files
committed
feat: clean up walnuts a little bit
1 parent 80b6652 commit ee9f927

10 files changed

Lines changed: 221 additions & 115 deletions

File tree

src/adapt_strategy.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,12 @@ where
326326
start: &State<M, P>,
327327
end: &State<M, P>,
328328
divergence_info: Option<&DivergenceInfo>,
329+
num_substeps: u64,
329330
) {
330331
self.collector1
331-
.register_leapfrog(math, start, end, divergence_info);
332+
.register_leapfrog(math, start, end, divergence_info, num_substeps);
332333
self.collector2
333-
.register_leapfrog(math, start, end, divergence_info);
334+
.register_leapfrog(math, start, end, divergence_info, num_substeps);
334335
}
335336

336337
fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {

src/chain.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ where
158158
&mut self.hamiltonian,
159159
&self.options,
160160
&mut self.collector,
161+
self.draw_count < 70,
161162
)?;
162163
let mut position: Box<[f64]> = vec![0f64; math.dim()].into();
163164
state.write_position(math, &mut position);
@@ -236,6 +237,7 @@ pub struct NutsStats<P: HasDims, H: Storable<P>, A: Storable<P>, D: Storable<P>>
236237
pub divergence_end: Option<Vec<f64>>,
237238
#[storable(dims("unconstrained_parameter"))]
238239
pub divergence_momentum: Option<Vec<f64>>,
240+
non_reversible: Option<bool>,
239241
//pub divergence_message: Option<String>,
240242
#[storable(ignore)]
241243
_phantom: PhantomData<fn() -> P>,
@@ -304,7 +306,7 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> SamplerStats<M> for NutsChain<M
304306
.and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec())),
305307
divergence_momentum: div_info
306308
.and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec())),
307-
//divergence_message: self.divergence_msg.clone(),
309+
non_reversible: div_info.and_then(|d| Some(d.non_reversible)),
308310
_phantom: PhantomData,
309311
}
310312
}

src/dynamics/hamiltonian.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::{
2121
/// a cutoff value or nan.
2222
/// - The logp function caused a recoverable error (eg if an ODE solver
2323
/// failed)
24+
#[non_exhaustive]
2425
#[derive(Debug, Clone)]
2526
pub struct DivergenceInfo {
2627
pub start_momentum: Option<Box<[f64]>>,
@@ -31,6 +32,80 @@ pub struct DivergenceInfo {
3132
pub end_idx_in_trajectory: Option<i64>,
3233
pub start_idx_in_trajectory: Option<i64>,
3334
pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
35+
pub non_reversible: bool,
36+
}
37+
impl DivergenceInfo {
38+
pub fn new() -> Self {
39+
DivergenceInfo {
40+
start_momentum: None,
41+
start_location: None,
42+
start_gradient: None,
43+
end_location: None,
44+
energy_error: None,
45+
end_idx_in_trajectory: None,
46+
start_idx_in_trajectory: None,
47+
logp_function_error: None,
48+
non_reversible: false,
49+
}
50+
}
51+
52+
pub fn new_energy_error_too_large<M: Math>(
53+
math: &mut M,
54+
start: &State<M, impl Point<M>>,
55+
stop: &State<M, impl Point<M>>,
56+
) -> Self {
57+
DivergenceInfo {
58+
logp_function_error: None,
59+
start_location: Some(math.box_array(start.point().position())),
60+
start_gradient: Some(math.box_array(start.point().gradient())),
61+
// TODO
62+
start_momentum: None,
63+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
64+
end_location: Some(math.box_array(&stop.point().position())),
65+
end_idx_in_trajectory: Some(stop.index_in_trajectory()),
66+
// TODO
67+
energy_error: None,
68+
non_reversible: false,
69+
}
70+
}
71+
72+
pub fn new_logp_function_error<M: Math>(
73+
math: &mut M,
74+
start: &State<M, impl Point<M>>,
75+
logp_function_error: Arc<dyn std::error::Error + Send + Sync>,
76+
) -> Self {
77+
DivergenceInfo {
78+
logp_function_error: Some(logp_function_error),
79+
start_location: Some(math.box_array(start.point().position())),
80+
start_gradient: Some(math.box_array(start.point().gradient())),
81+
// TODO
82+
start_momentum: None,
83+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
84+
end_location: None,
85+
end_idx_in_trajectory: None,
86+
energy_error: None,
87+
non_reversible: false,
88+
}
89+
}
90+
91+
pub fn new_not_reversible<M: Math>(math: &mut M, start: &State<M, impl Point<M>>) -> Self {
92+
// TODO add info about what went wrong
93+
DivergenceInfo {
94+
logp_function_error: None,
95+
start_location: Some(math.box_array(start.point().position())),
96+
start_gradient: Some(math.box_array(start.point().gradient())),
97+
// TODO
98+
start_momentum: None,
99+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
100+
end_location: None,
101+
end_idx_in_trajectory: None,
102+
energy_error: None,
103+
non_reversible: true,
104+
}
105+
}
106+
pub fn new_max_step_size_halvings<M: Math>(math: &mut M, num_steps: u64, info: Self) -> Self {
107+
info // TODO
108+
}
34109
}
35110

36111
#[derive(Debug, Copy, Clone)]
@@ -39,6 +114,15 @@ pub enum Direction {
39114
Backward,
40115
}
41116

117+
impl Direction {
118+
pub fn reverse(&self) -> Self {
119+
match self {
120+
Direction::Forward => Direction::Backward,
121+
Direction::Backward => Direction::Forward,
122+
}
123+
}
124+
}
125+
42126
impl Distribution<Direction> for StandardUniform {
43127
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Direction {
44128
if rng.random::<bool>() {
@@ -87,9 +171,44 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
87171
math: &mut M,
88172
start: &State<M, Self::Point>,
89173
dir: Direction,
174+
step_size_splits: u64,
90175
collector: &mut C,
91176
) -> LeapfrogResult<M, Self::Point>;
92177

178+
fn split_leapfrog<C: Collector<M, Self::Point>>(
179+
&mut self,
180+
math: &mut M,
181+
start: &State<M, Self::Point>,
182+
dir: Direction,
183+
num_steps: u64,
184+
collector: &mut C,
185+
max_error: f64,
186+
) -> LeapfrogResult<M, Self::Point> {
187+
let mut state = start.clone();
188+
189+
let mut min_energy = start.energy();
190+
let mut max_energy = min_energy;
191+
192+
for _ in 0..num_steps {
193+
state = match self.leapfrog(math, &state, dir, num_steps, collector) {
194+
LeapfrogResult::Ok(state) => state,
195+
LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info),
196+
LeapfrogResult::Err(err) => return LeapfrogResult::Err(err),
197+
};
198+
let energy = state.energy();
199+
min_energy = min_energy.min(energy);
200+
max_energy = max_energy.max(energy);
201+
202+
// TODO: walnuts papers says to use abs, but c++ code doesn't?
203+
if max_energy - min_energy > max_error {
204+
let info = DivergenceInfo::new_energy_error_too_large(math, start, &state);
205+
return LeapfrogResult::Divergence(info);
206+
}
207+
}
208+
209+
LeapfrogResult::Ok(state)
210+
}
211+
93212
fn is_turning(
94213
&self,
95214
math: &mut M,
@@ -128,4 +247,5 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
128247

129248
fn step_size(&self) -> f64;
130249
fn step_size_mut(&mut self) -> &mut f64;
250+
fn max_energy_error(&self) -> f64;
131251
}

src/dynamics/transformed_hamiltonian.rs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
379379
math: &mut M,
380380
start: &State<M, Self::Point>,
381381
dir: Direction,
382-
step_size_factor: f64,
382+
step_size_splits: u64,
383383
collector: &mut C,
384384
) -> LeapfrogResult<M, Self::Point> {
385385
let mut out = self.pool().new_state(math);
@@ -393,7 +393,7 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
393393
Direction::Backward => -1,
394394
};
395395

396-
let epsilon = (sign as f64) * self.step_size * step_size_factor;
396+
let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64);
397397

398398
start
399399
.point()
@@ -417,8 +417,9 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
417417
start_idx_in_trajectory: Some(start.point().index_in_trajectory()),
418418
end_idx_in_trajectory: None,
419419
energy_error: None,
420+
non_reversible: false,
420421
};
421-
collector.register_leapfrog(math, start, &out, Some(&div_info));
422+
collector.register_leapfrog(math, start, &out, Some(&div_info), step_size_splits);
422423
return LeapfrogResult::Divergence(div_info);
423424
}
424425

@@ -438,12 +439,19 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
438439
start_idx_in_trajectory: Some(start.index_in_trajectory()),
439440
end_idx_in_trajectory: Some(out.index_in_trajectory()),
440441
energy_error: Some(energy_error),
442+
non_reversible: false,
441443
};
442-
collector.register_leapfrog(math, start, &out, Some(&divergence_info));
444+
collector.register_leapfrog(
445+
math,
446+
start,
447+
&out,
448+
Some(&divergence_info),
449+
step_size_splits,
450+
);
443451
return LeapfrogResult::Divergence(divergence_info);
444452
}
445453

446-
collector.register_leapfrog(math, start, &out, None);
454+
collector.register_leapfrog(math, start, &out, None, step_size_splits);
447455

448456
LeapfrogResult::Ok(out)
449457
}
@@ -569,4 +577,8 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
569577
fn step_size_mut(&mut self) -> &mut f64 {
570578
&mut self.step_size
571579
}
580+
581+
fn max_energy_error(&self) -> f64 {
582+
self.max_energy_error
583+
}
572584
}

src/external_adapt_strategy.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
9393
_start: &State<M, P>,
9494
end: &State<M, P>,
9595
divergence_info: Option<&crate::DivergenceInfo>,
96+
_num_substeps: u64,
9697
) {
9798
if divergence_info.is_some() {
9899
return;

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ pub use chain::Chain;
120120
pub use dynamics::DivergenceInfo;
121121
pub use math::{CpuLogpFunc, CpuMath, CpuMathError, LogpError, Math};
122122
pub use model::Model;
123-
pub use nuts::NutsError;
123+
pub use nuts::{NutsError, WalnutsOptions};
124124
pub use sampler::{
125125
ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, NutsSettings, Progress,
126126
ProgressCallback, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings,

0 commit comments

Comments
 (0)