Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions examples/conjugate_models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//! Conjugate Bayesian model examples.
//!
//! These examples use closed-form posterior updates, then build the matching
//! bayes-rs distributions for posterior summaries and simple plug-in predictive
//! checks. The plug-in checks are intentionally lightweight examples; they are
//! not substitutes for the full posterior-predictive distributions that these
//! conjugate pairs also provide in closed form. They are deterministic and quick
//! enough to use as executable docs.

use bayes_rs::distributions::{Beta, Binomial, DiscreteDistribution, Distribution, Gamma, Poisson};

fn beta_binomial_posterior(
prior_alpha: f64,
prior_beta: f64,
successes: u64,
trials: u64,
) -> bayes_rs::Result<Beta> {
let failures = trials
.checked_sub(successes)
.ok_or_else(|| bayes_rs::BayesError::invalid_parameter("successes cannot exceed trials"))?;

Beta::new(prior_alpha + successes as f64, prior_beta + failures as f64)
}

fn gamma_poisson_posterior(
prior_shape: f64,
prior_rate: f64,
counts: &[u64],
) -> bayes_rs::Result<Gamma> {
let observed_events: u64 = counts.iter().sum();
// bayes-rs parameterizes Gamma as shape/rate, so exposure increments rate.
Gamma::new(
prior_shape + observed_events as f64,
prior_rate + counts.len() as f64,
)
}

fn main() -> bayes_rs::Result<()> {
let conversion_posterior = beta_binomial_posterior(2.0, 2.0, 42, 120)?;
let posterior_mean = conversion_posterior.mean();
let predictive_successes = Binomial::new(25, posterior_mean)?;

println!("Beta-binomial conversion-rate update");
println!(" Posterior mean success probability: {posterior_mean:.3}");
println!(
" Plug-in predictive probability of at least 8 successes in 25 trials: {:.3}",
(8..=25)
.map(|successes| predictive_successes.pmf(successes))
.sum::<f64>()
);

let daily_defects = [3, 4, 2, 5, 4, 1, 3];
let defect_rate_posterior = gamma_poisson_posterior(1.5, 1.0, &daily_defects)?;
let posterior_rate = defect_rate_posterior.mean();
let next_day_defects = Poisson::new(posterior_rate)?;

println!("Gamma-Poisson count-rate update");
println!(" Posterior mean event rate: {posterior_rate:.3}");
println!(
" Plug-in predictive probability of at most 2 events tomorrow: {:.3}",
(0..=2)
.map(|count| next_day_defects.pmf(count))
.sum::<f64>()
);
println!(
" Posterior density at rate 3.0: {:.3}",
defect_rate_posterior.pdf(3.0)
);

Ok(())
}
50 changes: 50 additions & 0 deletions tests/conjugate_models_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use approx::assert_relative_eq;
use bayes_rs::distributions::{Beta, Binomial, DiscreteDistribution, Gamma, Poisson};

#[test]
fn beta_binomial_update_matches_closed_form_posterior() {
let prior_alpha = 2.0;
let prior_beta = 2.0;
let successes = 42;
let trials = 120;
let failures = trials - successes;

let posterior = Beta::new(prior_alpha + successes as f64, prior_beta + failures as f64)
.expect("posterior parameters should be valid");

assert_relative_eq!(posterior.alpha(), 44.0);
assert_relative_eq!(posterior.beta(), 80.0);
assert_relative_eq!(posterior.mean(), 44.0 / 124.0);

let plug_in_predictive =
Binomial::new(25, posterior.mean()).expect("posterior mean is a probability");
let probability_at_least_eight: f64 = (8..=25).map(|k| plug_in_predictive.pmf(k)).sum();

assert!(probability_at_least_eight > 0.71);
assert!(probability_at_least_eight < 0.72);
}

#[test]
fn gamma_poisson_update_matches_closed_form_posterior() {
let prior_shape = 1.5;
let prior_rate = 1.0;
let counts = [3_u64, 4, 2, 5, 4, 1, 3];
let observed_events: u64 = counts.iter().sum();

let posterior = Gamma::new(
prior_shape + observed_events as f64,
prior_rate + counts.len() as f64,
)
.expect("posterior parameters should be valid");

assert_relative_eq!(posterior.shape(), 23.5);
assert_relative_eq!(posterior.rate(), 8.0);
assert_relative_eq!(posterior.mean(), 23.5 / 8.0);

let plug_in_predictive =
Poisson::new(posterior.mean()).expect("posterior mean is a valid rate");
let probability_at_most_two: f64 = (0..=2).map(|k| plug_in_predictive.pmf(k)).sum();

assert!(probability_at_most_two > 0.42);
assert!(probability_at_most_two < 0.44);
}
Loading