From c8bd85166ae1f06a743b891f34a04779bd3f6dc9 Mon Sep 17 00:00:00 2001 From: skoghoern <136440882+skoghoern@users.noreply.github.com> Date: Mon, 27 Oct 2025 19:28:30 +0100 Subject: [PATCH 1/4] Update precision.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Regularize invS in MvNormalMeanPrecision(:Λ, Marginalisation) rule to prevent singularity --- src/rules/mv_normal_mean_precision/precision.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/rules/mv_normal_mean_precision/precision.jl b/src/rules/mv_normal_mean_precision/precision.jl index e4f28b0e0..9865db329 100644 --- a/src/rules/mv_normal_mean_precision/precision.jl +++ b/src/rules/mv_normal_mean_precision/precision.jl @@ -1,15 +1,16 @@ - # Variational # # --------------------------------- # -@rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out::Any, q_μ::Any) = begin + @rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out::Any, q_μ::Any) = begin m_out, v_out = mean_cov(q_out) m_mean, v_mean = mean_cov(q_μ) + d = ndims(q_μ) - df = ndims(q_μ) + 2 - invS = v_mean + v_out + (m_mean - m_out) * (m_mean - m_out)' - + df = d + 2 + invS_raw = v_mean + v_out + (m_mean - m_out) * (m_mean - m_out)' + invS = invS_raw + 1e-6 * diagm(ones(d)) + return WishartFast(df, invS) -end +end @rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out_μ::Any,) = begin m_out_μ, v_out_μ = mean_cov(q_out_μ) From a5a484ca401c088e2f148bd6f31db8bc9d58d296 Mon Sep 17 00:00:00 2001 From: skoghoern <136440882+skoghoern@users.noreply.github.com> Date: Mon, 27 Oct 2025 19:29:40 +0100 Subject: [PATCH 2/4] Update precision.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Regularize invS in MvNormalMeanPrecision(:Λ, Marginalisation) rule to prevent singularity --- src/rules/mv_normal_mean_precision/precision.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/mv_normal_mean_precision/precision.jl b/src/rules/mv_normal_mean_precision/precision.jl index 9865db329..b715eb522 100644 --- a/src/rules/mv_normal_mean_precision/precision.jl +++ b/src/rules/mv_normal_mean_precision/precision.jl @@ -10,7 +10,7 @@ invS = invS_raw + 1e-6 * diagm(ones(d)) return WishartFast(df, invS) -end +end @rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out_μ::Any,) = begin m_out_μ, v_out_μ = mean_cov(q_out_μ) From 1fa15619013d3fbdcdd6b4bb3afe6e76367d1373 Mon Sep 17 00:00:00 2001 From: skoghoern <136440882+skoghoern@users.noreply.github.com> Date: Mon, 27 Oct 2025 19:31:07 +0100 Subject: [PATCH 3/4] Update precision.jl deleted space --- src/rules/mv_normal_mean_precision/precision.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/mv_normal_mean_precision/precision.jl b/src/rules/mv_normal_mean_precision/precision.jl index b715eb522..78638f495 100644 --- a/src/rules/mv_normal_mean_precision/precision.jl +++ b/src/rules/mv_normal_mean_precision/precision.jl @@ -1,6 +1,6 @@ # Variational # # --------------------------------- # - @rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out::Any, q_μ::Any) = begin +@rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out::Any, q_μ::Any) = begin m_out, v_out = mean_cov(q_out) m_mean, v_mean = mean_cov(q_μ) d = ndims(q_μ) From 161acf59c8ac806537f0ab014595240fadfe385a Mon Sep 17 00:00:00 2001 From: skoghoern <136440882+skoghoern@users.noreply.github.com> Date: Thu, 30 Oct 2025 20:08:17 +0100 Subject: [PATCH 4/4] Update precision.jl - using correction! from https://github.com/ReactiveBayes/MatrixCorrectionTools.jl --- src/rules/mv_normal_mean_precision/precision.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/rules/mv_normal_mean_precision/precision.jl b/src/rules/mv_normal_mean_precision/precision.jl index 78638f495..352224b52 100644 --- a/src/rules/mv_normal_mean_precision/precision.jl +++ b/src/rules/mv_normal_mean_precision/precision.jl @@ -1,25 +1,24 @@ # Variational # # --------------------------------- # -@rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out::Any, q_μ::Any) = begin +@rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out::Any, q_μ::Any, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin m_out, v_out = mean_cov(q_out) m_mean, v_mean = mean_cov(q_μ) d = ndims(q_μ) df = d + 2 - invS_raw = v_mean + v_out + (m_mean - m_out) * (m_mean - m_out)' - invS = invS_raw + 1e-6 * diagm(ones(d)) + invS = correction!(meta, v_mean + v_out + (m_mean - m_out) * (m_mean - m_out)') return WishartFast(df, invS) end -@rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out_μ::Any,) = begin +@rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out_μ::Any, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin m_out_μ, v_out_μ = mean_cov(q_out_μ) d = div(ndims(q_out_μ), 2) mdiff = @views m_out_μ[1:d] - m_out_μ[(d + 1):end] vdiff = @views v_out_μ[1:d, 1:d] - v_out_μ[1:d, (d + 1):end] - v_out_μ[(d + 1):end, 1:d] + v_out_μ[(d + 1):end, (d + 1):end] - invS = vdiff + mdiff * mdiff' + invS = correction!(meta, vdiff + mdiff * mdiff') return WishartFast(d + 2, invS) end