@@ -148,6 +148,33 @@ pub struct WarmupRun {
148148 pub metadata : WarmupMetadata ,
149149}
150150
151+ fn validate_initial_state ( initial_state : & DVector < f64 > ) -> Result < ( ) > {
152+ if initial_state. is_empty ( ) {
153+ return Err ( BayesError :: invalid_parameter (
154+ "Initial state must have at least one dimension" ,
155+ ) ) ;
156+ }
157+
158+ if initial_state. iter ( ) . any ( |value| !value. is_finite ( ) ) {
159+ return Err ( BayesError :: invalid_parameter (
160+ "Initial state must contain only finite values" ,
161+ ) ) ;
162+ }
163+
164+ Ok ( ( ) )
165+ }
166+
167+ fn validate_positive_finite_vector ( values : & DVector < f64 > , message : & ' static str ) -> Result < ( ) > {
168+ if values
169+ . iter ( )
170+ . any ( |& value| value <= 0.0 || !value. is_finite ( ) )
171+ {
172+ return Err ( BayesError :: invalid_parameter ( message) ) ;
173+ }
174+
175+ Ok ( ( ) )
176+ }
177+
151178/// Trait for MCMC samplers
152179pub trait Sampler {
153180 /// Sample from the posterior distribution
@@ -264,18 +291,19 @@ where
264291 proposal_std : DVector < f64 > ,
265292 rng : R ,
266293 ) -> Result < Self > {
294+ validate_initial_state ( & initial_state) ?;
295+
267296 if initial_state. len ( ) != proposal_std. len ( ) {
268297 return Err ( BayesError :: dimension_mismatch (
269298 initial_state. len ( ) ,
270299 proposal_std. len ( ) ,
271300 ) ) ;
272301 }
273302
274- if proposal_std. iter ( ) . any ( |& std| std <= 0.0 ) {
275- return Err ( BayesError :: invalid_parameter (
276- "All proposal standard deviations must be positive" ,
277- ) ) ;
278- }
303+ validate_positive_finite_vector (
304+ & proposal_std,
305+ "All proposal standard deviations must be positive and finite" ,
306+ ) ?;
279307
280308 let current_log_posterior = log_posterior ( & initial_state) ;
281309 if !current_log_posterior. is_finite ( ) {
@@ -304,11 +332,10 @@ where
304332 ) ) ;
305333 }
306334
307- if proposal_std. iter ( ) . any ( |& std| std <= 0.0 ) {
308- return Err ( BayesError :: invalid_parameter (
309- "All proposal standard deviations must be positive" ,
310- ) ) ;
311- }
335+ validate_positive_finite_vector (
336+ & proposal_std,
337+ "All proposal standard deviations must be positive and finite" ,
338+ ) ?;
312339
313340 self . proposal_std = proposal_std;
314341 Ok ( ( ) )
@@ -436,6 +463,8 @@ where
436463 initial_state : DVector < f64 > ,
437464 rng : R ,
438465 ) -> Result < Self > {
466+ validate_initial_state ( & initial_state) ?;
467+
439468 if conditional_samplers. len ( ) != initial_state. len ( ) {
440469 return Err ( BayesError :: dimension_mismatch (
441470 conditional_samplers. len ( ) ,
@@ -558,8 +587,12 @@ where
558587 n_leapfrog : usize ,
559588 rng : R ,
560589 ) -> Result < Self > {
561- if step_size <= 0.0 {
562- return Err ( BayesError :: invalid_parameter ( "Step size must be positive" ) ) ;
590+ validate_initial_state ( & initial_state) ?;
591+
592+ if step_size <= 0.0 || !step_size. is_finite ( ) {
593+ return Err ( BayesError :: invalid_parameter (
594+ "Step size must be positive and finite" ,
595+ ) ) ;
563596 }
564597
565598 if n_leapfrog == 0 {
@@ -601,11 +634,10 @@ where
601634 ) ) ;
602635 }
603636
604- if mass_matrix. iter ( ) . any ( |& m| m <= 0.0 ) {
605- return Err ( BayesError :: invalid_parameter (
606- "All mass matrix elements must be positive" ,
607- ) ) ;
608- }
637+ validate_positive_finite_vector (
638+ & mass_matrix,
639+ "All mass matrix elements must be positive and finite" ,
640+ ) ?;
609641
610642 self . mass_matrix = mass_matrix;
611643 Ok ( ( ) )
@@ -1239,16 +1271,140 @@ mod tests {
12391271 }
12401272
12411273 #[ test]
1242- fn test_invalid_parameters ( ) {
1243- let log_posterior = |params : & DVector < f64 > | -> f64 {
1244- let normal = Normal :: new ( 0.0 , 1.0 ) . unwrap ( ) ;
1245- normal. log_pdf ( params[ 0 ] )
1246- } ;
1274+ fn test_metropolis_hastings_rejects_non_positive_and_non_finite_proposal_std ( ) {
1275+ let log_posterior = |params : & DVector < f64 > | -> f64 { -0.5 * params[ 0 ] * params[ 0 ] } ;
1276+ let initial_state = DVector :: from_vec ( vec ! [ 0.0 ] ) ;
12471277
1278+ for invalid_std in [ 0.0 , -1.0 , f64:: NAN , f64:: INFINITY , f64:: NEG_INFINITY ] {
1279+ let error = MetropolisHastings :: new (
1280+ log_posterior,
1281+ initial_state. clone ( ) ,
1282+ DVector :: from_vec ( vec ! [ invalid_std] ) ,
1283+ )
1284+ . err ( )
1285+ . expect ( "invalid proposal standard deviation should be rejected" ) ;
1286+ assert ! ( error. to_string( ) . contains( "positive and finite" ) ) ;
1287+ }
1288+ }
1289+
1290+ #[ test]
1291+ fn test_metropolis_hastings_set_proposal_std_rejects_non_finite_values ( ) {
1292+ let log_posterior = |params : & DVector < f64 > | -> f64 { -0.5 * params[ 0 ] * params[ 0 ] } ;
1293+ let mut sampler = MetropolisHastings :: new (
1294+ log_posterior,
1295+ DVector :: from_vec ( vec ! [ 0.0 ] ) ,
1296+ DVector :: from_vec ( vec ! [ 1.0 ] ) ,
1297+ )
1298+ . unwrap ( ) ;
1299+
1300+ for invalid_std in [ f64:: NAN , f64:: INFINITY , f64:: NEG_INFINITY ] {
1301+ let error = match sampler. set_proposal_std ( DVector :: from_vec ( vec ! [ invalid_std] ) ) {
1302+ Ok ( ( ) ) => panic ! ( "invalid proposal standard deviation should be rejected" ) ,
1303+ Err ( error) => error,
1304+ } ;
1305+ assert ! ( error. to_string( ) . contains( "positive and finite" ) ) ;
1306+ }
1307+ }
1308+
1309+ #[ test]
1310+ fn test_hmc_rejects_non_positive_and_non_finite_step_size ( ) {
1311+ let log_posterior = |params : & DVector < f64 > | -> f64 { -0.5 * params[ 0 ] * params[ 0 ] } ;
1312+ let gradient =
1313+ |params : & DVector < f64 > | -> DVector < f64 > { DVector :: from_vec ( vec ! [ -params[ 0 ] ] ) } ;
12481314 let initial_state = DVector :: from_vec ( vec ! [ 0.0 ] ) ;
1249- let bad_proposal_std = DVector :: from_vec ( vec ! [ 0.0 ] ) ; // Invalid: zero std
12501315
1251- let sampler = MetropolisHastings :: new ( log_posterior, initial_state, bad_proposal_std) ;
1252- assert ! ( sampler. is_err( ) ) ;
1316+ for invalid_step_size in [ 0.0 , -0.1 , f64:: NAN , f64:: INFINITY , f64:: NEG_INFINITY ] {
1317+ let error = HamiltonianMonteCarlo :: new (
1318+ log_posterior,
1319+ gradient,
1320+ initial_state. clone ( ) ,
1321+ invalid_step_size,
1322+ 10 ,
1323+ )
1324+ . err ( )
1325+ . expect ( "invalid HMC step size should be rejected" ) ;
1326+ assert ! ( error. to_string( ) . contains( "positive and finite" ) ) ;
1327+ }
1328+ }
1329+
1330+ #[ test]
1331+ fn test_hmc_set_mass_matrix_rejects_non_positive_and_non_finite_values ( ) {
1332+ let log_posterior = |params : & DVector < f64 > | -> f64 { -0.5 * params[ 0 ] * params[ 0 ] } ;
1333+ let gradient =
1334+ |params : & DVector < f64 > | -> DVector < f64 > { DVector :: from_vec ( vec ! [ -params[ 0 ] ] ) } ;
1335+ let mut sampler = HamiltonianMonteCarlo :: new (
1336+ log_posterior,
1337+ gradient,
1338+ DVector :: from_vec ( vec ! [ 0.0 ] ) ,
1339+ 0.1 ,
1340+ 10 ,
1341+ )
1342+ . unwrap ( ) ;
1343+
1344+ for invalid_mass in [ 0.0 , -1.0 , f64:: NAN , f64:: INFINITY , f64:: NEG_INFINITY ] {
1345+ let error = match sampler. set_mass_matrix ( DVector :: from_vec ( vec ! [ invalid_mass] ) ) {
1346+ Ok ( ( ) ) => panic ! ( "invalid mass matrix should be rejected" ) ,
1347+ Err ( error) => error,
1348+ } ;
1349+ assert ! ( error. to_string( ) . contains( "positive and finite" ) ) ;
1350+ }
1351+ }
1352+
1353+ #[ test]
1354+ fn test_samplers_reject_zero_dimensional_initial_states ( ) {
1355+ let log_posterior = |_params : & DVector < f64 > | -> f64 { 0.0 } ;
1356+ let gradient = |_params : & DVector < f64 > | -> DVector < f64 > { DVector :: zeros ( 0 ) } ;
1357+ let empty_state = DVector :: zeros ( 0 ) ;
1358+
1359+ let mh_error =
1360+ MetropolisHastings :: new ( log_posterior, empty_state. clone ( ) , DVector :: zeros ( 0 ) )
1361+ . err ( )
1362+ . expect ( "zero-dimensional MH initial state should be rejected" ) ;
1363+ assert ! ( mh_error. to_string( ) . contains( "at least one dimension" ) ) ;
1364+
1365+ let hmc_error =
1366+ HamiltonianMonteCarlo :: new ( log_posterior, gradient, empty_state. clone ( ) , 0.1 , 10 )
1367+ . err ( )
1368+ . expect ( "zero-dimensional HMC initial state should be rejected" ) ;
1369+ assert ! ( hmc_error. to_string( ) . contains( "at least one dimension" ) ) ;
1370+
1371+ let samplers: Vec < _ > = Vec :: < fn ( & DVector < f64 > , usize , & mut ThreadRng ) -> f64 > :: new ( ) ;
1372+ let gibbs_error = GibbsSampler :: new ( samplers, empty_state)
1373+ . err ( )
1374+ . expect ( "zero-dimensional Gibbs initial state should be rejected" ) ;
1375+ assert ! ( gibbs_error. to_string( ) . contains( "at least one dimension" ) ) ;
1376+ }
1377+
1378+ #[ test]
1379+ fn test_samplers_reject_non_finite_initial_states ( ) {
1380+ let log_posterior = |params : & DVector < f64 > | -> f64 { -0.5 * params[ 0 ] * params[ 0 ] } ;
1381+ let gradient =
1382+ |params : & DVector < f64 > | -> DVector < f64 > { DVector :: from_vec ( vec ! [ -params[ 0 ] ] ) } ;
1383+
1384+ for invalid_value in [ f64:: NAN , f64:: INFINITY , f64:: NEG_INFINITY ] {
1385+ let initial_state = DVector :: from_vec ( vec ! [ invalid_value] ) ;
1386+
1387+ let mh_error = MetropolisHastings :: new (
1388+ log_posterior,
1389+ initial_state. clone ( ) ,
1390+ DVector :: from_vec ( vec ! [ 1.0 ] ) ,
1391+ )
1392+ . err ( )
1393+ . expect ( "non-finite MH initial state should be rejected" ) ;
1394+ assert ! ( mh_error. to_string( ) . contains( "finite values" ) ) ;
1395+
1396+ let hmc_error =
1397+ HamiltonianMonteCarlo :: new ( log_posterior, gradient, initial_state. clone ( ) , 0.1 , 10 )
1398+ . err ( )
1399+ . expect ( "non-finite HMC initial state should be rejected" ) ;
1400+ assert ! ( hmc_error. to_string( ) . contains( "finite values" ) ) ;
1401+
1402+ let conditional_sampler =
1403+ |_params : & DVector < f64 > , _idx : usize , _rng : & mut ThreadRng | -> f64 { 0.0 } ;
1404+ let gibbs_error = GibbsSampler :: new ( vec ! [ conditional_sampler] , initial_state)
1405+ . err ( )
1406+ . expect ( "non-finite Gibbs initial state should be rejected" ) ;
1407+ assert ! ( gibbs_error. to_string( ) . contains( "finite values" ) ) ;
1408+ }
12531409 }
12541410}
0 commit comments