|
| 1 | +#!/usr/bin/env julia |
| 2 | +#= |
| 3 | +ContinuousTransition Rules Benchmark Script |
| 4 | +
|
| 5 | +Run this script to benchmark the ContinuousTransition rules performance. |
| 6 | +Can be executed on different branches to compare optimizations. |
| 7 | +
|
| 8 | +Usage: |
| 9 | + julia --project=. benchmark/continuous_transition_bench.jl |
| 10 | + julia --project=. benchmark/continuous_transition_bench.jl quick # Quick mode (only small dims) |
| 11 | +
|
| 12 | +Output: Performance table showing timings for each rule and dimension. |
| 13 | +=# |
| 14 | + |
| 15 | +using Pkg |
| 16 | +Pkg.activate(dirname(@__DIR__)) |
| 17 | + |
| 18 | +using BenchmarkTools |
| 19 | +using ReactiveMP |
| 20 | +using BayesBase |
| 21 | +using ExponentialFamily |
| 22 | +using Random |
| 23 | +using LinearAlgebra |
| 24 | +using Distributions |
| 25 | +using Printf |
| 26 | + |
| 27 | +import ReactiveMP: CTMeta, @call_rule, @call_marginalrule |
| 28 | + |
| 29 | +# ============================================================================ |
| 30 | +# Test Data Generation |
| 31 | +# ============================================================================ |
| 32 | + |
| 33 | +function create_benchmark_data(dx, dy) |
| 34 | + rng = MersenneTwister(42) |
| 35 | + da = dx * dy |
| 36 | + |
| 37 | + transformation = a -> reshape(a, dy, dx) |
| 38 | + meta = CTMeta(transformation) |
| 39 | + |
| 40 | + Lx = rand(rng, dx, dx) |
| 41 | + Ly = rand(rng, dy, dy) |
| 42 | + La = rand(rng, da, da) |
| 43 | + |
| 44 | + μx, Σx = rand(rng, dx), Lx * Lx' + dx * I |
| 45 | + μy, Σy = rand(rng, dy), Ly * Ly' + dy * I |
| 46 | + μa, Σa = rand(rng, da), La * La' + da * I |
| 47 | + |
| 48 | + q_y = MvNormalMeanCovariance(μy, Σy) |
| 49 | + q_x = MvNormalMeanCovariance(μx, Σx) |
| 50 | + q_a = MvNormalMeanCovariance(μa, Σa) |
| 51 | + q_W = Wishart(dy + 1, Matrix{Float64}(I, dy, dy)) |
| 52 | + q_y_x = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) |
| 53 | + |
| 54 | + m_y = MvNormalMeanCovariance(μy, Σy) |
| 55 | + m_x = MvNormalMeanCovariance(μx, Σx) |
| 56 | + |
| 57 | + return (meta = meta, q_y = q_y, q_x = q_x, q_a = q_a, q_W = q_W, q_y_x = q_y_x, m_y = m_y, m_x = m_x) |
| 58 | +end |
| 59 | + |
| 60 | +# ============================================================================ |
| 61 | +# Benchmark Functions |
| 62 | +# ============================================================================ |
| 63 | + |
| 64 | +function bench_a_structured(data) |
| 65 | + @call_rule ContinuousTransition(:a, Marginalisation) (q_y_x = data.q_y_x, q_a = data.q_a, q_W = data.q_W, meta = data.meta) |
| 66 | +end |
| 67 | + |
| 68 | +function bench_a_meanfield(data) |
| 69 | + @call_rule ContinuousTransition(:a, Marginalisation) (q_y = data.q_y, q_x = data.q_x, q_a = data.q_a, q_W = data.q_W, meta = data.meta) |
| 70 | +end |
| 71 | + |
| 72 | +function bench_marginal_y_x(data) |
| 73 | + @call_marginalrule ContinuousTransition(:y_x) (m_y = data.m_y, m_x = data.m_x, q_a = data.q_a, q_W = data.q_W, meta = data.meta) |
| 74 | +end |
| 75 | + |
| 76 | +# ============================================================================ |
| 77 | +# Benchmark Runner |
| 78 | +# ============================================================================ |
| 79 | + |
| 80 | +function run_benchmarks(; quick_mode = false) |
| 81 | + println() |
| 82 | + println("=" ^ 80) |
| 83 | + println(" ContinuousTransition Rules Benchmark") |
| 84 | + println(" Branch: ", strip(read(`git rev-parse --abbrev-ref HEAD`, String))) |
| 85 | + println(" Commit: ", strip(read(`git rev-parse --short HEAD`, String))) |
| 86 | + println("=" ^ 80) |
| 87 | + println() |
| 88 | + |
| 89 | + if quick_mode |
| 90 | + test_dims = [(10, 10), (20, 20)] |
| 91 | + println(" Mode: QUICK (limited dimensions)") |
| 92 | + else |
| 93 | + test_dims = [(5, 5), (10, 10), (20, 20), (30, 30), (40, 40)] |
| 94 | + println(" Mode: FULL") |
| 95 | + end |
| 96 | + println() |
| 97 | + |
| 98 | + # Results storage |
| 99 | + results = Dict{String, Vector{Tuple{Int, Int, Float64}}}("a_structured" => [], "a_meanfield" => [], "marginal_y_x" => []) |
| 100 | + |
| 101 | + for (dx, dy) in test_dims |
| 102 | + println("-" ^ 60) |
| 103 | + @printf(" Benchmarking: dx=%d, dy=%d (da=%d)\n", dx, dy, dx*dy) |
| 104 | + println("-" ^ 60) |
| 105 | + |
| 106 | + data = create_benchmark_data(dx, dy) |
| 107 | + |
| 108 | + # Warm-up calls |
| 109 | + bench_a_structured(data) |
| 110 | + bench_a_meanfield(data) |
| 111 | + bench_marginal_y_x(data) |
| 112 | + |
| 113 | + # Benchmark a.jl structured |
| 114 | + t = @belapsed bench_a_structured($data) |
| 115 | + push!(results["a_structured"], (dx, dy, t * 1e6)) |
| 116 | + @printf(" a.jl Structured: %10.2f μs\n", t * 1e6) |
| 117 | + |
| 118 | + # Benchmark a.jl mean-field |
| 119 | + t = @belapsed bench_a_meanfield($data) |
| 120 | + push!(results["a_meanfield"], (dx, dy, t * 1e6)) |
| 121 | + @printf(" a.jl Mean-field: %10.2f μs\n", t * 1e6) |
| 122 | + |
| 123 | + # Benchmark marginals.jl |
| 124 | + t = @belapsed bench_marginal_y_x($data) |
| 125 | + push!(results["marginal_y_x"], (dx, dy, t * 1e6)) |
| 126 | + @printf(" marginals.jl y_x: %10.2f μs\n", t * 1e6) |
| 127 | + |
| 128 | + println() |
| 129 | + end |
| 130 | + |
| 131 | + # Print summary table |
| 132 | + println("=" ^ 80) |
| 133 | + println(" SUMMARY TABLE (times in μs)") |
| 134 | + println("=" ^ 80) |
| 135 | + println() |
| 136 | + |
| 137 | + # Header |
| 138 | + @printf(" %-12s", "Dimensions") |
| 139 | + for rule in ["a_structured", "a_meanfield", "marginal_y_x"] |
| 140 | + @printf(" | %14s", replace(rule, "_" => " ")) |
| 141 | + end |
| 142 | + println() |
| 143 | + println(" " * "-" ^ 12, " | ", "-" ^ 14, " | ", "-" ^ 14, " | ", "-" ^ 14) |
| 144 | + |
| 145 | + # Data rows |
| 146 | + for i in eachindex(test_dims) |
| 147 | + dx, dy = test_dims[i] |
| 148 | + @printf(" %4d × %-4d ", dx, dy) |
| 149 | + @printf(" | %14.2f", results["a_structured"][i][3]) |
| 150 | + @printf(" | %14.2f", results["a_meanfield"][i][3]) |
| 151 | + @printf(" | %14.2f", results["marginal_y_x"][i][3]) |
| 152 | + println() |
| 153 | + end |
| 154 | + |
| 155 | + println() |
| 156 | + println("=" ^ 80) |
| 157 | + println(" Benchmark Complete") |
| 158 | + println("=" ^ 80) |
| 159 | + println() |
| 160 | + |
| 161 | + return results |
| 162 | +end |
| 163 | + |
| 164 | +# ============================================================================ |
| 165 | +# Main |
| 166 | +# ============================================================================ |
| 167 | + |
| 168 | +quick_mode = length(ARGS) > 0 && ARGS[1] == "quick" |
| 169 | +run_benchmarks(quick_mode = quick_mode) |
0 commit comments