From e9962094871e238feae185e1baec015201cd016c Mon Sep 17 00:00:00 2001 From: kholdrex Date: Sun, 31 May 2026 18:18:04 -0500 Subject: [PATCH] feat: add serde support for MCMC outputs --- Cargo.toml | 1 + README.md | 21 ++++++++++++ examples/serde_diagnostics.rs | 31 +++++++++++++++++ src/diagnostics.rs | 64 +++++++++++++++++++++++++++++++++-- src/multi_chain.rs | 19 ++++++++++- 5 files changed, 133 insertions(+), 3 deletions(-) create mode 100644 examples/serde_diagnostics.rs diff --git a/Cargo.toml b/Cargo.toml index 5ece334..1d62016 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ serde = { version = "1.0", features = ["derive"], optional = true } [dev-dependencies] approx = "0.5" criterion = { version = "0.5", features = ["html_reports"] } +serde_json = "1.0" [features] default = [] diff --git a/README.md b/README.md index 24c22fe..b74069d 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,13 @@ Add this to your `Cargo.toml`: bayes-rs = "0.1.0" ``` +Enable the optional `serde` feature when you want to serialize user-facing MCMC output: + +```toml +[dependencies] +bayes-rs = { version = "0.1.0", features = ["serde"] } +``` + ### Simple Example ```rust @@ -215,6 +222,20 @@ let trace_plot = TracePlot::new(&samples, 0)?; // Parameter 0 // Use trace_plot.values and trace_plot.iterations for visualization ``` +With the optional `serde` feature enabled, `McmcDiagnostics`, `McmcDiagnosticSummary`, +`ParameterDiagnosticSummary`, `TracePlot`, and `MultiChainOutput` derive `Serialize` for +JSON or other serde formats: + +```rust +let summary_json = serde_json::to_string_pretty(&output.summary)?; +``` + +Add `serde_json` or another serde format crate to your application to emit a concrete format. +These structs use Rust field names in their serialized form. Treat that JSON shape as a +convenient interchange format for the current API, not as a long-term storage schema. +JSON serializers may reject non-finite diagnostics such as `NaN` or `Infinity` from +degenerate chains; handle those cases before persisting JSON output. + ## Real-World Example: Bayesian Linear Regression ```rust diff --git a/examples/serde_diagnostics.rs b/examples/serde_diagnostics.rs new file mode 100644 index 0000000..30e6e04 --- /dev/null +++ b/examples/serde_diagnostics.rs @@ -0,0 +1,31 @@ +#[cfg(feature = "serde")] +use bayes_rs::diagnostics::{McmcDiagnostics, TracePlot}; +#[cfg(feature = "serde")] +use nalgebra::DVector; + +#[cfg(feature = "serde")] +fn main() -> bayes_rs::Result<()> { + let samples = vec![ + DVector::from_vec(vec![0.8, 1.7]), + DVector::from_vec(vec![1.0, 2.0]), + DVector::from_vec(vec![1.2, 2.3]), + DVector::from_vec(vec![1.1, 2.1]), + ]; + + let diagnostics = McmcDiagnostics::from_single_chain(&samples)?; + let trace = TracePlot::new(&samples, 0)?; + + let summary_json = serde_json::to_string_pretty(&diagnostics.summary()) + .expect("diagnostic summary should serialize"); + let trace_json = serde_json::to_string_pretty(&trace).expect("trace plot should serialize"); + + println!("diagnostic summary:\n{summary_json}"); + println!("trace plot:\n{trace_json}"); + + Ok(()) +} + +#[cfg(not(feature = "serde"))] +fn main() { + eprintln!("Run with `cargo run --example serde_diagnostics --features serde` to emit JSON."); +} diff --git a/src/diagnostics.rs b/src/diagnostics.rs index f5ca39a..a9c9bd4 100644 --- a/src/diagnostics.rs +++ b/src/diagnostics.rs @@ -53,7 +53,8 @@ pub const LOW_ESS_THRESHOLD: f64 = 400.0; /// Diagnostics can be created from one chain with [`Self::from_single_chain`] or /// from multiple chains with [`Self::from_multiple_chains`]. Vectors are ordered /// by parameter index. -#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +#[derive(Debug, Clone, PartialEq)] pub struct McmcDiagnostics { /// Effective sample size for each parameter. /// @@ -271,6 +272,7 @@ impl McmcDiagnostics { } /// Compact diagnostic summary for one model parameter. +#[cfg_attr(feature = "serde", derive(serde::Serialize))] #[derive(Debug, Clone, Copy, PartialEq)] pub struct ParameterDiagnosticSummary { /// Zero-based parameter index. @@ -284,6 +286,7 @@ pub struct ParameterDiagnosticSummary { } /// Compact diagnostics summary focused on convergence reporting. +#[cfg_attr(feature = "serde", derive(serde::Serialize))] #[derive(Debug, Clone, PartialEq)] pub struct McmcDiagnosticSummary { /// Per-parameter R-hat, ESS, and MCSE values. @@ -607,7 +610,8 @@ fn calculate_quantiles(samples: &[f64]) -> [f64; 5] { } /// Simple trace plot data for visualization. -#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +#[derive(Debug, Clone, PartialEq)] pub struct TracePlot { /// Parameter index represented by this trace. pub parameter_index: usize, @@ -903,4 +907,60 @@ mod tests { assert_eq!(trace_plot.values, vec![1.0, 1.1, 0.9]); assert_eq!(trace_plot.iterations, vec![0, 1, 2]); } + + #[cfg(feature = "serde")] + #[test] + fn test_diagnostics_summary_serializes_to_json() { + let diagnostics = McmcDiagnostics { + effective_sample_size: vec![100.0, LOW_ESS_THRESHOLD - 1.0], + r_hat: Some(vec![1.01, 1.2]), + mc_se: vec![0.1, 0.2], + mean: vec![0.0, 1.0], + std_dev: vec![1.0, 2.0], + quantiles: vec![[0.0; 5], [1.0; 5]], + }; + let summary = diagnostics.summary(); + + let json = serde_json::to_value(&summary).unwrap(); + + assert_eq!(json["has_converged"], false); + assert_eq!(json["parameters"].as_array().unwrap().len(), 2); + assert_eq!(json["parameters"][0]["effective_sample_size"], 100.0); + assert_eq!(json["parameters"][1]["r_hat"], 1.2); + } + + #[cfg(feature = "serde")] + #[test] + fn test_diagnostics_and_trace_plot_serialize_to_json() { + let samples = vec![ + DVector::from_vec(vec![1.0, 2.0]), + DVector::from_vec(vec![1.1, 2.1]), + DVector::from_vec(vec![0.9, 1.9]), + ]; + let diagnostics = McmcDiagnostics { + effective_sample_size: vec![3.0, 3.0], + r_hat: None, + mc_se: vec![0.1, 0.2], + mean: vec![1.0, 2.0], + std_dev: vec![0.2, 0.2], + quantiles: vec![[0.8, 0.9, 1.0, 1.1, 1.2], [1.8, 1.9, 2.0, 2.1, 2.2]], + }; + let trace_plot = TracePlot::new(&samples, 1).unwrap(); + + let diagnostics_json = serde_json::to_value(&diagnostics).unwrap(); + let trace_plot_json = serde_json::to_value(&trace_plot).unwrap(); + + assert_eq!( + diagnostics_json["effective_sample_size"], + serde_json::json!([3.0, 3.0]) + ); + assert!(diagnostics_json["r_hat"].is_null()); + assert_eq!(diagnostics_json["quantiles"].as_array().unwrap().len(), 2); + assert_eq!(trace_plot_json["parameter_index"], 1); + assert_eq!( + trace_plot_json["values"], + serde_json::json!([2.0, 2.1, 1.9]) + ); + assert_eq!(trace_plot_json["iterations"], serde_json::json!([0, 1, 2])); + } } diff --git a/src/multi_chain.rs b/src/multi_chain.rs index b6ca6c9..7d11379 100644 --- a/src/multi_chain.rs +++ b/src/multi_chain.rs @@ -53,7 +53,8 @@ use crate::samplers::Sampler; use nalgebra::DVector; /// Raw multi-chain samples plus their diagnostics summary. -#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +#[derive(Debug, Clone, PartialEq)] pub struct MultiChainOutput { /// Samples for each chain, preserving sampler order. pub chains: Vec>>, @@ -237,4 +238,20 @@ mod tests { assert_eq!(samplers[0].current_state(), &DVector::from_vec(vec![0.0])); assert_eq!(samplers[1].current_state(), &DVector::from_vec(vec![10.0])); } + + #[cfg(feature = "serde")] + #[test] + fn multi_chain_output_serializes_to_json() { + let mut samplers = [ + DeterministicSampler::new(0.0), + DeterministicSampler::new(10.0), + ]; + let output = run_multiple_chains(&mut samplers, 2, 4).unwrap(); + + let json = serde_json::to_value(&output).unwrap(); + + assert_eq!(json["chains"].as_array().unwrap().len(), 2); + assert_eq!(json["summary"]["parameters"].as_array().unwrap().len(), 1); + assert_eq!(json["diagnostics"]["mean"].as_array().unwrap().len(), 1); + } }