Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ serde = { version = "1.0", features = ["derive"], optional = true }
[dev-dependencies]
approx = "0.5"
criterion = { version = "0.5", features = ["html_reports"] }
serde_json = "1.0"

[features]
default = []
Expand Down
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ Add this to your `Cargo.toml`:
bayes-rs = "0.1.0"
```

Enable the optional `serde` feature when you want to serialize user-facing MCMC output:

```toml
[dependencies]
bayes-rs = { version = "0.1.0", features = ["serde"] }
```

### Simple Example

```rust
Expand Down Expand Up @@ -215,6 +222,20 @@ let trace_plot = TracePlot::new(&samples, 0)?; // Parameter 0
// Use trace_plot.values and trace_plot.iterations for visualization
```

With the optional `serde` feature enabled, `McmcDiagnostics`, `McmcDiagnosticSummary`,
`ParameterDiagnosticSummary`, `TracePlot`, and `MultiChainOutput` derive `Serialize` for
JSON or other serde formats:

```rust
let summary_json = serde_json::to_string_pretty(&output.summary)?;
```

Add `serde_json` or another serde format crate to your application to emit a concrete format.
These structs use Rust field names in their serialized form. Treat that JSON shape as a
convenient interchange format for the current API, not as a long-term storage schema.
JSON serializers may reject non-finite diagnostics such as `NaN` or `Infinity` from
degenerate chains; handle those cases before persisting JSON output.

## Real-World Example: Bayesian Linear Regression

```rust
Expand Down
31 changes: 31 additions & 0 deletions examples/serde_diagnostics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#[cfg(feature = "serde")]
use bayes_rs::diagnostics::{McmcDiagnostics, TracePlot};
#[cfg(feature = "serde")]
use nalgebra::DVector;

#[cfg(feature = "serde")]
fn main() -> bayes_rs::Result<()> {
let samples = vec![
DVector::from_vec(vec![0.8, 1.7]),
DVector::from_vec(vec![1.0, 2.0]),
DVector::from_vec(vec![1.2, 2.3]),
DVector::from_vec(vec![1.1, 2.1]),
];

let diagnostics = McmcDiagnostics::from_single_chain(&samples)?;
let trace = TracePlot::new(&samples, 0)?;

let summary_json = serde_json::to_string_pretty(&diagnostics.summary())
.expect("diagnostic summary should serialize");
let trace_json = serde_json::to_string_pretty(&trace).expect("trace plot should serialize");

println!("diagnostic summary:\n{summary_json}");
println!("trace plot:\n{trace_json}");

Ok(())
}

#[cfg(not(feature = "serde"))]
fn main() {
eprintln!("Run with `cargo run --example serde_diagnostics --features serde` to emit JSON.");
}
64 changes: 62 additions & 2 deletions src/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ pub const LOW_ESS_THRESHOLD: f64 = 400.0;
/// Diagnostics can be created from one chain with [`Self::from_single_chain`] or
/// from multiple chains with [`Self::from_multiple_chains`]. Vectors are ordered
/// by parameter index.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone, PartialEq)]
pub struct McmcDiagnostics {
/// Effective sample size for each parameter.
///
Expand Down Expand Up @@ -271,6 +272,7 @@ impl McmcDiagnostics {
}

/// Compact diagnostic summary for one model parameter.
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ParameterDiagnosticSummary {
/// Zero-based parameter index.
Expand All @@ -284,6 +286,7 @@ pub struct ParameterDiagnosticSummary {
}

/// Compact diagnostics summary focused on convergence reporting.
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone, PartialEq)]
pub struct McmcDiagnosticSummary {
/// Per-parameter R-hat, ESS, and MCSE values.
Expand Down Expand Up @@ -607,7 +610,8 @@ fn calculate_quantiles(samples: &[f64]) -> [f64; 5] {
}

/// Simple trace plot data for visualization.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone, PartialEq)]
pub struct TracePlot {
/// Parameter index represented by this trace.
pub parameter_index: usize,
Expand Down Expand Up @@ -903,4 +907,60 @@ mod tests {
assert_eq!(trace_plot.values, vec![1.0, 1.1, 0.9]);
assert_eq!(trace_plot.iterations, vec![0, 1, 2]);
}

#[cfg(feature = "serde")]
#[test]
fn test_diagnostics_summary_serializes_to_json() {
let diagnostics = McmcDiagnostics {
effective_sample_size: vec![100.0, LOW_ESS_THRESHOLD - 1.0],
r_hat: Some(vec![1.01, 1.2]),
mc_se: vec![0.1, 0.2],
mean: vec![0.0, 1.0],
std_dev: vec![1.0, 2.0],
quantiles: vec![[0.0; 5], [1.0; 5]],
};
let summary = diagnostics.summary();

let json = serde_json::to_value(&summary).unwrap();

assert_eq!(json["has_converged"], false);
assert_eq!(json["parameters"].as_array().unwrap().len(), 2);
assert_eq!(json["parameters"][0]["effective_sample_size"], 100.0);
assert_eq!(json["parameters"][1]["r_hat"], 1.2);
}

#[cfg(feature = "serde")]
#[test]
fn test_diagnostics_and_trace_plot_serialize_to_json() {
let samples = vec![
DVector::from_vec(vec![1.0, 2.0]),
DVector::from_vec(vec![1.1, 2.1]),
DVector::from_vec(vec![0.9, 1.9]),
];
let diagnostics = McmcDiagnostics {
effective_sample_size: vec![3.0, 3.0],
r_hat: None,
mc_se: vec![0.1, 0.2],
mean: vec![1.0, 2.0],
std_dev: vec![0.2, 0.2],
quantiles: vec![[0.8, 0.9, 1.0, 1.1, 1.2], [1.8, 1.9, 2.0, 2.1, 2.2]],
};
let trace_plot = TracePlot::new(&samples, 1).unwrap();

let diagnostics_json = serde_json::to_value(&diagnostics).unwrap();
let trace_plot_json = serde_json::to_value(&trace_plot).unwrap();

assert_eq!(
diagnostics_json["effective_sample_size"],
serde_json::json!([3.0, 3.0])
);
assert!(diagnostics_json["r_hat"].is_null());
assert_eq!(diagnostics_json["quantiles"].as_array().unwrap().len(), 2);
assert_eq!(trace_plot_json["parameter_index"], 1);
assert_eq!(
trace_plot_json["values"],
serde_json::json!([2.0, 2.1, 1.9])
);
assert_eq!(trace_plot_json["iterations"], serde_json::json!([0, 1, 2]));
}
}
19 changes: 18 additions & 1 deletion src/multi_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ use crate::samplers::Sampler;
use nalgebra::DVector;

/// Raw multi-chain samples plus their diagnostics summary.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[derive(Debug, Clone, PartialEq)]
pub struct MultiChainOutput {
/// Samples for each chain, preserving sampler order.
pub chains: Vec<Vec<DVector<f64>>>,
Expand Down Expand Up @@ -237,4 +238,20 @@ mod tests {
assert_eq!(samplers[0].current_state(), &DVector::from_vec(vec![0.0]));
assert_eq!(samplers[1].current_state(), &DVector::from_vec(vec![10.0]));
}

#[cfg(feature = "serde")]
#[test]
fn multi_chain_output_serializes_to_json() {
let mut samplers = [
DeterministicSampler::new(0.0),
DeterministicSampler::new(10.0),
];
let output = run_multiple_chains(&mut samplers, 2, 4).unwrap();

let json = serde_json::to_value(&output).unwrap();

assert_eq!(json["chains"].as_array().unwrap().len(), 2);
assert_eq!(json["summary"]["parameters"].as_array().unwrap().len(), 1);
assert_eq!(json["diagnostics"]["mean"].as_array().unwrap().len(), 1);
}
}
Loading