Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ext/ReactiveMPProjectionExt/ReactiveMPProjectionExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ using ReactiveMP, ExponentialFamily, Distributions, ExponentialFamilyProjection,
struct DivisionOf{A, B}
numerator::A
denumerator::B
function DivisionOf(numerator::A, denumerator::B) where {A, B}
if variate_form(A) == variate_form(B)
return new{A, B}(numerator, denumerator)
else
error("DivisionOf does not support arguments of different variate forms: $(variate_form(A)) and $(variate_form(B))")
end
end
end

(divisionof::DivisionOf)(x) = logpdf(divisionof, x)
Expand Down Expand Up @@ -44,6 +51,8 @@ include("layout/cvi_projection.jl")
include("rules/in.jl")
include("rules/out.jl")
include("rules/marginals.jl")
include("divisionof/univariate_gaussian.jl")
include("divisionof/multivariate_gaussian.jl")

# This will enable the extension and make `CVIProjection` compatible with delta nodes
# Otherwise it should throw an error suggesting users to install `ExponentialFamilyProjection`
Expand Down
28 changes: 28 additions & 0 deletions ext/ReactiveMPProjectionExt/divisionof/multivariate_gaussian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
function BayesBase.prod(
::GenericProd, something::C, division::DivisionOf{A, B}
) where {A <: MultivariateGaussianDistributionsFamily, B <: MultivariateGaussianDistributionsFamily, C <: MultivariateGaussianDistributionsFamily}
d_numerator = convert(MvNormalMeanPrecision, division.numerator)
d_denumerator = convert(MvNormalMeanPrecision, division.denumerator)
d_something = convert(MvNormalMeanPrecision, something)

ef_a = convert(ExponentialFamilyDistribution, d_numerator)
ef_b = convert(ExponentialFamilyDistribution, d_denumerator)
ef_c = convert(ExponentialFamilyDistribution, d_something)

ef_a_typetag = ExponentialFamily.exponential_family_typetag(ef_a)

resulting_nat_params = ExponentialFamily.getnaturalparameters(ef_a) - ExponentialFamily.getnaturalparameters(ef_b) + ExponentialFamily.getnaturalparameters(ef_c)
ef_resulting = ExponentialFamily.ExponentialFamilyDistribution(ef_a_typetag, resulting_nat_params, nothing, nothing)
Copy link
Copy Markdown
Member Author

@Nimrais Nimrais Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bvdmitri This line is the "main" change in this PR, that makes RxInfer test work.


if !ExponentialFamily.isproper(ef_resulting)
@warn "The product of $(something) and $(division.numerator) divided by $(division.denumerator) is not proper" maxlog=1
end

return convert(Distribution, ef_resulting)
end

function BayesBase.prod(
prodtype::GenericProd, division::DivisionOf{A, B}, something::C
) where {A <: MultivariateGaussianDistributionsFamily, B <: MultivariateGaussianDistributionsFamily, C <: MultivariateGaussianDistributionsFamily}
return prod(prodtype, something, division)
end
24 changes: 24 additions & 0 deletions ext/ReactiveMPProjectionExt/divisionof/univariate_gaussian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
function BayesBase.prod(
::GenericProd, something::C, division::DivisionOf{A, B}
) where {A <: UnivariateGaussianDistributionsFamily, B <: UnivariateGaussianDistributionsFamily, C <: UnivariateGaussianDistributionsFamily}
ef_a = convert(ExponentialFamilyDistribution, division.numerator)
ef_b = convert(ExponentialFamilyDistribution, division.denumerator)
ef_c = convert(ExponentialFamilyDistribution, something)

ef_a_typetag = ExponentialFamily.exponential_family_typetag(ef_a)

resulting_nat_params = ExponentialFamily.getnaturalparameters(ef_a) - ExponentialFamily.getnaturalparameters(ef_b) + ExponentialFamily.getnaturalparameters(ef_c)
ef_resulting = ExponentialFamily.ExponentialFamilyDistribution(ef_a_typetag, resulting_nat_params, nothing, nothing)

if !ExponentialFamily.isproper(ef_resulting)
@warn "The product of $(something) and $(division.numerator) divided by $(division.denumerator) is not proper" maxlog=1
end

return convert(Distribution, ef_resulting)
end

function BayesBase.prod(
prodtype::GenericProd, division::DivisionOf{A, B}, something::C
) where {A <: UnivariateGaussianDistributionsFamily, B <: UnivariateGaussianDistributionsFamily, C <: UnivariateGaussianDistributionsFamily}
return prod(prodtype, something, division)
end
99 changes: 99 additions & 0 deletions test/ext/ReactiveMPProjectionExt/ReactiveMPProjectionExt_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,105 @@
@test ext.DivisionOf(d1, d2) == prod(GenericProd(), missing, ext.DivisionOf(d1, d2))
end

@testitem "Check warning when DivisionOf is not proper" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase
ext = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt)
@test !isnothing(ext)
using .ext

@testset "Check warning when DivisionOf is not proper" begin
d1 = NormalMeanVariance(0, 1)
d2 = NormalMeanVariance(0, 0.01)
d3 = NormalMeanVariance(0, 0.5)
@test_logs (:warn, "The product of $(d3) and $(d1) divided by $(d2) is not proper") prod(GenericProd(), ext.DivisionOf(d1, d2), d3)
end

@testset "Check warning when DivisionOf is not proper" begin
d1 = MvNormalMeanCovariance([0.0, 0.0], [1.0 0.0; 0.0 1.0])
d2 = MvNormalMeanCovariance([1.0, 1.0], [0.01 0.0; 0.0 0.01])
d3 = MvNormalMeanCovariance([2.0, 2.0], [3.0 0.0; 0.0 3.0])
@test_logs (:warn, "The product of $(d3) and $(d1) divided by $(d2) is not proper") prod(GenericProd(), ext.DivisionOf(d1, d2), d3)
end

@testset "Check no warning when DivisionOf is proper" begin
d1 = NormalMeanVariance(0, 1)
d2 = NormalMeanVariance(0, 2)
d3 = NormalMeanVariance(0, 1)
@test_logs prod(GenericProd(), ext.DivisionOf(d1, d2), d3)
end
end

@testitem "DivisionOf(Gaussian, Gaussian) x Gaussian" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase

# `DivisionOf` is internal to the extension
ext = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt)
@test !isnothing(ext)
using .ext

d1 = NormalMeanVariance(0, 1)
d2 = NormalMeanVariance(1, 2)
d3 = NormalMeanVariance(2, 3)
result = NormalMeanVariance(0.2, 1.2)

@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ prod(GenericProd(), d3, ext.DivisionOf(d1, d2))
@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ result

d1 = NormalMeanVariance(0, 1)
d2 = convert(NormalWeightedMeanPrecision, d2)
d3 = convert(NormalMeanPrecision, d3)

@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ prod(GenericProd(), d3, ext.DivisionOf(d1, d2))
@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ result
end

@testitem "DivisionOf(MvGaussian, MvGaussian) x MvGaussian" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase

# `DivisionOf` is internal to the extension
ext = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt)
@test !isnothing(ext)
using .ext

d1 = MvNormalMeanCovariance([0.0, 0.0], [1.0 0.0; 0.0 1.0])
d2 = MvNormalMeanCovariance([1.0, 1.0], [2.0 0.0; 0.0 2.0])
d3 = MvNormalMeanCovariance([2.0, 2.0], [3.0 0.0; 0.0 3.0])
result = MvNormalMeanCovariance([0.2, 0.2], [1.2 0.0; 0.0 1.2])

@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ prod(GenericProd(), d3, ext.DivisionOf(d1, d2))
@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ result

d1 = MvNormalMeanCovariance([0.0, 0.0], [1.0 0.0; 0.0 1.0])
d2 = convert(MvNormalWeightedMeanPrecision, d2)
d3 = convert(MvNormalMeanPrecision, d3)

@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ prod(GenericProd(), d3, ext.DivisionOf(d1, d2))
@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ result

d1 = MvNormalMeanCovariance([0.0, 0.0], [1.0 0.0; 0.0 1.0])
d2 = MvNormalMeanScalePrecision([1.0, 1.0], 1 / 2)
d3 = convert(MvNormalMeanPrecision, d3)

@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ prod(GenericProd(), d3, ext.DivisionOf(d1, d2))
@test prod(GenericProd(), ext.DivisionOf(d1, d2), d3) ≈ result
end

@testitem "Raise error when DivisionOf of Univarive and Multivariate" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase

# `DivisionOf` is internal to the extension
ext = Base.get_extension(ReactiveMP, :ReactiveMPProjectionExt)
@test !isnothing(ext)
using .ext

d1 = NormalMeanVariance(0, 1)
d2 = MvNormalMeanCovariance([0.0, 0.0], [1.0 0.0; 0.0 1.0])
d3 = NormalMeanVariance(0, 1)

@test_throws "DivisionOf does not support arguments of different variate forms" prod(GenericProd(), ext.DivisionOf(d1, d2), d2)
@test_throws "DivisionOf does not support arguments of different variate forms" prod(GenericProd(), d2, ext.DivisionOf(d1, d2))
end

@testitem "create_project_to_ins type stability" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase, Test
using ReactiveMP: CVIProjection
Expand Down