|
6 | 6 |
|
7 | 7 | @testset "Variational: (q_out::MultivariateNormalDistributionsFamily, q_γ::Gamma)" begin |
8 | 8 | @test_rules [check_type_promotion = true] MvNormalMeanScalePrecision(:out, Marginalisation) [ |
9 | | - (input = (q_μ = MvNormalMeanCovariance([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(1.0, 1.0)), output = MvNormalMeanPrecision([2.0, 1.0], [1.0 0.0; 0.0 1.0])), |
10 | | - (input = (q_μ = MvNormalMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(3.0, 2.0)), output = MvNormalMeanPrecision([2.0, 1.0], [6.0 0.0; 0.0 6.0])), |
11 | | - ( |
12 | | - input = (q_μ = MvNormalWeightedMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(4.0, 2.0)), |
13 | | - output = MvNormalMeanPrecision([3 / 4, -1 / 8], [8.0 0.0; 0.0 8.0]) |
14 | | - ) |
| 9 | + (input = (q_μ = MvNormalMeanCovariance([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(1.0, 1.0)), output = MvNormalMeanScalePrecision([2.0, 1.0], 1.0)), |
| 10 | + (input = (q_μ = MvNormalMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(3.0, 2.0)), output = MvNormalMeanScalePrecision([2.0, 1.0], 6.0)), |
| 11 | + (input = (q_μ = MvNormalWeightedMeanPrecision([2.0, 1.0], [3.0 2.0; 2.0 4.0]), q_γ = Gamma(4.0, 2.0)), output = MvNormalMeanScalePrecision([3 / 4, -1 / 8], 8.0)) |
15 | 12 | ] |
16 | 13 | end |
17 | 14 |
|
|
25 | 22 | ) |
26 | 23 | ] |
27 | 24 | end |
| 25 | + |
| 26 | + @testset "Structured variational: (m_μ::MvNormalMeanScalePrecision, q_γ::Gamma)" begin |
| 27 | + @test_rules [check_type_promotion = true] MvNormalMeanScalePrecision(:out, Marginalisation) [ |
| 28 | + (input = (m_μ = MvNormalMeanScalePrecision([2.0, 1.0], 3.0), q_γ = Gamma(1.0, 1.0)), output = MvNormalMeanScalePrecision([2.0, 1.0], 3.0 * 1.0 / (3.0 + 1.0))), |
| 29 | + (input = (m_μ = MvNormalMeanScalePrecision([0.0, 0.0], 4.0), q_γ = GammaShapeRate(4.0, 2.0)), output = MvNormalMeanScalePrecision([0.0, 0.0], 4.0 * 2.0 / (4.0 + 2.0))), |
| 30 | + ( |
| 31 | + input = (m_μ = MvNormalMeanScalePrecision([3.0, -1.0], 2.0), q_γ = GammaShapeRate(2.0, 1.0)), |
| 32 | + output = MvNormalMeanScalePrecision([3.0, -1.0], 2.0 * 2.0 / (2.0 + 2.0)) |
| 33 | + ) |
| 34 | + ] |
| 35 | + end |
| 36 | + |
| 37 | + @testset "Performance: MvNormalMeanScalePrecision rule allocates less than general rule" begin |
| 38 | + import ReactiveMP: @call_rule |
| 39 | + |
| 40 | + for n in (10, 100) |
| 41 | + m_μ_scale = MvNormalMeanScalePrecision(zeros(n), 3.0) |
| 42 | + m_μ_general = MvNormalMeanCovariance(zeros(n), diageye(Float64, n) / 3.0) |
| 43 | + q_γ = GammaShapeRate(2.0, 1.0) |
| 44 | + |
| 45 | + # Warm up |
| 46 | + @call_rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ = m_μ_scale, q_γ = q_γ) |
| 47 | + @call_rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ = m_μ_general, q_γ = q_γ) |
| 48 | + |
| 49 | + allocs_scale = @allocated @call_rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ = m_μ_scale, q_γ = q_γ) |
| 50 | + allocs_general = @allocated @call_rule MvNormalMeanScalePrecision(:out, Marginalisation) (m_μ = m_μ_general, q_γ = q_γ) |
| 51 | + |
| 52 | + @test allocs_scale < allocs_general |
| 53 | + end |
| 54 | + end |
28 | 55 | end |
0 commit comments