Skip to content

Commit 1b9a7a1

Browse files
authored
Merge pull request #12 from SyntaxSpirits/fix/sampler-input-validation
2 parents 0f76fdd + a403309 commit 1b9a7a1

1 file changed

Lines changed: 181 additions & 25 deletions

File tree

src/samplers.rs

Lines changed: 181 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,33 @@ pub struct WarmupRun {
148148
pub metadata: WarmupMetadata,
149149
}
150150

151+
fn validate_initial_state(initial_state: &DVector<f64>) -> Result<()> {
152+
if initial_state.is_empty() {
153+
return Err(BayesError::invalid_parameter(
154+
"Initial state must have at least one dimension",
155+
));
156+
}
157+
158+
if initial_state.iter().any(|value| !value.is_finite()) {
159+
return Err(BayesError::invalid_parameter(
160+
"Initial state must contain only finite values",
161+
));
162+
}
163+
164+
Ok(())
165+
}
166+
167+
fn validate_positive_finite_vector(values: &DVector<f64>, message: &'static str) -> Result<()> {
168+
if values
169+
.iter()
170+
.any(|&value| value <= 0.0 || !value.is_finite())
171+
{
172+
return Err(BayesError::invalid_parameter(message));
173+
}
174+
175+
Ok(())
176+
}
177+
151178
/// Trait for MCMC samplers
152179
pub trait Sampler {
153180
/// Sample from the posterior distribution
@@ -264,18 +291,19 @@ where
264291
proposal_std: DVector<f64>,
265292
rng: R,
266293
) -> Result<Self> {
294+
validate_initial_state(&initial_state)?;
295+
267296
if initial_state.len() != proposal_std.len() {
268297
return Err(BayesError::dimension_mismatch(
269298
initial_state.len(),
270299
proposal_std.len(),
271300
));
272301
}
273302

274-
if proposal_std.iter().any(|&std| std <= 0.0) {
275-
return Err(BayesError::invalid_parameter(
276-
"All proposal standard deviations must be positive",
277-
));
278-
}
303+
validate_positive_finite_vector(
304+
&proposal_std,
305+
"All proposal standard deviations must be positive and finite",
306+
)?;
279307

280308
let current_log_posterior = log_posterior(&initial_state);
281309
if !current_log_posterior.is_finite() {
@@ -304,11 +332,10 @@ where
304332
));
305333
}
306334

307-
if proposal_std.iter().any(|&std| std <= 0.0) {
308-
return Err(BayesError::invalid_parameter(
309-
"All proposal standard deviations must be positive",
310-
));
311-
}
335+
validate_positive_finite_vector(
336+
&proposal_std,
337+
"All proposal standard deviations must be positive and finite",
338+
)?;
312339

313340
self.proposal_std = proposal_std;
314341
Ok(())
@@ -436,6 +463,8 @@ where
436463
initial_state: DVector<f64>,
437464
rng: R,
438465
) -> Result<Self> {
466+
validate_initial_state(&initial_state)?;
467+
439468
if conditional_samplers.len() != initial_state.len() {
440469
return Err(BayesError::dimension_mismatch(
441470
conditional_samplers.len(),
@@ -558,8 +587,12 @@ where
558587
n_leapfrog: usize,
559588
rng: R,
560589
) -> Result<Self> {
561-
if step_size <= 0.0 {
562-
return Err(BayesError::invalid_parameter("Step size must be positive"));
590+
validate_initial_state(&initial_state)?;
591+
592+
if step_size <= 0.0 || !step_size.is_finite() {
593+
return Err(BayesError::invalid_parameter(
594+
"Step size must be positive and finite",
595+
));
563596
}
564597

565598
if n_leapfrog == 0 {
@@ -601,11 +634,10 @@ where
601634
));
602635
}
603636

604-
if mass_matrix.iter().any(|&m| m <= 0.0) {
605-
return Err(BayesError::invalid_parameter(
606-
"All mass matrix elements must be positive",
607-
));
608-
}
637+
validate_positive_finite_vector(
638+
&mass_matrix,
639+
"All mass matrix elements must be positive and finite",
640+
)?;
609641

610642
self.mass_matrix = mass_matrix;
611643
Ok(())
@@ -1239,16 +1271,140 @@ mod tests {
12391271
}
12401272

12411273
#[test]
1242-
fn test_invalid_parameters() {
1243-
let log_posterior = |params: &DVector<f64>| -> f64 {
1244-
let normal = Normal::new(0.0, 1.0).unwrap();
1245-
normal.log_pdf(params[0])
1246-
};
1274+
fn test_metropolis_hastings_rejects_non_positive_and_non_finite_proposal_std() {
1275+
let log_posterior = |params: &DVector<f64>| -> f64 { -0.5 * params[0] * params[0] };
1276+
let initial_state = DVector::from_vec(vec![0.0]);
12471277

1278+
for invalid_std in [0.0, -1.0, f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1279+
let error = MetropolisHastings::new(
1280+
log_posterior,
1281+
initial_state.clone(),
1282+
DVector::from_vec(vec![invalid_std]),
1283+
)
1284+
.err()
1285+
.expect("invalid proposal standard deviation should be rejected");
1286+
assert!(error.to_string().contains("positive and finite"));
1287+
}
1288+
}
1289+
1290+
#[test]
1291+
fn test_metropolis_hastings_set_proposal_std_rejects_non_finite_values() {
1292+
let log_posterior = |params: &DVector<f64>| -> f64 { -0.5 * params[0] * params[0] };
1293+
let mut sampler = MetropolisHastings::new(
1294+
log_posterior,
1295+
DVector::from_vec(vec![0.0]),
1296+
DVector::from_vec(vec![1.0]),
1297+
)
1298+
.unwrap();
1299+
1300+
for invalid_std in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1301+
let error = match sampler.set_proposal_std(DVector::from_vec(vec![invalid_std])) {
1302+
Ok(()) => panic!("invalid proposal standard deviation should be rejected"),
1303+
Err(error) => error,
1304+
};
1305+
assert!(error.to_string().contains("positive and finite"));
1306+
}
1307+
}
1308+
1309+
#[test]
1310+
fn test_hmc_rejects_non_positive_and_non_finite_step_size() {
1311+
let log_posterior = |params: &DVector<f64>| -> f64 { -0.5 * params[0] * params[0] };
1312+
let gradient =
1313+
|params: &DVector<f64>| -> DVector<f64> { DVector::from_vec(vec![-params[0]]) };
12481314
let initial_state = DVector::from_vec(vec![0.0]);
1249-
let bad_proposal_std = DVector::from_vec(vec![0.0]); // Invalid: zero std
12501315

1251-
let sampler = MetropolisHastings::new(log_posterior, initial_state, bad_proposal_std);
1252-
assert!(sampler.is_err());
1316+
for invalid_step_size in [0.0, -0.1, f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1317+
let error = HamiltonianMonteCarlo::new(
1318+
log_posterior,
1319+
gradient,
1320+
initial_state.clone(),
1321+
invalid_step_size,
1322+
10,
1323+
)
1324+
.err()
1325+
.expect("invalid HMC step size should be rejected");
1326+
assert!(error.to_string().contains("positive and finite"));
1327+
}
1328+
}
1329+
1330+
#[test]
1331+
fn test_hmc_set_mass_matrix_rejects_non_positive_and_non_finite_values() {
1332+
let log_posterior = |params: &DVector<f64>| -> f64 { -0.5 * params[0] * params[0] };
1333+
let gradient =
1334+
|params: &DVector<f64>| -> DVector<f64> { DVector::from_vec(vec![-params[0]]) };
1335+
let mut sampler = HamiltonianMonteCarlo::new(
1336+
log_posterior,
1337+
gradient,
1338+
DVector::from_vec(vec![0.0]),
1339+
0.1,
1340+
10,
1341+
)
1342+
.unwrap();
1343+
1344+
for invalid_mass in [0.0, -1.0, f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1345+
let error = match sampler.set_mass_matrix(DVector::from_vec(vec![invalid_mass])) {
1346+
Ok(()) => panic!("invalid mass matrix should be rejected"),
1347+
Err(error) => error,
1348+
};
1349+
assert!(error.to_string().contains("positive and finite"));
1350+
}
1351+
}
1352+
1353+
#[test]
1354+
fn test_samplers_reject_zero_dimensional_initial_states() {
1355+
let log_posterior = |_params: &DVector<f64>| -> f64 { 0.0 };
1356+
let gradient = |_params: &DVector<f64>| -> DVector<f64> { DVector::zeros(0) };
1357+
let empty_state = DVector::zeros(0);
1358+
1359+
let mh_error =
1360+
MetropolisHastings::new(log_posterior, empty_state.clone(), DVector::zeros(0))
1361+
.err()
1362+
.expect("zero-dimensional MH initial state should be rejected");
1363+
assert!(mh_error.to_string().contains("at least one dimension"));
1364+
1365+
let hmc_error =
1366+
HamiltonianMonteCarlo::new(log_posterior, gradient, empty_state.clone(), 0.1, 10)
1367+
.err()
1368+
.expect("zero-dimensional HMC initial state should be rejected");
1369+
assert!(hmc_error.to_string().contains("at least one dimension"));
1370+
1371+
let samplers: Vec<_> = Vec::<fn(&DVector<f64>, usize, &mut ThreadRng) -> f64>::new();
1372+
let gibbs_error = GibbsSampler::new(samplers, empty_state)
1373+
.err()
1374+
.expect("zero-dimensional Gibbs initial state should be rejected");
1375+
assert!(gibbs_error.to_string().contains("at least one dimension"));
1376+
}
1377+
1378+
#[test]
1379+
fn test_samplers_reject_non_finite_initial_states() {
1380+
let log_posterior = |params: &DVector<f64>| -> f64 { -0.5 * params[0] * params[0] };
1381+
let gradient =
1382+
|params: &DVector<f64>| -> DVector<f64> { DVector::from_vec(vec![-params[0]]) };
1383+
1384+
for invalid_value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1385+
let initial_state = DVector::from_vec(vec![invalid_value]);
1386+
1387+
let mh_error = MetropolisHastings::new(
1388+
log_posterior,
1389+
initial_state.clone(),
1390+
DVector::from_vec(vec![1.0]),
1391+
)
1392+
.err()
1393+
.expect("non-finite MH initial state should be rejected");
1394+
assert!(mh_error.to_string().contains("finite values"));
1395+
1396+
let hmc_error =
1397+
HamiltonianMonteCarlo::new(log_posterior, gradient, initial_state.clone(), 0.1, 10)
1398+
.err()
1399+
.expect("non-finite HMC initial state should be rejected");
1400+
assert!(hmc_error.to_string().contains("finite values"));
1401+
1402+
let conditional_sampler =
1403+
|_params: &DVector<f64>, _idx: usize, _rng: &mut ThreadRng| -> f64 { 0.0 };
1404+
let gibbs_error = GibbsSampler::new(vec![conditional_sampler], initial_state)
1405+
.err()
1406+
.expect("non-finite Gibbs initial state should be rejected");
1407+
assert!(gibbs_error.to_string().contains("finite values"));
1408+
}
12531409
}
12541410
}

0 commit comments

Comments
 (0)