@@ -100,20 +100,89 @@ where
100100 Ok ( max_error)
101101}
102102
103+ /// Metadata describing a warmup/burn-in sampling run.
104+ ///
105+ /// The counts describe the complete sampler schedule: discarded warmup draws,
106+ /// retained posterior draws, and their sum as total sampler transitions.
107+ #[ cfg_attr( feature = "serde" , derive( serde:: Serialize ) ) ]
108+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
109+ pub struct WarmupMetadata {
110+ /// Number of warmup/burn-in transitions that were run before collection.
111+ pub warmup_count : usize ,
112+ /// Number of posterior samples retained after warmup.
113+ pub retained_count : usize ,
114+ /// Total sampler transitions run (`warmup_count + retained_count`).
115+ pub total_iterations : usize ,
116+ }
117+
118+ impl WarmupMetadata {
119+ /// Create metadata for a warmup/burn-in sampling schedule.
120+ pub fn new ( warmup_count : usize , retained_count : usize ) -> Self {
121+ let total_iterations = warmup_count
122+ . checked_add ( retained_count)
123+ . expect ( "warmup and retained sample counts exceed usize::MAX" ) ;
124+
125+ Self {
126+ warmup_count,
127+ retained_count,
128+ total_iterations,
129+ }
130+ }
131+ }
132+
133+ /// Samples produced by a first-class warmup/burn-in workflow.
134+ ///
135+ /// `warmup_samples` are the discarded burn-in states and should not be used for
136+ /// posterior summaries. `samples` are the retained posterior draws. Sampler-level
137+ /// running statistics are reset between these phases by [`Sampler::run_with_warmup`]
138+ /// so acceptance statistics reported afterward describe only `samples` when the
139+ /// concrete sampler overrides [`Sampler::reset_statistics`].
140+ #[ cfg_attr( feature = "serde" , derive( serde:: Serialize ) ) ]
141+ #[ derive( Debug , Clone , PartialEq ) ]
142+ pub struct WarmupRun {
143+ /// Discarded warmup/burn-in samples, preserving iteration order.
144+ pub warmup_samples : Vec < DVector < f64 > > ,
145+ /// Retained posterior samples collected after warmup.
146+ pub samples : Vec < DVector < f64 > > ,
147+ /// Counts for the run schedule.
148+ pub metadata : WarmupMetadata ,
149+ }
150+
103151/// Trait for MCMC samplers
104152pub trait Sampler {
105153 /// Sample from the posterior distribution
106154 fn sample ( & mut self , n_samples : usize ) -> Vec < DVector < f64 > > ;
107155
108- /// Run warmup iterations, discard those states, then collect posterior samples.
156+ /// Run warmup iterations, keep them separated as discarded states, reset
157+ /// sampler statistics, then collect retained posterior samples.
109158 ///
110159 /// Warmup iterations let a Markov chain move away from its initial state before
111160 /// collecting draws for posterior summaries. This method does not perform
112161 /// automatic adaptation; callers should tune sampler parameters separately when
113162 /// their workflow requires it. Statistics such as acceptance rate are reset
114- /// after warmup, so they describe only the returned samples. Implementations
115- /// that maintain running statistics must override [`Sampler::reset_statistics`]
116- /// for this guarantee to hold.
163+ /// before retained sampling, even when `n_warmup` is zero, so they describe
164+ /// only the retained samples. Implementations that maintain running statistics
165+ /// must override [`Sampler::reset_statistics`] for this guarantee to hold.
166+ fn run_with_warmup ( & mut self , n_warmup : usize , n_samples : usize ) -> WarmupRun {
167+ let mut warmup_samples = Vec :: with_capacity ( n_warmup) ;
168+ for _ in 0 ..n_warmup {
169+ warmup_samples. push ( self . step ( ) ) ;
170+ }
171+ self . reset_statistics ( ) ;
172+
173+ let samples = self . sample ( n_samples) ;
174+ WarmupRun {
175+ warmup_samples,
176+ samples,
177+ metadata : WarmupMetadata :: new ( n_warmup, n_samples) ,
178+ }
179+ }
180+
181+ /// Run warmup iterations, discard those states, then collect posterior samples.
182+ ///
183+ /// This preserves the lightweight behavior of discarding warmup draws without
184+ /// allocating storage for them. Use [`Sampler::run_with_warmup`] when the
185+ /// discarded states and schedule metadata are needed.
117186 fn sample_with_warmup ( & mut self , n_warmup : usize , n_samples : usize ) -> Vec < DVector < f64 > > {
118187 for _ in 0 ..n_warmup {
119188 self . step ( ) ;
@@ -659,6 +728,128 @@ mod tests {
659728 use super :: * ;
660729 use crate :: distributions:: { Distribution , Normal } ;
661730
731+ #[ derive( Debug , Clone ) ]
732+ struct CountingSampler {
733+ state : DVector < f64 > ,
734+ steps : usize ,
735+ resets : usize ,
736+ }
737+
738+ impl CountingSampler {
739+ fn new ( ) -> Self {
740+ Self {
741+ state : DVector :: from_vec ( vec ! [ 0.0 ] ) ,
742+ steps : 0 ,
743+ resets : 0 ,
744+ }
745+ }
746+ }
747+
748+ impl Sampler for CountingSampler {
749+ fn sample ( & mut self , n_samples : usize ) -> Vec < DVector < f64 > > {
750+ ( 0 ..n_samples) . map ( |_| self . step ( ) ) . collect ( )
751+ }
752+
753+ fn step ( & mut self ) -> DVector < f64 > {
754+ self . steps += 1 ;
755+ self . state [ 0 ] += 1.0 ;
756+ self . state . clone ( )
757+ }
758+
759+ fn current_state ( & self ) -> & DVector < f64 > {
760+ & self . state
761+ }
762+
763+ fn reset_statistics ( & mut self ) {
764+ self . resets += 1 ;
765+ }
766+ }
767+
768+ #[ test]
769+ fn test_run_with_warmup_separates_discarded_and_retained_samples ( ) {
770+ let mut sampler = CountingSampler :: new ( ) ;
771+
772+ let run = sampler. run_with_warmup ( 2 , 3 ) ;
773+
774+ assert_eq ! (
775+ run. warmup_samples,
776+ vec![ DVector :: from_vec( vec![ 1.0 ] ) , DVector :: from_vec( vec![ 2.0 ] ) ]
777+ ) ;
778+ assert_eq ! (
779+ run. samples,
780+ vec![
781+ DVector :: from_vec( vec![ 3.0 ] ) ,
782+ DVector :: from_vec( vec![ 4.0 ] ) ,
783+ DVector :: from_vec( vec![ 5.0 ] ) ,
784+ ]
785+ ) ;
786+ assert_eq ! (
787+ run. metadata,
788+ WarmupMetadata {
789+ warmup_count: 2 ,
790+ retained_count: 3 ,
791+ total_iterations: 5 ,
792+ }
793+ ) ;
794+ assert_eq ! ( sampler. steps, 5 ) ;
795+ assert_eq ! ( sampler. resets, 1 ) ;
796+ assert_eq ! ( sampler. current_state( ) , & DVector :: from_vec( vec![ 5.0 ] ) ) ;
797+ }
798+
799+ #[ test]
800+ fn test_run_with_zero_warmup_retains_regular_samples_with_metadata ( ) {
801+ let mut sampler = CountingSampler :: new ( ) ;
802+
803+ let run = sampler. run_with_warmup ( 0 , 3 ) ;
804+
805+ assert ! ( run. warmup_samples. is_empty( ) ) ;
806+ assert_eq ! (
807+ run. samples,
808+ vec![
809+ DVector :: from_vec( vec![ 1.0 ] ) ,
810+ DVector :: from_vec( vec![ 2.0 ] ) ,
811+ DVector :: from_vec( vec![ 3.0 ] ) ,
812+ ]
813+ ) ;
814+ assert_eq ! ( run. metadata, WarmupMetadata :: new( 0 , 3 ) ) ;
815+ assert_eq ! ( sampler. steps, 3 ) ;
816+ assert_eq ! ( sampler. resets, 1 ) ;
817+ }
818+
819+ #[ test]
820+ fn test_run_with_warmup_allows_zero_retained_samples ( ) {
821+ let mut sampler = CountingSampler :: new ( ) ;
822+
823+ let run = sampler. run_with_warmup ( 3 , 0 ) ;
824+
825+ assert_eq ! (
826+ run. warmup_samples,
827+ vec![
828+ DVector :: from_vec( vec![ 1.0 ] ) ,
829+ DVector :: from_vec( vec![ 2.0 ] ) ,
830+ DVector :: from_vec( vec![ 3.0 ] ) ,
831+ ]
832+ ) ;
833+ assert ! ( run. samples. is_empty( ) ) ;
834+ assert_eq ! ( run. metadata, WarmupMetadata :: new( 3 , 0 ) ) ;
835+ assert_eq ! ( sampler. steps, 3 ) ;
836+ assert_eq ! ( sampler. resets, 1 ) ;
837+ }
838+
839+ #[ test]
840+ fn test_sample_with_warmup_returns_retained_samples_only ( ) {
841+ let mut sampler = CountingSampler :: new ( ) ;
842+
843+ let samples = sampler. sample_with_warmup ( 3 , 2 ) ;
844+
845+ assert_eq ! (
846+ samples,
847+ vec![ DVector :: from_vec( vec![ 4.0 ] ) , DVector :: from_vec( vec![ 5.0 ] ) ]
848+ ) ;
849+ assert_eq ! ( sampler. steps, 5 ) ;
850+ assert_eq ! ( sampler. resets, 1 ) ;
851+ }
852+
662853 #[ test]
663854 fn test_metropolis_hastings_creation ( ) {
664855 let log_posterior = |params : & DVector < f64 > | -> f64 {
@@ -800,6 +991,40 @@ mod tests {
800991 ) ;
801992 }
802993
994+ #[ test]
995+ fn test_run_with_warmup_resets_metropolis_hastings_statistics_before_retained_samples ( ) {
996+ let log_posterior = |params : & DVector < f64 > | -> f64 {
997+ let normal = Normal :: new ( 0.0 , 1.0 ) . unwrap ( ) ;
998+ normal. log_pdf ( params[ 0 ] )
999+ } ;
1000+
1001+ let initial_state = DVector :: from_vec ( vec ! [ 0.0 ] ) ;
1002+ let proposal_std = DVector :: from_vec ( vec ! [ 0.5 ] ) ;
1003+
1004+ let mut warmup_sampler = MetropolisHastings :: with_seed (
1005+ log_posterior,
1006+ initial_state. clone ( ) ,
1007+ proposal_std. clone ( ) ,
1008+ 789 ,
1009+ )
1010+ . unwrap ( ) ;
1011+ let run = warmup_sampler. run_with_warmup ( 25 , 50 ) ;
1012+
1013+ let mut manual_sampler =
1014+ MetropolisHastings :: with_seed ( log_posterior, initial_state, proposal_std, 789 ) . unwrap ( ) ;
1015+ let warmup_samples = manual_sampler. sample ( 25 ) ;
1016+ manual_sampler. reset_statistics ( ) ;
1017+ let retained_samples = manual_sampler. sample ( 50 ) ;
1018+
1019+ assert_eq ! ( run. warmup_samples, warmup_samples) ;
1020+ assert_eq ! ( run. samples, retained_samples) ;
1021+ assert_eq ! ( run. metadata, WarmupMetadata :: new( 25 , 50 ) ) ;
1022+ assert_eq ! (
1023+ warmup_sampler. acceptance_rate( ) ,
1024+ manual_sampler. acceptance_rate( )
1025+ ) ;
1026+ }
1027+
8031028 #[ test]
8041029 fn test_gibbs_sampler_creation ( ) {
8051030 let conditional_sampler = |_params : & DVector < f64 > ,
0 commit comments