Skip to content

Commit 9ed5aef

Browse files
authored
Merge pull request #571 from ReactiveBayes/optimize-double-loops-ct
Optimize double loops in CT node
2 parents 8b3a7d0 + 719a2e7 commit 9ed5aef

5 files changed

Lines changed: 302 additions & 16 deletions

File tree

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#!/usr/bin/env julia
2+
#=
3+
ContinuousTransition Rules Benchmark Script
4+
5+
Run this script to benchmark the ContinuousTransition rules performance.
6+
Can be executed on different branches to compare optimizations.
7+
8+
Usage:
9+
julia --project=. benchmark/continuous_transition_bench.jl
10+
julia --project=. benchmark/continuous_transition_bench.jl quick # Quick mode (only small dims)
11+
12+
Output: Performance table showing timings for each rule and dimension.
13+
=#
14+
15+
using Pkg
16+
Pkg.activate(dirname(@__DIR__))
17+
18+
using BenchmarkTools
19+
using ReactiveMP
20+
using BayesBase
21+
using ExponentialFamily
22+
using Random
23+
using LinearAlgebra
24+
using Distributions
25+
using Printf
26+
27+
import ReactiveMP: CTMeta, @call_rule, @call_marginalrule
28+
29+
# ============================================================================
30+
# Test Data Generation
31+
# ============================================================================
32+
33+
function create_benchmark_data(dx, dy)
34+
rng = MersenneTwister(42)
35+
da = dx * dy
36+
37+
transformation = a -> reshape(a, dy, dx)
38+
meta = CTMeta(transformation)
39+
40+
Lx = rand(rng, dx, dx)
41+
Ly = rand(rng, dy, dy)
42+
La = rand(rng, da, da)
43+
44+
μx, Σx = rand(rng, dx), Lx * Lx' + dx * I
45+
μy, Σy = rand(rng, dy), Ly * Ly' + dy * I
46+
μa, Σa = rand(rng, da), La * La' + da * I
47+
48+
q_y = MvNormalMeanCovariance(μy, Σy)
49+
q_x = MvNormalMeanCovariance(μx, Σx)
50+
q_a = MvNormalMeanCovariance(μa, Σa)
51+
q_W = Wishart(dy + 1, Matrix{Float64}(I, dy, dy))
52+
q_y_x = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx])
53+
54+
m_y = MvNormalMeanCovariance(μy, Σy)
55+
m_x = MvNormalMeanCovariance(μx, Σx)
56+
57+
return (meta = meta, q_y = q_y, q_x = q_x, q_a = q_a, q_W = q_W, q_y_x = q_y_x, m_y = m_y, m_x = m_x)
58+
end
59+
60+
# ============================================================================
61+
# Benchmark Functions
62+
# ============================================================================
63+
64+
function bench_a_structured(data)
65+
@call_rule ContinuousTransition(:a, Marginalisation) (q_y_x = data.q_y_x, q_a = data.q_a, q_W = data.q_W, meta = data.meta)
66+
end
67+
68+
function bench_a_meanfield(data)
69+
@call_rule ContinuousTransition(:a, Marginalisation) (q_y = data.q_y, q_x = data.q_x, q_a = data.q_a, q_W = data.q_W, meta = data.meta)
70+
end
71+
72+
function bench_marginal_y_x(data)
73+
@call_marginalrule ContinuousTransition(:y_x) (m_y = data.m_y, m_x = data.m_x, q_a = data.q_a, q_W = data.q_W, meta = data.meta)
74+
end
75+
76+
# ============================================================================
77+
# Benchmark Runner
78+
# ============================================================================
79+
80+
function run_benchmarks(; quick_mode = false)
81+
println()
82+
println("=" ^ 80)
83+
println(" ContinuousTransition Rules Benchmark")
84+
println(" Branch: ", strip(read(`git rev-parse --abbrev-ref HEAD`, String)))
85+
println(" Commit: ", strip(read(`git rev-parse --short HEAD`, String)))
86+
println("=" ^ 80)
87+
println()
88+
89+
if quick_mode
90+
test_dims = [(10, 10), (20, 20)]
91+
println(" Mode: QUICK (limited dimensions)")
92+
else
93+
test_dims = [(5, 5), (10, 10), (20, 20), (30, 30), (40, 40)]
94+
println(" Mode: FULL")
95+
end
96+
println()
97+
98+
# Results storage
99+
results = Dict{String, Vector{Tuple{Int, Int, Float64}}}("a_structured" => [], "a_meanfield" => [], "marginal_y_x" => [])
100+
101+
for (dx, dy) in test_dims
102+
println("-" ^ 60)
103+
@printf(" Benchmarking: dx=%d, dy=%d (da=%d)\n", dx, dy, dx*dy)
104+
println("-" ^ 60)
105+
106+
data = create_benchmark_data(dx, dy)
107+
108+
# Warm-up calls
109+
bench_a_structured(data)
110+
bench_a_meanfield(data)
111+
bench_marginal_y_x(data)
112+
113+
# Benchmark a.jl structured
114+
t = @belapsed bench_a_structured($data)
115+
push!(results["a_structured"], (dx, dy, t * 1e6))
116+
@printf(" a.jl Structured: %10.2f μs\n", t * 1e6)
117+
118+
# Benchmark a.jl mean-field
119+
t = @belapsed bench_a_meanfield($data)
120+
push!(results["a_meanfield"], (dx, dy, t * 1e6))
121+
@printf(" a.jl Mean-field: %10.2f μs\n", t * 1e6)
122+
123+
# Benchmark marginals.jl
124+
t = @belapsed bench_marginal_y_x($data)
125+
push!(results["marginal_y_x"], (dx, dy, t * 1e6))
126+
@printf(" marginals.jl y_x: %10.2f μs\n", t * 1e6)
127+
128+
println()
129+
end
130+
131+
# Print summary table
132+
println("=" ^ 80)
133+
println(" SUMMARY TABLE (times in μs)")
134+
println("=" ^ 80)
135+
println()
136+
137+
# Header
138+
@printf(" %-12s", "Dimensions")
139+
for rule in ["a_structured", "a_meanfield", "marginal_y_x"]
140+
@printf(" | %14s", replace(rule, "_" => " "))
141+
end
142+
println()
143+
println(" " * "-" ^ 12, " | ", "-" ^ 14, " | ", "-" ^ 14, " | ", "-" ^ 14)
144+
145+
# Data rows
146+
for i in eachindex(test_dims)
147+
dx, dy = test_dims[i]
148+
@printf(" %4d × %-4d ", dx, dy)
149+
@printf(" | %14.2f", results["a_structured"][i][3])
150+
@printf(" | %14.2f", results["a_meanfield"][i][3])
151+
@printf(" | %14.2f", results["marginal_y_x"][i][3])
152+
println()
153+
end
154+
155+
println()
156+
println("=" ^ 80)
157+
println(" Benchmark Complete")
158+
println("=" ^ 80)
159+
println()
160+
161+
return results
162+
end
163+
164+
# ============================================================================
165+
# Main
166+
# ============================================================================
167+
168+
quick_mode = length(ARGS) > 0 && ARGS[1] == "quick"
169+
run_benchmarks(quick_mode = quick_mode)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
using BenchmarkTools
2+
using ReactiveMP
3+
using BayesBase
4+
using ExponentialFamily
5+
using Random
6+
using LinearAlgebra
7+
using Distributions
8+
using StableRNGs
9+
10+
import ReactiveMP: CTMeta, Marginal, Message, @call_rule, @call_marginalrule
11+
12+
"""
13+
Creates test data for ContinuousTransition benchmarks.
14+
Returns distributions and meta needed to call the rules.
15+
"""
16+
function create_ct_benchmark_data(dx, dy)
17+
rng = StableRNGs(42)
18+
da = dx * dy # For linear transformation a -> reshape(a, dy, dx)
19+
20+
# Transformation function
21+
transformation = a -> reshape(a, dy, dx)
22+
meta = CTMeta(transformation)
23+
24+
# Create covariance matrices
25+
Lx = rand(rng, dx, dx)
26+
Ly = rand(rng, dy, dy)
27+
La = rand(rng, da, da)
28+
29+
μx, Σx = rand(rng, dx), Lx * Lx' + dx * I
30+
μy, Σy = rand(rng, dy), Ly * Ly' + dy * I
31+
μa, Σa = rand(rng, da), La * La' + da * I
32+
33+
# Create distributions for mean-field factorization
34+
q_y = MvNormalMeanCovariance(μy, Σy)
35+
q_x = MvNormalMeanCovariance(μx, Σx)
36+
q_a = MvNormalMeanCovariance(μa, Σa)
37+
q_W = Wishart(dy + 1, Matrix{Float64}(I, dy, dy))
38+
39+
# Create joint distribution for structured factorization
40+
q_y_x = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx])
41+
42+
# Create messages for marginal rule
43+
m_y = MvNormalMeanCovariance(μy, Σy)
44+
m_x = MvNormalMeanCovariance(μx, Σx)
45+
46+
return (meta = meta, q_y = q_y, q_x = q_x, q_a = q_a, q_W = q_W, q_y_x = q_y_x, m_y = m_y, m_x = m_x)
47+
end
48+
49+
"""
50+
Adds ContinuousTransition rule benchmarks to the suite.
51+
"""
52+
function add_continuous_transition_rule_benchmarks(SUITE)
53+
SUITE["ContinuousTransition"] = BenchmarkGroup()
54+
55+
add_continuous_transition_a_benchmarks(SUITE["ContinuousTransition"])
56+
add_continuous_transition_marginals_benchmarks(SUITE["ContinuousTransition"])
57+
end
58+
59+
function add_continuous_transition_a_benchmarks(SUITE)
60+
SUITE["a"] = BenchmarkGroup(["Rules", "ContinuousTransition"])
61+
62+
# Test dimensions: (dx, dy)
63+
test_dims = [(5, 5), (10, 10), (20, 20), (30, 30)]
64+
65+
for (dx, dy) in test_dims
66+
data = create_ct_benchmark_data(dx, dy)
67+
68+
# Structured VMP: q(y,x) joint
69+
SUITE["a"]["Structured"]["dx=$(dx), dy=$(dy)"] = @benchmarkable begin
70+
@call_rule ContinuousTransition(:a, Marginalisation) (q_y_x = $data.q_y_x, q_a = $data.q_a, q_W = $data.q_W, meta = $data.meta)
71+
end
72+
73+
# Mean-field VMP: q(y)q(x)q(a)q(W)
74+
SUITE["a"]["Mean-field"]["dx=$(dx), dy=$(dy)"] = @benchmarkable begin
75+
@call_rule ContinuousTransition(:a, Marginalisation) (q_y = $data.q_y, q_x = $data.q_x, q_a = $data.q_a, q_W = $data.q_W, meta = $data.meta)
76+
end
77+
end
78+
end
79+
80+
function add_continuous_transition_marginals_benchmarks(SUITE)
81+
SUITE["marginals"] = BenchmarkGroup(["Rules", "ContinuousTransition"])
82+
83+
# Test dimensions: (dx, dy)
84+
test_dims = [(5, 5), (10, 10), (20, 20), (30, 30)]
85+
86+
for (dx, dy) in test_dims
87+
data = create_ct_benchmark_data(dx, dy)
88+
89+
# y_x marginal rule
90+
SUITE["marginals"]["y_x"]["dx=$(dx), dy=$(dy)"] = @benchmarkable begin
91+
@call_marginalrule ContinuousTransition(:y_x) (m_y = $data.m_y, m_x = $data.m_x, q_a = $data.q_a, q_W = $data.q_W, meta = $data.meta)
92+
end
93+
end
94+
end

src/nodes/predefined/continuous_transition.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,18 @@ end
127127

128128
g1 = -mA * Vyx'
129129
g2 = g1'
130+
131+
# Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy)
132+
# Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j]
133+
H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy]
134+
135+
# Step 2: Compute traces
130136
trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma))
131137
xxt = mx * mx'
132-
for (i, j) in Iterators.product(1:dy, 1:dy)
133-
FjVaFi = Fs[j] * Va * Fs[i]'
134-
trWSU += mW[j, i] * tr(FjVaFi)
135-
trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi)
138+
for i in 1:dy
139+
HVaFi = H[i] * Va * Fs[i]'
140+
trWSU += tr(HVaFi)
141+
trkronxxWSU += tr(xxt * HVaFi)
136142
end
137143
AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2
138144

@@ -151,12 +157,17 @@ end
151157
n = div(ndims(q_y), 2)
152158
mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta)
153159

160+
# Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy)
161+
# Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j]
162+
H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy]
163+
164+
# Step 2: Compute traces
154165
trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma))
155166
xxt = mx * mx'
156-
for (i, j) in Iterators.product(1:dy, 1:dy)
157-
FjVaFi = Fs[j] * Va * Fs[i]'
158-
trWSU += mW[j, i] * tr(FjVaFi)
159-
trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi)
167+
for i in 1:dy
168+
HVaFi = H[i] * Va * Fs[i]'
169+
trWSU += tr(HVaFi)
170+
trkronxxWSU += tr(xxt * HVaFi)
160171
end
161172
AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2
162173

src/rules/continuous_transition/a.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717

1818
Vxymxy = rank1update(Vyx', mx, my)
1919
Vxmx = rank1update(Vx, mx)
20+
21+
# Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy)
22+
# Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j]
23+
H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy]
24+
25+
# Step 2: Compute xi and W
2026
for i in 1:dy
2127
xi += Fs[i]' * Vxymxy * mW[:, i]
22-
for j in 1:dy
23-
W += mW[j, i] * Fs[i]' * Vxmx * Fs[j]
24-
end
28+
W += Fs[i]' * Vxmx * H[i]
2529
end
2630

2731
return MvNormalWeightedMeanPrecision(xi, W)
@@ -42,11 +46,14 @@ end
4246
mxmy = mx * my'
4347
Vxmx = rank1update(Vx, mx)
4448

49+
# Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy)
50+
# Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j]
51+
H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy]
52+
53+
# Step 2: Compute xi and W
4554
for i in 1:dy
4655
xi += Fs[i]' * mxmy * mW[:, i]
47-
for j in 1:dy
48-
W += mW[j, i] * Fs[i]' * Vxmx * Fs[j]
49-
end
56+
W += Fs[i]' * Vxmx * H[i]
5057
end
5158

5259
return MvNormalWeightedMeanPrecision(xi, W)

src/rules/continuous_transition/marginals.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@ function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamil
2424

2525
W_21 = negate_inplace!(mA' * mW)
2626

27+
# Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy)
28+
# Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j]
29+
H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy]
30+
31+
# Step 2: Compute Ξ
2732
Ξ = Wx
28-
for (i, j) in Iterators.product(1:dy, 1:dy)
29-
Ξ += mW[j, i] * Fs[j] * Va * Fs[i]'
33+
for i in 1:dy
34+
Ξ += H[i] * Va * Fs[i]'
3035
end
3136

3237
W_22 = Ξ + mA' * mW * mA

0 commit comments

Comments
 (0)