Skip to content

Commit e7d5ee0

Browse files
committed
fix: change mclmc defaults
1 parent 4406e6e commit e7d5ee0

1 file changed

Lines changed: 69 additions & 72 deletions

File tree

src/sampler.rs

Lines changed: 69 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use std::{
2222
};
2323

2424
use crate::{
25-
DiagAdaptExpSettings, Math,
25+
DiagAdaptExpSettings, Math, StepSizeAdaptMethod,
2626
adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
2727
chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
2828
dynamics::{KineticEnergyKind, TransformedHamiltonian, TransformedPointStatsOptions},
@@ -316,32 +316,6 @@ fn usize_hint(value: u64, field: &str) -> usize {
316316
.unwrap_or_else(|_| panic!("{field} must be smaller than usize::MAX"))
317317
}
318318

319-
fn default_diag_mclmc_adapt_options() -> EuclideanAdaptOptions<DiagAdaptExpSettings> {
320-
let mut adapt_options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
321-
adapt_options.step_size_window = 0.0;
322-
adapt_options.step_size_settings = crate::stepsize::StepSizeSettings {
323-
adapt_options: crate::stepsize::StepSizeAdaptOptions {
324-
method: crate::stepsize::StepSizeAdaptMethod::Fixed(0.5),
325-
..crate::stepsize::StepSizeAdaptOptions::default()
326-
},
327-
..crate::stepsize::StepSizeSettings::default()
328-
};
329-
adapt_options
330-
}
331-
332-
fn default_low_rank_mclmc_adapt_options() -> EuclideanAdaptOptions<LowRankSettings> {
333-
let mut adapt_options = EuclideanAdaptOptions::<LowRankSettings>::default();
334-
adapt_options.step_size_window = 0.0;
335-
adapt_options.step_size_settings = crate::stepsize::StepSizeSettings {
336-
adapt_options: crate::stepsize::StepSizeAdaptOptions {
337-
method: crate::stepsize::StepSizeAdaptMethod::Fixed(0.5),
338-
..crate::stepsize::StepSizeAdaptOptions::default()
339-
},
340-
..crate::stepsize::StepSizeSettings::default()
341-
};
342-
adapt_options
343-
}
344-
345319
fn default_mclmc_settings<A: Debug + Copy + Default + Serialize>(
346320
adapt_options: A,
347321
num_tune: u64,
@@ -362,21 +336,26 @@ fn default_mclmc_settings<A: Debug + Copy + Default + Serialize>(
362336
store_transformed: false,
363337
adapt_options,
364338
subsample_frequency: 1.0,
365-
dynamic_step_size: false,
339+
dynamic_step_size: true,
366340
trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
367341
trajectory_switch_fraction: 0.3,
368342
}
369343
}
370344

371345
impl Default for DiagMclmcSettings {
372346
fn default() -> Self {
373-
default_mclmc_settings(default_diag_mclmc_adapt_options(), 400, 6, 1000.0)
347+
let mut adapt_options = EuclideanAdaptOptions::default();
348+
adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
349+
default_mclmc_settings(adapt_options, 400, 6, 1000.0)
374350
}
375351
}
376352

377353
impl Default for LowRankMclmcSettings {
378354
fn default() -> Self {
379-
default_mclmc_settings(default_low_rank_mclmc_adapt_options(), 800, 6, 1000.0)
355+
let mut adapt_options = EuclideanAdaptOptions::default();
356+
adapt_options.early_mass_matrix_switch_freq = 20;
357+
adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
358+
default_mclmc_settings(adapt_options, 800, 6, 1000.0)
380359
}
381360
}
382361

@@ -474,11 +453,16 @@ impl Settings for DiagMclmcSettings {
474453
mass_matrix: (),
475454
},
476455
hamiltonian: -1,
477-
point: point_stats_options(
478-
self.store_gradient,
479-
self.store_unconstrained,
480-
self.store_transformed,
481-
),
456+
point: {
457+
let store_gradient = self.store_gradient;
458+
let store_unconstrained = self.store_unconstrained;
459+
let store_transformed = self.store_transformed;
460+
TransformedPointStatsOptions {
461+
store_gradient,
462+
store_unconstrained,
463+
store_transformed,
464+
}
465+
},
482466
divergence: crate::dynamics::DivergenceStatsOptions {
483467
store_divergences: self.store_divergences,
484468
},
@@ -592,11 +576,16 @@ impl Settings for LowRankMclmcSettings {
592576
mass_matrix: (),
593577
},
594578
hamiltonian: -1,
595-
point: point_stats_options(
596-
self.store_gradient,
597-
self.store_unconstrained,
598-
self.store_transformed,
599-
),
579+
point: {
580+
let store_gradient = self.store_gradient;
581+
let store_unconstrained = self.store_unconstrained;
582+
let store_transformed = self.store_transformed;
583+
TransformedPointStatsOptions {
584+
store_gradient,
585+
store_unconstrained,
586+
store_transformed,
587+
}
588+
},
600589
divergence: crate::dynamics::DivergenceStatsOptions {
601590
store_divergences: self.store_divergences,
602591
},
@@ -647,18 +636,6 @@ fn nuts_options(settings: &NutsSettings<impl Debug + Copy + Default + Serialize>
647636
}
648637
}
649638

650-
fn point_stats_options(
651-
store_gradient: bool,
652-
store_unconstrained: bool,
653-
store_transformed: bool,
654-
) -> TransformedPointStatsOptions {
655-
TransformedPointStatsOptions {
656-
store_gradient,
657-
store_unconstrained,
658-
store_transformed,
659-
}
660-
}
661-
662639
impl Settings for LowRankNutsSettings {
663640
type Chain<M: Math> = LowRankNutsChain<M>;
664641

@@ -711,11 +688,16 @@ impl Settings for LowRankNutsSettings {
711688
step_size: (),
712689
},
713690
hamiltonian: -1,
714-
point: point_stats_options(
715-
self.store_gradient,
716-
self.store_unconstrained,
717-
self.store_transformed,
718-
),
691+
point: {
692+
let store_gradient = self.store_gradient;
693+
let store_unconstrained = self.store_unconstrained;
694+
let store_transformed = self.store_transformed;
695+
TransformedPointStatsOptions {
696+
store_gradient,
697+
store_unconstrained,
698+
store_transformed,
699+
}
700+
},
719701
divergence: crate::dynamics::DivergenceStatsOptions {
720702
store_divergences: self.store_divergences,
721703
},
@@ -786,11 +768,16 @@ impl Settings for DiagNutsSettings {
786768
step_size: (),
787769
},
788770
hamiltonian: -1,
789-
point: point_stats_options(
790-
self.store_gradient,
791-
self.store_unconstrained,
792-
self.store_transformed,
793-
),
771+
point: {
772+
let store_gradient = self.store_gradient;
773+
let store_unconstrained = self.store_unconstrained;
774+
let store_transformed = self.store_transformed;
775+
TransformedPointStatsOptions {
776+
store_gradient,
777+
store_unconstrained,
778+
store_transformed,
779+
}
780+
},
794781
divergence: crate::dynamics::DivergenceStatsOptions {
795782
store_divergences: self.store_divergences,
796783
},
@@ -859,11 +846,16 @@ impl Settings for FlowNutsSettings {
859846
StatOptions {
860847
adapt: (),
861848
hamiltonian: (),
862-
point: point_stats_options(
863-
self.store_gradient,
864-
self.store_unconstrained,
865-
self.store_transformed,
866-
),
849+
point: {
850+
let store_gradient = self.store_gradient;
851+
let store_unconstrained = self.store_unconstrained;
852+
let store_transformed = self.store_transformed;
853+
TransformedPointStatsOptions {
854+
store_gradient,
855+
store_unconstrained,
856+
store_transformed,
857+
}
858+
},
867859
divergence: crate::dynamics::DivergenceStatsOptions {
868860
store_divergences: self.store_divergences,
869861
},
@@ -948,11 +940,16 @@ impl Settings for FlowMclmcSettings {
948940
StatOptions {
949941
adapt: (),
950942
hamiltonian: (),
951-
point: point_stats_options(
952-
self.store_gradient,
953-
self.store_unconstrained,
954-
self.store_transformed,
955-
),
943+
point: {
944+
let store_gradient = self.store_gradient;
945+
let store_unconstrained = self.store_unconstrained;
946+
let store_transformed = self.store_transformed;
947+
TransformedPointStatsOptions {
948+
store_gradient,
949+
store_unconstrained,
950+
store_transformed,
951+
}
952+
},
956953
divergence: crate::dynamics::DivergenceStatsOptions {
957954
store_divergences: self.store_divergences,
958955
},

0 commit comments

Comments
 (0)