diff --git a/examples/conjugate_models.rs b/examples/conjugate_models.rs new file mode 100644 index 0000000..cc8ef35 --- /dev/null +++ b/examples/conjugate_models.rs @@ -0,0 +1,71 @@ +//! Conjugate Bayesian model examples. +//! +//! These examples use closed-form posterior updates, then build the matching +//! bayes-rs distributions for posterior summaries and simple plug-in predictive +//! checks. The plug-in checks are intentionally lightweight examples; they are +//! not substitutes for the full posterior-predictive distributions that these +//! conjugate pairs also provide in closed form. They are deterministic and quick +//! enough to use as executable docs. + +use bayes_rs::distributions::{Beta, Binomial, DiscreteDistribution, Distribution, Gamma, Poisson}; + +fn beta_binomial_posterior( + prior_alpha: f64, + prior_beta: f64, + successes: u64, + trials: u64, +) -> bayes_rs::Result { + let failures = trials + .checked_sub(successes) + .ok_or_else(|| bayes_rs::BayesError::invalid_parameter("successes cannot exceed trials"))?; + + Beta::new(prior_alpha + successes as f64, prior_beta + failures as f64) +} + +fn gamma_poisson_posterior( + prior_shape: f64, + prior_rate: f64, + counts: &[u64], +) -> bayes_rs::Result { + let observed_events: u64 = counts.iter().sum(); + // bayes-rs parameterizes Gamma as shape/rate, so exposure increments rate. + Gamma::new( + prior_shape + observed_events as f64, + prior_rate + counts.len() as f64, + ) +} + +fn main() -> bayes_rs::Result<()> { + let conversion_posterior = beta_binomial_posterior(2.0, 2.0, 42, 120)?; + let posterior_mean = conversion_posterior.mean(); + let predictive_successes = Binomial::new(25, posterior_mean)?; + + println!("Beta-binomial conversion-rate update"); + println!(" Posterior mean success probability: {posterior_mean:.3}"); + println!( + " Plug-in predictive probability of at least 8 successes in 25 trials: {:.3}", + (8..=25) + .map(|successes| predictive_successes.pmf(successes)) + .sum::() + ); + + let daily_defects = [3, 4, 2, 5, 4, 1, 3]; + let defect_rate_posterior = gamma_poisson_posterior(1.5, 1.0, &daily_defects)?; + let posterior_rate = defect_rate_posterior.mean(); + let next_day_defects = Poisson::new(posterior_rate)?; + + println!("Gamma-Poisson count-rate update"); + println!(" Posterior mean event rate: {posterior_rate:.3}"); + println!( + " Plug-in predictive probability of at most 2 events tomorrow: {:.3}", + (0..=2) + .map(|count| next_day_defects.pmf(count)) + .sum::() + ); + println!( + " Posterior density at rate 3.0: {:.3}", + defect_rate_posterior.pdf(3.0) + ); + + Ok(()) +} diff --git a/tests/conjugate_models_test.rs b/tests/conjugate_models_test.rs new file mode 100644 index 0000000..7890758 --- /dev/null +++ b/tests/conjugate_models_test.rs @@ -0,0 +1,50 @@ +use approx::assert_relative_eq; +use bayes_rs::distributions::{Beta, Binomial, DiscreteDistribution, Gamma, Poisson}; + +#[test] +fn beta_binomial_update_matches_closed_form_posterior() { + let prior_alpha = 2.0; + let prior_beta = 2.0; + let successes = 42; + let trials = 120; + let failures = trials - successes; + + let posterior = Beta::new(prior_alpha + successes as f64, prior_beta + failures as f64) + .expect("posterior parameters should be valid"); + + assert_relative_eq!(posterior.alpha(), 44.0); + assert_relative_eq!(posterior.beta(), 80.0); + assert_relative_eq!(posterior.mean(), 44.0 / 124.0); + + let plug_in_predictive = + Binomial::new(25, posterior.mean()).expect("posterior mean is a probability"); + let probability_at_least_eight: f64 = (8..=25).map(|k| plug_in_predictive.pmf(k)).sum(); + + assert!(probability_at_least_eight > 0.71); + assert!(probability_at_least_eight < 0.72); +} + +#[test] +fn gamma_poisson_update_matches_closed_form_posterior() { + let prior_shape = 1.5; + let prior_rate = 1.0; + let counts = [3_u64, 4, 2, 5, 4, 1, 3]; + let observed_events: u64 = counts.iter().sum(); + + let posterior = Gamma::new( + prior_shape + observed_events as f64, + prior_rate + counts.len() as f64, + ) + .expect("posterior parameters should be valid"); + + assert_relative_eq!(posterior.shape(), 23.5); + assert_relative_eq!(posterior.rate(), 8.0); + assert_relative_eq!(posterior.mean(), 23.5 / 8.0); + + let plug_in_predictive = + Poisson::new(posterior.mean()).expect("posterior mean is a valid rate"); + let probability_at_most_two: f64 = (0..=2).map(|k| plug_in_predictive.pmf(k)).sum(); + + assert!(probability_at_most_two > 0.42); + assert!(probability_at_most_two < 0.44); +}