Skip to content

Commit 0f76fdd

Browse files
authored
Merge pull request #11 from SyntaxSpirits/feat/warmup-workflow
feat: add warmup workflow output
2 parents 59c0371 + 1af254e commit 0f76fdd

1 file changed

Lines changed: 229 additions & 4 deletions

File tree

src/samplers.rs

Lines changed: 229 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,89 @@ where
100100
Ok(max_error)
101101
}
102102

103+
/// Metadata describing a warmup/burn-in sampling run.
104+
///
105+
/// The counts describe the complete sampler schedule: discarded warmup draws,
106+
/// retained posterior draws, and their sum as total sampler transitions.
107+
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
108+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109+
pub struct WarmupMetadata {
110+
/// Number of warmup/burn-in transitions that were run before collection.
111+
pub warmup_count: usize,
112+
/// Number of posterior samples retained after warmup.
113+
pub retained_count: usize,
114+
/// Total sampler transitions run (`warmup_count + retained_count`).
115+
pub total_iterations: usize,
116+
}
117+
118+
impl WarmupMetadata {
119+
/// Create metadata for a warmup/burn-in sampling schedule.
120+
pub fn new(warmup_count: usize, retained_count: usize) -> Self {
121+
let total_iterations = warmup_count
122+
.checked_add(retained_count)
123+
.expect("warmup and retained sample counts exceed usize::MAX");
124+
125+
Self {
126+
warmup_count,
127+
retained_count,
128+
total_iterations,
129+
}
130+
}
131+
}
132+
133+
/// Samples produced by a first-class warmup/burn-in workflow.
134+
///
135+
/// `warmup_samples` are the discarded burn-in states and should not be used for
136+
/// posterior summaries. `samples` are the retained posterior draws. Sampler-level
137+
/// running statistics are reset between these phases by [`Sampler::run_with_warmup`]
138+
/// so acceptance statistics reported afterward describe only `samples` when the
139+
/// concrete sampler overrides [`Sampler::reset_statistics`].
140+
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
141+
#[derive(Debug, Clone, PartialEq)]
142+
pub struct WarmupRun {
143+
/// Discarded warmup/burn-in samples, preserving iteration order.
144+
pub warmup_samples: Vec<DVector<f64>>,
145+
/// Retained posterior samples collected after warmup.
146+
pub samples: Vec<DVector<f64>>,
147+
/// Counts for the run schedule.
148+
pub metadata: WarmupMetadata,
149+
}
150+
103151
/// Trait for MCMC samplers
104152
pub trait Sampler {
105153
/// Sample from the posterior distribution
106154
fn sample(&mut self, n_samples: usize) -> Vec<DVector<f64>>;
107155

108-
/// Run warmup iterations, discard those states, then collect posterior samples.
156+
/// Run warmup iterations, keep them separated as discarded states, reset
157+
/// sampler statistics, then collect retained posterior samples.
109158
///
110159
/// Warmup iterations let a Markov chain move away from its initial state before
111160
/// collecting draws for posterior summaries. This method does not perform
112161
/// automatic adaptation; callers should tune sampler parameters separately when
113162
/// their workflow requires it. Statistics such as acceptance rate are reset
114-
/// after warmup, so they describe only the returned samples. Implementations
115-
/// that maintain running statistics must override [`Sampler::reset_statistics`]
116-
/// for this guarantee to hold.
163+
/// before retained sampling, even when `n_warmup` is zero, so they describe
164+
/// only the retained samples. Implementations that maintain running statistics
165+
/// must override [`Sampler::reset_statistics`] for this guarantee to hold.
166+
fn run_with_warmup(&mut self, n_warmup: usize, n_samples: usize) -> WarmupRun {
167+
let mut warmup_samples = Vec::with_capacity(n_warmup);
168+
for _ in 0..n_warmup {
169+
warmup_samples.push(self.step());
170+
}
171+
self.reset_statistics();
172+
173+
let samples = self.sample(n_samples);
174+
WarmupRun {
175+
warmup_samples,
176+
samples,
177+
metadata: WarmupMetadata::new(n_warmup, n_samples),
178+
}
179+
}
180+
181+
/// Run warmup iterations, discard those states, then collect posterior samples.
182+
///
183+
/// This preserves the lightweight behavior of discarding warmup draws without
184+
/// allocating storage for them. Use [`Sampler::run_with_warmup`] when the
185+
/// discarded states and schedule metadata are needed.
117186
fn sample_with_warmup(&mut self, n_warmup: usize, n_samples: usize) -> Vec<DVector<f64>> {
118187
for _ in 0..n_warmup {
119188
self.step();
@@ -659,6 +728,128 @@ mod tests {
659728
use super::*;
660729
use crate::distributions::{Distribution, Normal};
661730

731+
#[derive(Debug, Clone)]
732+
struct CountingSampler {
733+
state: DVector<f64>,
734+
steps: usize,
735+
resets: usize,
736+
}
737+
738+
impl CountingSampler {
739+
fn new() -> Self {
740+
Self {
741+
state: DVector::from_vec(vec![0.0]),
742+
steps: 0,
743+
resets: 0,
744+
}
745+
}
746+
}
747+
748+
impl Sampler for CountingSampler {
749+
fn sample(&mut self, n_samples: usize) -> Vec<DVector<f64>> {
750+
(0..n_samples).map(|_| self.step()).collect()
751+
}
752+
753+
fn step(&mut self) -> DVector<f64> {
754+
self.steps += 1;
755+
self.state[0] += 1.0;
756+
self.state.clone()
757+
}
758+
759+
fn current_state(&self) -> &DVector<f64> {
760+
&self.state
761+
}
762+
763+
fn reset_statistics(&mut self) {
764+
self.resets += 1;
765+
}
766+
}
767+
768+
#[test]
769+
fn test_run_with_warmup_separates_discarded_and_retained_samples() {
770+
let mut sampler = CountingSampler::new();
771+
772+
let run = sampler.run_with_warmup(2, 3);
773+
774+
assert_eq!(
775+
run.warmup_samples,
776+
vec![DVector::from_vec(vec![1.0]), DVector::from_vec(vec![2.0])]
777+
);
778+
assert_eq!(
779+
run.samples,
780+
vec![
781+
DVector::from_vec(vec![3.0]),
782+
DVector::from_vec(vec![4.0]),
783+
DVector::from_vec(vec![5.0]),
784+
]
785+
);
786+
assert_eq!(
787+
run.metadata,
788+
WarmupMetadata {
789+
warmup_count: 2,
790+
retained_count: 3,
791+
total_iterations: 5,
792+
}
793+
);
794+
assert_eq!(sampler.steps, 5);
795+
assert_eq!(sampler.resets, 1);
796+
assert_eq!(sampler.current_state(), &DVector::from_vec(vec![5.0]));
797+
}
798+
799+
#[test]
800+
fn test_run_with_zero_warmup_retains_regular_samples_with_metadata() {
801+
let mut sampler = CountingSampler::new();
802+
803+
let run = sampler.run_with_warmup(0, 3);
804+
805+
assert!(run.warmup_samples.is_empty());
806+
assert_eq!(
807+
run.samples,
808+
vec![
809+
DVector::from_vec(vec![1.0]),
810+
DVector::from_vec(vec![2.0]),
811+
DVector::from_vec(vec![3.0]),
812+
]
813+
);
814+
assert_eq!(run.metadata, WarmupMetadata::new(0, 3));
815+
assert_eq!(sampler.steps, 3);
816+
assert_eq!(sampler.resets, 1);
817+
}
818+
819+
#[test]
820+
fn test_run_with_warmup_allows_zero_retained_samples() {
821+
let mut sampler = CountingSampler::new();
822+
823+
let run = sampler.run_with_warmup(3, 0);
824+
825+
assert_eq!(
826+
run.warmup_samples,
827+
vec![
828+
DVector::from_vec(vec![1.0]),
829+
DVector::from_vec(vec![2.0]),
830+
DVector::from_vec(vec![3.0]),
831+
]
832+
);
833+
assert!(run.samples.is_empty());
834+
assert_eq!(run.metadata, WarmupMetadata::new(3, 0));
835+
assert_eq!(sampler.steps, 3);
836+
assert_eq!(sampler.resets, 1);
837+
}
838+
839+
#[test]
840+
fn test_sample_with_warmup_returns_retained_samples_only() {
841+
let mut sampler = CountingSampler::new();
842+
843+
let samples = sampler.sample_with_warmup(3, 2);
844+
845+
assert_eq!(
846+
samples,
847+
vec![DVector::from_vec(vec![4.0]), DVector::from_vec(vec![5.0])]
848+
);
849+
assert_eq!(sampler.steps, 5);
850+
assert_eq!(sampler.resets, 1);
851+
}
852+
662853
#[test]
663854
fn test_metropolis_hastings_creation() {
664855
let log_posterior = |params: &DVector<f64>| -> f64 {
@@ -800,6 +991,40 @@ mod tests {
800991
);
801992
}
802993

994+
#[test]
995+
fn test_run_with_warmup_resets_metropolis_hastings_statistics_before_retained_samples() {
996+
let log_posterior = |params: &DVector<f64>| -> f64 {
997+
let normal = Normal::new(0.0, 1.0).unwrap();
998+
normal.log_pdf(params[0])
999+
};
1000+
1001+
let initial_state = DVector::from_vec(vec![0.0]);
1002+
let proposal_std = DVector::from_vec(vec![0.5]);
1003+
1004+
let mut warmup_sampler = MetropolisHastings::with_seed(
1005+
log_posterior,
1006+
initial_state.clone(),
1007+
proposal_std.clone(),
1008+
789,
1009+
)
1010+
.unwrap();
1011+
let run = warmup_sampler.run_with_warmup(25, 50);
1012+
1013+
let mut manual_sampler =
1014+
MetropolisHastings::with_seed(log_posterior, initial_state, proposal_std, 789).unwrap();
1015+
let warmup_samples = manual_sampler.sample(25);
1016+
manual_sampler.reset_statistics();
1017+
let retained_samples = manual_sampler.sample(50);
1018+
1019+
assert_eq!(run.warmup_samples, warmup_samples);
1020+
assert_eq!(run.samples, retained_samples);
1021+
assert_eq!(run.metadata, WarmupMetadata::new(25, 50));
1022+
assert_eq!(
1023+
warmup_sampler.acceptance_rate(),
1024+
manual_sampler.acceptance_rate()
1025+
);
1026+
}
1027+
8031028
#[test]
8041029
fn test_gibbs_sampler_creation() {
8051030
let conditional_sampler = |_params: &DVector<f64>,

0 commit comments

Comments
 (0)