Skip to content

Commit 72e6d0a

Browse files
committed
feat: add walnuts implementation
1 parent 0ea5d36 commit 72e6d0a

9 files changed

Lines changed: 413 additions & 15 deletions

File tree

src/adapt_strategy.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,12 @@ where
326326
start: &State<M, P>,
327327
end: &State<M, P>,
328328
divergence_info: Option<&DivergenceInfo>,
329+
num_substeps: u64,
329330
) {
330331
self.collector1
331-
.register_leapfrog(math, start, end, divergence_info);
332+
.register_leapfrog(math, start, end, divergence_info, num_substeps);
332333
self.collector2
333-
.register_leapfrog(math, start, end, divergence_info);
334+
.register_leapfrog(math, start, end, divergence_info, num_substeps);
334335
}
335336

336337
fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
@@ -458,6 +459,7 @@ mod test {
458459
target_integration_time: None,
459460
extra_doublings: 0,
460461
max_energy_error: 1000.0,
462+
walnuts_options: None,
461463
};
462464

463465
let rng = {

src/chain.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ where
158158
&mut self.hamiltonian,
159159
&self.options,
160160
&mut self.collector,
161+
self.draw_count < 70,
161162
)?;
162163
let mut position: Box<[f64]> = vec![0f64; math.dim()].into();
163164
state.write_position(math, &mut position);
@@ -226,6 +227,17 @@ pub struct NutsStats<P: HasDims, H: Storable<P>, A: Storable<P>, D: Storable<P>>
226227
pub point: D,
227228
#[storable(flatten)]
228229
pub divergence: DivergenceStats,
230+
pub diverging: bool,
231+
#[storable(dims("unconstrained_parameter"))]
232+
pub divergence_start: Option<Vec<f64>>,
233+
#[storable(dims("unconstrained_parameter"))]
234+
pub divergence_start_gradient: Option<Vec<f64>>,
235+
#[storable(dims("unconstrained_parameter"))]
236+
pub divergence_end: Option<Vec<f64>>,
237+
#[storable(dims("unconstrained_parameter"))]
238+
pub divergence_momentum: Option<Vec<f64>>,
239+
non_reversible: Option<bool>,
240+
//pub divergence_message: Option<String>,
229241
#[storable(ignore)]
230242
_phantom: PhantomData<fn() -> P>,
231243
}
@@ -279,6 +291,17 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> SamplerStats<M> for NutsChain<M
279291
adapt: adapt_stats,
280292
point: point_stats,
281293
divergence: (div_info, options.divergence, self.draw_count).into(),
294+
diverging: div_info.is_some(),
295+
divergence_start: div_info
296+
.and_then(|d| d.start_location.as_ref().map(|v| v.as_ref().to_vec())),
297+
divergence_start_gradient: div_info
298+
.and_then(|d| d.start_gradient.as_ref().map(|v| v.as_ref().to_vec())),
299+
divergence_end: div_info
300+
.and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec())),
301+
divergence_momentum: div_info
302+
.and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec())),
303+
//divergence_message: self.divergence_msg.clone(),
304+
non_reversible: div_info.and_then(|d| Some(d.non_reversible)),
282305
_phantom: PhantomData,
283306
}
284307
}

src/dynamics/hamiltonian.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::{
2222
/// a cutoff value or nan.
2323
/// - The logp function caused a recoverable error (eg if an ODE solver
2424
/// failed)
25+
#[non_exhaustive]
2526
#[derive(Debug, Clone)]
2627
pub struct DivergenceInfo {
2728
pub start_momentum: Option<Box<[f64]>>,
@@ -32,6 +33,86 @@ pub struct DivergenceInfo {
3233
pub end_idx_in_trajectory: Option<i64>,
3334
pub start_idx_in_trajectory: Option<i64>,
3435
pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
36+
pub non_reversible: bool,
37+
}
38+
impl DivergenceInfo {
39+
pub fn new() -> Self {
40+
DivergenceInfo {
41+
start_momentum: None,
42+
start_location: None,
43+
start_gradient: None,
44+
end_location: None,
45+
energy_error: None,
46+
end_idx_in_trajectory: None,
47+
start_idx_in_trajectory: None,
48+
logp_function_error: None,
49+
non_reversible: false,
50+
}
51+
}
52+
53+
pub fn new_energy_error_too_large<M: Math>(
54+
math: &mut M,
55+
start: &State<M, impl Point<M>>,
56+
stop: &State<M, impl Point<M>>,
57+
) -> Self {
58+
DivergenceInfo {
59+
logp_function_error: None,
60+
start_location: Some(math.box_array(start.point().position())),
61+
start_gradient: Some(math.box_array(start.point().gradient())),
62+
// TODO
63+
start_momentum: None,
64+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
65+
end_location: Some(math.box_array(&stop.point().position())),
66+
end_idx_in_trajectory: Some(stop.index_in_trajectory()),
67+
// TODO
68+
energy_error: None,
69+
non_reversible: false,
70+
}
71+
}
72+
73+
pub fn new_logp_function_error<M: Math>(
74+
math: &mut M,
75+
start: &State<M, impl Point<M>>,
76+
logp_function_error: Arc<dyn std::error::Error + Send + Sync>,
77+
) -> Self {
78+
DivergenceInfo {
79+
logp_function_error: Some(logp_function_error),
80+
start_location: Some(math.box_array(start.point().position())),
81+
start_gradient: Some(math.box_array(start.point().gradient())),
82+
// TODO
83+
start_momentum: None,
84+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
85+
end_location: None,
86+
end_idx_in_trajectory: None,
87+
energy_error: None,
88+
non_reversible: false,
89+
}
90+
}
91+
92+
pub fn new_not_reversible<M: Math>(math: &mut M, start: &State<M, impl Point<M>>) -> Self {
93+
// TODO add info about what went wrong
94+
DivergenceInfo {
95+
logp_function_error: None,
96+
start_location: Some(math.box_array(start.point().position())),
97+
start_gradient: Some(math.box_array(start.point().gradient())),
98+
// TODO
99+
start_momentum: None,
100+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
101+
end_location: None,
102+
end_idx_in_trajectory: None,
103+
energy_error: None,
104+
non_reversible: true,
105+
}
106+
}
107+
pub fn new_max_step_size_halvings<M: Math>(math: &mut M, num_steps: u64, info: Self) -> Self {
108+
info // TODO
109+
}
110+
}
111+
112+
impl DivergenceInfo {
113+
pub(crate) fn new_non_reversible() -> DivergenceInfo {
114+
todo!()
115+
}
35116
}
36117

37118
/// Per-draw divergence statistics, suitable for storage.
@@ -108,6 +189,15 @@ pub enum Direction {
108189
Backward,
109190
}
110191

192+
impl Direction {
193+
pub fn reverse(&self) -> Self {
194+
match self {
195+
Direction::Forward => Direction::Backward,
196+
Direction::Backward => Direction::Forward,
197+
}
198+
}
199+
}
200+
111201
impl Distribution<Direction> for StandardUniform {
112202
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Direction {
113203
if rng.random::<bool>() {
@@ -167,6 +257,40 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
167257
collector: &mut C,
168258
) -> LeapfrogResult<M, Self::Point>;
169259

260+
fn split_leapfrog<C: Collector<M, Self::Point>>(
261+
&mut self,
262+
math: &mut M,
263+
start: &State<M, Self::Point>,
264+
dir: Direction,
265+
num_steps: u64,
266+
collector: &mut C,
267+
max_error: f64,
268+
) -> LeapfrogResult<M, Self::Point> {
269+
let mut state = start.clone();
270+
271+
let mut min_energy = start.energy();
272+
let mut max_energy = min_energy;
273+
274+
for _ in 0..num_steps {
275+
state = match self.leapfrog(math, &state, dir, num_steps, collector) {
276+
LeapfrogResult::Ok(state) => state,
277+
LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info),
278+
LeapfrogResult::Err(err) => return LeapfrogResult::Err(err),
279+
};
280+
let energy = state.energy();
281+
min_energy = min_energy.min(energy);
282+
max_energy = max_energy.max(energy);
283+
284+
// TODO: walnuts papers says to use abs, but c++ code doesn't?
285+
if max_energy - min_energy > max_error {
286+
let info = DivergenceInfo::new_energy_error_too_large(math, start, &state);
287+
return LeapfrogResult::Divergence(info);
288+
}
289+
}
290+
291+
LeapfrogResult::Ok(state)
292+
}
293+
170294
fn is_turning(
171295
&self,
172296
math: &mut M,
@@ -255,4 +379,5 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
255379
let _ = (math, state, noise, rng, factor);
256380
Ok(())
257381
}
382+
fn max_energy_error(&self) -> f64;
258383
}

src/dynamics/transformed_hamiltonian.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,9 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
572572
start_idx_in_trajectory: Some(start.point().index_in_trajectory()),
573573
end_idx_in_trajectory: None,
574574
energy_error: None,
575+
non_reversible: false,
575576
};
576-
collector.register_leapfrog(math, start, &out, Some(&div_info));
577+
collector.register_leapfrog(math, start, &out, Some(&div_info), step_size_splits);
577578
return LeapfrogResult::Divergence(div_info);
578579
}
579580

@@ -604,12 +605,19 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
604605
start_idx_in_trajectory: Some(start.index_in_trajectory()),
605606
end_idx_in_trajectory: Some(out.index_in_trajectory()),
606607
energy_error: Some(energy_error),
608+
non_reversible: false,
607609
};
608-
collector.register_leapfrog(math, start, &out, Some(&divergence_info));
610+
collector.register_leapfrog(
611+
math,
612+
start,
613+
&out,
614+
Some(&divergence_info),
615+
step_size_splits,
616+
);
609617
return LeapfrogResult::Divergence(divergence_info);
610618
}
611619

612-
collector.register_leapfrog(math, start, &out, None);
620+
collector.register_leapfrog(math, start, &out, None, step_size_splits);
613621

614622
LeapfrogResult::Ok(out)
615623
}
@@ -824,4 +832,8 @@ impl<M: Math, T: Transformation<M>> Hamiltonian<M> for TransformedHamiltonian<M,
824832

825833
Ok(())
826834
}
835+
836+
fn max_energy_error(&self) -> f64 {
837+
self.max_energy_error
838+
}
827839
}

src/external_adapt_strategy.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
9797
_start: &State<M, P>,
9898
end: &State<M, P>,
9999
divergence_info: Option<&crate::DivergenceInfo>,
100+
_num_substeps: u64,
100101
) {
101102
if divergence_info.is_some() {
102103
return;

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ pub use dynamics::{DivergenceInfo, KineticEnergyKind};
122122
pub use math::{CpuLogpFunc, CpuMath, CpuMathError, LogpError, Math};
123123
pub use mclmc::{MclmcChain, MclmcInfo, MclmcStats, MclmcTrajectoryKind};
124124
pub use model::Model;
125-
pub use nuts::NutsError;
125+
pub use nuts::{NutsError, WalnutsOptions};
126126

127127
#[allow(deprecated)]
128128
pub use sampler::{

0 commit comments

Comments
 (0)