Skip to content

Commit f3f04ef

Browse files
authored
Merge pull request #579 from ReactiveBayes/optimize-mv-normal-mean-scale-precision
Optimize mv normal mean scale precision
2 parents 3b97062 + fae43e3 commit f3f04ef

4 files changed

Lines changed: 84 additions & 13 deletions

File tree

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11

22
# Variational #
33
# --------------------------------- #
4-
@rule MvNormalMeanScalePrecision(, Marginalisation) (q_out::Any, q_γ::Any) = MvNormalMeanPrecision(mean(q_out), mean(q_γ) * diageye(eltype(q_out), ndims(q_out)))
4+
@rule MvNormalMeanScalePrecision(, Marginalisation) (q_out::Any, q_γ::Any) = MvNormalMeanScalePrecision(mean(q_out), mean(q_γ))
55

66
@rule MvNormalMeanScalePrecision(, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, q_γ::Any) = begin
77
m_out_mean, m_out_cov = mean_cov(m_out)
88
return MvNormalMeanCovariance(m_out_mean, m_out_cov + inv(mean(q_γ)) * diageye(samplefloattype(m_out), ndims(m_out)))
99
end
10+
11+
@rule MvNormalMeanScalePrecision(, Marginalisation) (m_out::MvNormalMeanScalePrecision, q_γ::Any) = begin
12+
m_out_mean = mean(m_out)
13+
l_γ = m_out.γ
14+
r_γ = mean(q_γ)
15+
return MvNormalMeanScalePrecision(m_out_mean, (l_γ*r_γ)/(l_γ+r_γ))
16+
end
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
# Variational #
22
# --------------------------------- #
3-
@rule MvNormalMeanScalePrecision(:out, Marginalisation) (q_μ::Any, q_γ::Any) = MvNormalMeanPrecision(mean(q_μ), mean(q_γ) * diageye(eltype(q_μ), ndims(q_μ)))
3+
@rule MvNormalMeanScalePrecision(:out, Marginalisation) (q_μ::Any, q_γ::Any) = MvNormalMeanScalePrecision(mean(q_μ), mean(q_γ))
44

55
@rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ::MultivariateNormalDistributionsFamily, q_γ::Any) = begin
66
m_μ_mean, m_μ_cov = mean_cov(m_μ)
77
return MvNormalMeanCovariance(m_μ_mean, m_μ_cov + inv(mean(q_γ)) * diageye(eltype(m_μ), ndims(m_μ)))
88
end
9+
10+
@rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ::MvNormalMeanScalePrecision, q_γ::Any) = begin
11+
m_out_mean = mean(m_μ)
12+
l_γ = m_μ.γ
13+
r_γ = mean(q_γ)
14+
return MvNormalMeanScalePrecision(m_out_mean, (l_γ*r_γ)/(l_γ+r_γ))
15+
end

test/rules/mv_normal_mean_scale_precision/mean_tests.jl

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66

77
@testset "Variational: (q_out::MultivariateNormalDistributionsFamily, q_γ::Gamma)" begin
88
@test_rules [check_type_promotion = true] MvNormalMeanScalePrecision(, Marginalisation) [
9-
(
10-
input = (q_out = MvNormalMeanCovariance([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = GammaShapeRate(1.0, 1.0)),
11-
output = MvNormalMeanPrecision([2.0, 1.0], [1.0 0.0; 0.0 1.0])
12-
),
13-
(input = (q_out = MvNormalMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(3.0, 1.0)), output = MvNormalMeanPrecision([2.0, 1.0], [3.0 0.0; 0.0 3.0]))
9+
(input = (q_out = MvNormalMeanCovariance([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = GammaShapeRate(1.0, 1.0)), output = MvNormalMeanScalePrecision([2.0, 1.0], 1.0)),
10+
(input = (q_out = MvNormalMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(3.0, 1.0)), output = MvNormalMeanScalePrecision([2.0, 1.0], 3.0))
1411
]
1512
end
1613

@@ -27,4 +24,37 @@
2724
)
2825
]
2926
end
27+
28+
@testset "Structured variational: (m_out::MvNormalMeanScalePrecision, q_γ::Gamma)" begin
29+
@test_rules [check_type_promotion = true] MvNormalMeanScalePrecision(, Marginalisation) [
30+
(input = (m_out = MvNormalMeanScalePrecision([2.0, 1.0], 3.0), q_γ = Gamma(1.0, 1.0)), output = MvNormalMeanScalePrecision([2.0, 1.0], 3.0 * 1.0 / (3.0 + 1.0))),
31+
(
32+
input = (m_out = MvNormalMeanScalePrecision([0.0, 0.0], 4.0), q_γ = GammaShapeRate(4.0, 2.0)),
33+
output = MvNormalMeanScalePrecision([0.0, 0.0], 4.0 * 2.0 / (4.0 + 2.0))
34+
),
35+
(
36+
input = (m_out = MvNormalMeanScalePrecision([3.0, -1.0], 2.0), q_γ = GammaShapeRate(2.0, 1.0)),
37+
output = MvNormalMeanScalePrecision([3.0, -1.0], 2.0 * 2.0 / (2.0 + 2.0))
38+
)
39+
]
40+
end
41+
42+
@testset "Performance: MvNormalMeanScalePrecision rule allocates less than general rule" begin
43+
import ReactiveMP: @call_rule
44+
45+
for n in (10, 100)
46+
m_out_scale = MvNormalMeanScalePrecision(zeros(n), 3.0)
47+
m_out_general = MvNormalMeanCovariance(zeros(n), diageye(Float64, n) / 3.0)
48+
q_γ = GammaShapeRate(2.0, 1.0)
49+
50+
# Warm up
51+
@call_rule MvNormalMeanScalePrecision(, Marginalisation) (m_out = m_out_scale, q_γ = q_γ)
52+
@call_rule MvNormalMeanScalePrecision(, Marginalisation) (m_out = m_out_general, q_γ = q_γ)
53+
54+
allocs_scale = @allocated @call_rule MvNormalMeanScalePrecision(, Marginalisation) (m_out = m_out_scale, q_γ = q_γ)
55+
allocs_general = @allocated @call_rule MvNormalMeanScalePrecision(, Marginalisation) (m_out = m_out_general, q_γ = q_γ)
56+
57+
@test allocs_scale < allocs_general
58+
end
59+
end
3060
end

test/rules/mv_normal_mean_scale_precision/out_tests.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@
66

77
@testset "Variational: (q_out::MultivariateNormalDistributionsFamily, q_γ::Gamma)" begin
88
@test_rules [check_type_promotion = true] MvNormalMeanScalePrecision(:out, Marginalisation) [
9-
(input = (q_μ = MvNormalMeanCovariance([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(1.0, 1.0)), output = MvNormalMeanPrecision([2.0, 1.0], [1.0 0.0; 0.0 1.0])),
10-
(input = (q_μ = MvNormalMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(3.0, 2.0)), output = MvNormalMeanPrecision([2.0, 1.0], [6.0 0.0; 0.0 6.0])),
11-
(
12-
input = (q_μ = MvNormalWeightedMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(4.0, 2.0)),
13-
output = MvNormalMeanPrecision([3 / 4, -1 / 8], [8.0 0.0; 0.0 8.0])
14-
)
9+
(input = (q_μ = MvNormalMeanCovariance([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(1.0, 1.0)), output = MvNormalMeanScalePrecision([2.0, 1.0], 1.0)),
10+
(input = (q_μ = MvNormalMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(3.0, 2.0)), output = MvNormalMeanScalePrecision([2.0, 1.0], 6.0)),
11+
(input = (q_μ = MvNormalWeightedMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(4.0, 2.0)), output = MvNormalMeanScalePrecision([3 / 4, -1 / 8], 8.0))
1512
]
1613
end
1714

@@ -25,4 +22,34 @@
2522
)
2623
]
2724
end
25+
26+
@testset "Structured variational: (m_μ::MvNormalMeanScalePrecision, q_γ::Gamma)" begin
27+
@test_rules [check_type_promotion = true] MvNormalMeanScalePrecision(:out, Marginalisation) [
28+
(input = (m_μ = MvNormalMeanScalePrecision([2.0, 1.0], 3.0), q_γ = Gamma(1.0, 1.0)), output = MvNormalMeanScalePrecision([2.0, 1.0], 3.0 * 1.0 / (3.0 + 1.0))),
29+
(input = (m_μ = MvNormalMeanScalePrecision([0.0, 0.0], 4.0), q_γ = GammaShapeRate(4.0, 2.0)), output = MvNormalMeanScalePrecision([0.0, 0.0], 4.0 * 2.0 / (4.0 + 2.0))),
30+
(
31+
input = (m_μ = MvNormalMeanScalePrecision([3.0, -1.0], 2.0), q_γ = GammaShapeRate(2.0, 1.0)),
32+
output = MvNormalMeanScalePrecision([3.0, -1.0], 2.0 * 2.0 / (2.0 + 2.0))
33+
)
34+
]
35+
end
36+
37+
@testset "Performance: MvNormalMeanScalePrecision rule allocates less than general rule" begin
38+
import ReactiveMP: @call_rule
39+
40+
for n in (10, 100)
41+
m_μ_scale = MvNormalMeanScalePrecision(zeros(n), 3.0)
42+
m_μ_general = MvNormalMeanCovariance(zeros(n), diageye(Float64, n) / 3.0)
43+
q_γ = GammaShapeRate(2.0, 1.0)
44+
45+
# Warm up
46+
@call_rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ = m_μ_scale, q_γ = q_γ)
47+
@call_rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ = m_μ_general, q_γ = q_γ)
48+
49+
allocs_scale = @allocated @call_rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ = m_μ_scale, q_γ = q_γ)
50+
allocs_general = @allocated @call_rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ = m_μ_general, q_γ = q_γ)
51+
52+
@test allocs_scale < allocs_general
53+
end
54+
end
2855
end

0 commit comments

Comments
 (0)