From 988fb88dadb1f173bee198c47a62ff9c3aae5fa7 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 21 Nov 2025 12:44:26 +0100 Subject: [PATCH 01/17] feat: draft closed form strategy --- Project.toml | 7 +- .../ClosedFormExpectationsExt.jl | 101 +++++++++++++ src/ExponentialFamilyProjection.jl | 1 + src/strategies/closed_form.jl | 22 +++ test/strategies/closed_form_tests.jl | 134 ++++++++++++++++++ 5 files changed, 264 insertions(+), 1 deletion(-) create mode 100644 ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl create mode 100644 src/strategies/closed_form.jl create mode 100644 test/strategies/closed_form_tests.jl diff --git a/Project.toml b/Project.toml index b88423b..c153099 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "3.1.3" [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" +ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b" ExponentialFamilyManifolds = "5c9727c4-3b82-4ab3-b165-76e2eb971b08" @@ -22,6 +23,9 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +[extensions] +ClosedFormExpectationsExt = "ClosedFormExpectations" + [compat] BayesBase = "1.5.0" Bumper = "0.6" @@ -46,6 +50,7 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ClosedFormExpectations = "675a34d0-5017-47c3-81d8-6952f69c707e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -58,4 +63,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Aqua", "BenchmarkTools", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"] +test = ["Test", "Aqua", "BenchmarkTools", "ClosedFormExpectations", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"] diff --git a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl new file mode 100644 index 0000000..d3054ce --- /dev/null +++ b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl @@ -0,0 +1,101 @@ +module ClosedFormExpectationsExt + +using ExponentialFamilyProjection +using ClosedFormExpectations +using ExponentialFamily +using ExponentialFamilyManifolds +using Manifolds +using ManifoldsBase +using LinearAlgebra +using Distributions + +ExponentialFamilyProjection.get_nsamples(::ClosedFormStrategy) = 0 + +function logbasemeasure_correction(::ClosedFormStrategy, ::ExponentialFamily.ConstantBaseMeasure, q_dist, grad_target) + grad_target +end + +function ExponentialFamilyProjection.compute_gradient!( + M::AbstractManifold, + strategy::ClosedFormStrategy, + state, + X, + η, + logpartition, + gradlogpartition, + inv_fisher +) + # The gradient of KL(q||p) involves E_q[(log p̃ - log h_q) * (T - μ)] + # where h_q is the base measure of q (the variational distribution). + # + # For constant base measure: + # E[(log p̃ - log h) * (T - μ)] = E[log p̃ * (T - μ)] - log h * E[(T - μ)] + # Since E[T] = μ, the second term is zero. + + target_fn = state.target + + # Convert natural parameters on manifold to an ExponentialFamilyDistribution object + q_dist = convert(ExponentialFamilyDistribution, M, ExponentialFamilyManifolds.partition_point(M, η)) + + # Compute ∇_η E[log p̃ * (T - μ)] + grad_target = mean(ClosedWilliamsProduct(), target_fn, q_dist) + grad_eta = logbasemeasure_correction(strategy, ExponentialFamily.isbasemeasureconstant(q_dist), q_dist, grad_target) + + # Natural Gradient Update: X = η - F⁻¹ * ∇_η E + X .= η .- inv_fisher * grad_eta + + return X +end + +# Helper to create state +struct ClosedFormStrategyState{T} + target::T +end + +function ExponentialFamilyProjection.create_state!( + strategy::ClosedFormStrategy, + M::AbstractManifold, + parameters::ProjectionParameters, + projection_argument, + initial_ef, + supplementary_η, +) + return ClosedFormStrategyState(projection_argument) +end + +function ExponentialFamilyProjection.prepare_state!( + strategy::ClosedFormStrategy, + state::ClosedFormStrategyState, + M::AbstractManifold, + parameters::ProjectionParameters, + projection_argument, + current_ef, + supplementary_η, +) + return state +end + +# Compute cost for logging/convergence check +function ExponentialFamilyProjection.compute_cost( + M::AbstractManifold, + strategy::ClosedFormStrategy, + state::ClosedFormStrategyState, + η, + logpartition, + gradlogpartition, + inv_fisher +) + # Cost = KL(q || p) = E_q[log q] - E_q[log p] + + # Reconstruct distribution + q_dist = convert(ExponentialFamilyDistribution, M, ExponentialFamilyManifolds.partition_point(M, η)) + dist_std = convert(Distribution, q_dist) + + # E_q[log p] via CFE + E_log_p = mean(ClosedFormExpectation(), state.target, dist_std) + + # E_q[log q] = -entropy(q) + return -entropy(dist_std) - E_log_p +end + +end diff --git a/src/ExponentialFamilyProjection.jl b/src/ExponentialFamilyProjection.jl index 8431f42..39adb1a 100644 --- a/src/ExponentialFamilyProjection.jl +++ b/src/ExponentialFamilyProjection.jl @@ -123,6 +123,7 @@ function compute_gradient! end include("strategies/control_variate.jl") include("strategies/mle.jl") include("strategies/default.jl") +include("strategies/closed_form.jl") # Bonnet strategy include("strategies/bonnet/naive_grad_hess.jl") include("strategies/bonnet/bonnet_logpdf.jl") diff --git a/src/strategies/closed_form.jl b/src/strategies/closed_form.jl new file mode 100644 index 0000000..6c2b86f --- /dev/null +++ b/src/strategies/closed_form.jl @@ -0,0 +1,22 @@ +export ClosedFormStrategy + +""" + ClosedFormStrategy <: ExponentialFamilyProjection.AbstractStrategy + +A projection strategy that uses `ClosedFormExpectations.jl` to compute the exact gradient +of the cross-entropy term \$\\mathbb{E}_{q_\\eta}[\\log \\tilde{p}(x)]\$ analytically. + +This strategy provides a "Zero-Variance" gradient estimator, avoiding the noise associated +with Monte Carlo sampling (like in `ControlVariateStrategy`). + +This estimator was proposed in [Lukashchuk et al., 2024](https://proceedings.mlr.press/v246/lukashchuk24a.html). + +!!! note + To use this strategy, you **must** load the `ClosedFormExpectations` package in your environment. + Loading `ClosedFormExpectations` will trigger the package extension that implements `compute_gradient!` + and `compute_cost` for this strategy. + + It requires that `ClosedFormExpectations.jl` implements `ClosedWilliamsProduct` for the + specific pair of target function and variational family. +""" +struct ClosedFormStrategy end diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl new file mode 100644 index 0000000..1aa99d2 --- /dev/null +++ b/test/strategies/closed_form_tests.jl @@ -0,0 +1,134 @@ +module ClosedFormStrategyTests + +using Test +using ExponentialFamilyProjection +using ExponentialFamilyProjection: ControlVariateStrategy +using ClosedFormExpectations +using Distributions +using ExponentialFamily +using StableRNGs +using LinearAlgebra + +@testset "ClosedFormStrategy" begin + + @testset "LogGamma projected to Normal" begin + # Target: LogGamma(α=2.0, β=1.0) + # Density: p(x) ∝ x^(α-1) * e^(-βx) + # Log density: (α-1)log(x) - βx + + # Note: Normal support is (-∞, ∞), Gamma is (0, ∞). + # This projection is technically improper if Normal puts mass on x < 0. + # But we can test if the gradient machinery works. + # To make it "reasonable", we put the Normal far from 0. + + target_dist = LogGamma(20.0, 1.0) # Mean 20, Var 20. + target = Logpdf(target_dist) + + # Initial approximation: Normal(15, 5) + initial_dist = Normal(15.0, 5.0) + + # Create strategy + strategy = ClosedFormStrategy() + + # Project + # We use a small number of iterations to check it runs and descends + result = project_to(ProjectedTo(NormalMeanVariance), target; + strategy = strategy, + initial_point = initial_dist, + parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5) + ) + + @test result isa NormalMeanVariance + μ, v = mean(result), var(result) + + # We expect it to move towards the mode of the LogGamma. + # Mode of LogGamma(α, β) is at x = log(α*β). + # Here α=20, β=1. Mode at log(20) ≈ 3.0. + + @test 2.0 < μ < 4.0 + end + + @testset "LogNormal projected to Gamma" begin + # Target: LogNormal(μ=1.0, σ=0.5) + # Approx: Gamma + + target_dist = LogNormal(1.0, 0.5) + target = Logpdf(target_dist) + + # Initial: Gamma(2.0, 2.0) -> mean 4, var 8 + initial_dist = Gamma(2.0, 2.0) + + strategy = ClosedFormStrategy() + + result = project_to(ProjectedTo(Gamma), target; + strategy = strategy, + initial_point = initial_dist, + parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5) + ) + + @test result isa GammaDistributionsFamily + + # Check if it ran without error and produced valid parameters + @test shape(result) > 0 + @test scale(result) > 0 + + # Comparison with ControlVariateStrategy + # CV strategy is stochastic, so it might be slightly different but should be close. + + cv_strategy = ControlVariateStrategy(nsamples = 500) + result_cv = project_to(ProjectedTo(Gamma), target; + strategy = cv_strategy, + initial_point = initial_dist, + parameters = ProjectionParameters(niterations = 50, tolerance = 1e-4) + ) + + # KL between result and result_cv should be small + # Since they are both Gamma, we can measure parameter distance or KL + + @test isapprox(mean(result), mean(result_cv), rtol=0.1) + end + + @testset "Comparison: Speed and Accuracy" begin + # Simple case: Normal to Normal (should be exact one step if Newton, but we use GradientDescent) + # Let's use a target that is actually a Normal + target_dist = Normal(5.0, 2.0) + target = Logpdf(target_dist) + + initial = Normal(0.0, 1.0) + + # Analytic + t_analytic = @elapsed begin + res_analytic = project_to(ProjectedTo(NormalMeanVariance), target; + strategy = ClosedFormStrategy(), + initial_point = initial, + parameters = ProjectionParameters(niterations=100) + ) + end + + # MC + t_mc = @elapsed begin + res_mc = project_to(ProjectedTo(NormalMeanVariance), target; + strategy = ControlVariateStrategy(nsamples=1000), + initial_point = initial, + parameters = ProjectionParameters(niterations=100) + ) + end + + println("Analytic time: $t_analytic") + println("MC time: $t_mc") + + # Analytic should be faster (no sampling) and more accurate (converge to exact target) + @test isapprox(mean(res_analytic), mean(target_dist), atol=2e-2) + @test isapprox(std(res_analytic), std(target_dist), atol=2e-2) + + # MC might have some noise + @test abs(mean(res_analytic) - 5.0) < abs(mean(res_mc) - 5.0) + 0.1 # Heuristic check + end + +end + +end + + + + From ec21709c9b142ef90bb9e3b44d650372ea05b10e Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 21 Nov 2025 15:53:09 +0100 Subject: [PATCH 02/17] fix: add process strategy into extension --- .../ClosedFormExpectationsExt.jl | 80 ++++++++++++++-- src/manopt/projection_objective.jl | 13 ++- test/strategies/closed_form_tests.jl | 93 ++++++++++--------- 3 files changed, 134 insertions(+), 52 deletions(-) diff --git a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl index d3054ce..53cf568 100644 --- a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl +++ b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl @@ -9,9 +9,18 @@ using ManifoldsBase using LinearAlgebra using Distributions +# Import types needed for RxInfer closure unwrapping +import ExponentialFamily: ProductOf +import ClosedFormExpectations: Logpdf + ExponentialFamilyProjection.get_nsamples(::ClosedFormStrategy) = 0 -function logbasemeasure_correction(::ClosedFormStrategy, ::ExponentialFamily.ConstantBaseMeasure, q_dist, grad_target) +function logbasemeasure_correction( + ::ClosedFormStrategy, + ::ExponentialFamily.ConstantBaseMeasure, + q_dist, + grad_target, +) grad_target end @@ -23,7 +32,7 @@ function ExponentialFamilyProjection.compute_gradient!( η, logpartition, gradlogpartition, - inv_fisher + inv_fisher, ) # The gradient of KL(q||p) involves E_q[(log p̃ - log h_q) * (T - μ)] # where h_q is the base measure of q (the variational distribution). @@ -35,11 +44,20 @@ function ExponentialFamilyProjection.compute_gradient!( target_fn = state.target # Convert natural parameters on manifold to an ExponentialFamilyDistribution object - q_dist = convert(ExponentialFamilyDistribution, M, ExponentialFamilyManifolds.partition_point(M, η)) + q_dist = convert( + ExponentialFamilyDistribution, + M, + ExponentialFamilyManifolds.partition_point(M, η), + ) # Compute ∇_η E[log p̃ * (T - μ)] grad_target = mean(ClosedWilliamsProduct(), target_fn, q_dist) - grad_eta = logbasemeasure_correction(strategy, ExponentialFamily.isbasemeasureconstant(q_dist), q_dist, grad_target) + grad_eta = logbasemeasure_correction( + strategy, + ExponentialFamily.isbasemeasureconstant(q_dist), + q_dist, + grad_target, + ) # Natural Gradient Update: X = η - F⁻¹ * ∇_η E X .= η .- inv_fisher * grad_eta @@ -83,19 +101,63 @@ function ExponentialFamilyProjection.compute_cost( η, logpartition, gradlogpartition, - inv_fisher + inv_fisher, ) # Cost = KL(q || p) = E_q[log q] - E_q[log p] - + # Reconstruct distribution - q_dist = convert(ExponentialFamilyDistribution, M, ExponentialFamilyManifolds.partition_point(M, η)) + q_dist = convert( + ExponentialFamilyDistribution, + M, + ExponentialFamilyManifolds.partition_point(M, η), + ) dist_std = convert(Distribution, q_dist) - + # E_q[log p] via CFE E_log_p = mean(ClosedFormExpectation(), state.target, dist_std) - + # E_q[log q] = -entropy(q) return -entropy(dist_std) - E_log_p end +# preprocess_strategy_argument for ClosedFormStrategy +# Special handling for RxInfer closures that wrap ProductOf +function ExponentialFamilyProjection.preprocess_strategy_argument( + strategy::ClosedFormStrategy, + argument::Function, +) + # RxInfer wraps ProductOf in a closure. + # Extract the ProductOf from the closure's captured variables. + # The closure typically has one field holding the ProductOf. + fn_type = typeof(argument) + field_names = fieldnames(fn_type) + + if !isempty(field_names) + # Get the first field (usually the captured ProductOf) + captured = getfield(argument, first(field_names)) + + # If it's a ProductOf, use it directly + if captured isa ProductOf + return (strategy, Logpdf(captured)) + end + + # If it's a Distribution (e.g. LogNormal inside ProjectionExt closure), use it directly + if captured isa Distribution + return (strategy, Logpdf(captured)) + end + end + + # Fallback: keep the function as-is + return (strategy, argument) +end + +# Generic fallback for non-Function arguments +function ExponentialFamilyProjection.preprocess_strategy_argument( + strategy::ClosedFormStrategy, + argument, +) + # ClosedFormStrategy accepts any callable or distribution as argument + return (strategy, argument) +end + end diff --git a/src/manopt/projection_objective.jl b/src/manopt/projection_objective.jl index 0db69d4..a1ec6a3 100644 --- a/src/manopt/projection_objective.jl +++ b/src/manopt/projection_objective.jl @@ -58,7 +58,18 @@ function call_objective( logpartition = ExponentialFamily.logpartition(current_ef) gradlogpartition = ExponentialFamily.gradlogpartition(current_ef) - inv_fisher = cholinv(ExponentialFamily.fisherinformation(current_ef)) + fisher = ExponentialFamily.fisherinformation(current_ef) + inv_fisher = try + cholinv(fisher) + catch e + # If the Fisher information matrix is not positive definite, we try to compute the pseudo-inverse + # This can happen if the distribution is degenerate or if the parameters are on the boundary + if e isa PosDefException || e isa LinearAlgebra.PosDefException + pinv(fisher) + else + rethrow(e) + end + end # If we have some supplementary natural parameters in the objective # we must subtract them from the natural parameters of the current η diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl index 1aa99d2..0d19f1d 100644 --- a/test/strategies/closed_form_tests.jl +++ b/test/strategies/closed_form_tests.jl @@ -1,6 +1,7 @@ module ClosedFormStrategyTests using Test +using BayesBase using ExponentialFamilyProjection using ExponentialFamilyProjection: ControlVariateStrategy using ClosedFormExpectations @@ -9,35 +10,39 @@ using ExponentialFamily using StableRNGs using LinearAlgebra +import BayesBase: ProductOf + @testset "ClosedFormStrategy" begin @testset "LogGamma projected to Normal" begin # Target: LogGamma(α=2.0, β=1.0) # Density: p(x) ∝ x^(α-1) * e^(-βx) # Log density: (α-1)log(x) - βx - + # Note: Normal support is (-∞, ∞), Gamma is (0, ∞). # This projection is technically improper if Normal puts mass on x < 0. # But we can test if the gradient machinery works. # To make it "reasonable", we put the Normal far from 0. - + target_dist = LogGamma(20.0, 1.0) # Mean 20, Var 20. target = Logpdf(target_dist) - + # Initial approximation: Normal(15, 5) initial_dist = Normal(15.0, 5.0) - + # Create strategy strategy = ClosedFormStrategy() - + # Project # We use a small number of iterations to check it runs and descends - result = project_to(ProjectedTo(NormalMeanVariance), target; - strategy = strategy, + result = project_to( + ProjectedTo(NormalMeanVariance), + target; + strategy = strategy, initial_point = initial_dist, - parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5) + parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5), ) - + @test result isa NormalMeanVariance μ, v = mean(result), var(result) @@ -45,47 +50,51 @@ using LinearAlgebra # Mode of LogGamma(α, β) is at x = log(α*β). # Here α=20, β=1. Mode at log(20) ≈ 3.0. - @test 2.0 < μ < 4.0 + @test 2.0 < μ < 4.0 end @testset "LogNormal projected to Gamma" begin # Target: LogNormal(μ=1.0, σ=0.5) # Approx: Gamma - + target_dist = LogNormal(1.0, 0.5) target = Logpdf(target_dist) - + # Initial: Gamma(2.0, 2.0) -> mean 4, var 8 initial_dist = Gamma(2.0, 2.0) - + strategy = ClosedFormStrategy() - - result = project_to(ProjectedTo(Gamma), target; + + result = project_to( + ProjectedTo(Gamma), + target; strategy = strategy, initial_point = initial_dist, - parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5) + parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5), ) - + @test result isa GammaDistributionsFamily - + # Check if it ran without error and produced valid parameters @test shape(result) > 0 @test scale(result) > 0 - + # Comparison with ControlVariateStrategy # CV strategy is stochastic, so it might be slightly different but should be close. - + cv_strategy = ControlVariateStrategy(nsamples = 500) - result_cv = project_to(ProjectedTo(Gamma), target; + result_cv = project_to( + ProjectedTo(Gamma), + target; strategy = cv_strategy, initial_point = initial_dist, - parameters = ProjectionParameters(niterations = 50, tolerance = 1e-4) + parameters = ProjectionParameters(niterations = 50, tolerance = 1e-4), ) - + # KL between result and result_cv should be small # Since they are both Gamma, we can measure parameter distance or KL - - @test isapprox(mean(result), mean(result_cv), rtol=0.1) + + @test isapprox(mean(result), mean(result_cv), rtol = 0.1) end @testset "Comparison: Speed and Accuracy" begin @@ -93,34 +102,38 @@ using LinearAlgebra # Let's use a target that is actually a Normal target_dist = Normal(5.0, 2.0) target = Logpdf(target_dist) - + initial = Normal(0.0, 1.0) - + # Analytic t_analytic = @elapsed begin - res_analytic = project_to(ProjectedTo(NormalMeanVariance), target; + res_analytic = project_to( + ProjectedTo(NormalMeanVariance), + target; strategy = ClosedFormStrategy(), initial_point = initial, - parameters = ProjectionParameters(niterations=100) + parameters = ProjectionParameters(niterations = 100), ) end - + # MC t_mc = @elapsed begin - res_mc = project_to(ProjectedTo(NormalMeanVariance), target; - strategy = ControlVariateStrategy(nsamples=1000), + res_mc = project_to( + ProjectedTo(NormalMeanVariance), + target; + strategy = ControlVariateStrategy(nsamples = 1000), initial_point = initial, - parameters = ProjectionParameters(niterations=100) + parameters = ProjectionParameters(niterations = 100), ) end - + println("Analytic time: $t_analytic") println("MC time: $t_mc") - + # Analytic should be faster (no sampling) and more accurate (converge to exact target) - @test isapprox(mean(res_analytic), mean(target_dist), atol=2e-2) - @test isapprox(std(res_analytic), std(target_dist), atol=2e-2) - + @test isapprox(mean(res_analytic), mean(target_dist), atol = 2e-2) + @test isapprox(std(res_analytic), std(target_dist), atol = 2e-2) + # MC might have some noise @test abs(mean(res_analytic) - 5.0) < abs(mean(res_mc) - 5.0) + 0.1 # Heuristic check end @@ -128,7 +141,3 @@ using LinearAlgebra end end - - - - From efa331ed49d183c63c06cb2656fea8c9c1d0eddb Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Sun, 23 Nov 2025 17:32:44 +0100 Subject: [PATCH 03/17] fix: update ClosedFormExpectations.jl to 0.3.0 --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index c153099..13ffaac 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,7 @@ ClosedFormExpectationsExt = "ClosedFormExpectations" [compat] BayesBase = "1.5.0" Bumper = "0.6" +ClosedFormExpectations = "0.3.0" Distributions = "0.25" ExponentialFamily = "2.0.0" ExponentialFamilyManifolds = "3.0.3" From 5792428fbcf132bf2a7ead556cbcaa55ac6f0532 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Sun, 23 Nov 2025 17:35:48 +0100 Subject: [PATCH 04/17] fix: Project.toml --- Project.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 13ffaac..e0c2439 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "3.1.3" [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" -ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b" ExponentialFamilyManifolds = "5c9727c4-3b82-4ab3-b165-76e2eb971b08" @@ -51,7 +50,7 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -ClosedFormExpectations = "675a34d0-5017-47c3-81d8-6952f69c707e" +ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" From 57156b513e902c2790ff6f21d02dd5110bf0bfef Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Sun, 23 Nov 2025 17:59:29 +0100 Subject: [PATCH 05/17] fix: properly mark tests --- Project.toml | 3 + test/strategies/closed_form_tests.jl | 264 ++++++++++++++------------- 2 files changed, 138 insertions(+), 129 deletions(-) diff --git a/Project.toml b/Project.toml index e0c2439..e29d0d0 100644 --- a/Project.toml +++ b/Project.toml @@ -47,6 +47,9 @@ StaticArrays = "1.9" StatsFuns = "1.3" julia = "1.10" +[weakdeps] +ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" + [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl index 0d19f1d..fd87889 100644 --- a/test/strategies/closed_form_tests.jl +++ b/test/strategies/closed_form_tests.jl @@ -1,143 +1,149 @@ -module ClosedFormStrategyTests - -using Test -using BayesBase -using ExponentialFamilyProjection -using ExponentialFamilyProjection: ControlVariateStrategy -using ClosedFormExpectations -using Distributions -using ExponentialFamily -using StableRNGs -using LinearAlgebra - -import BayesBase: ProductOf - -@testset "ClosedFormStrategy" begin - - @testset "LogGamma projected to Normal" begin - # Target: LogGamma(α=2.0, β=1.0) - # Density: p(x) ∝ x^(α-1) * e^(-βx) - # Log density: (α-1)log(x) - βx - - # Note: Normal support is (-∞, ∞), Gamma is (0, ∞). - # This projection is technically improper if Normal puts mass on x < 0. - # But we can test if the gradient machinery works. - # To make it "reasonable", we put the Normal far from 0. - - target_dist = LogGamma(20.0, 1.0) # Mean 20, Var 20. - target = Logpdf(target_dist) - - # Initial approximation: Normal(15, 5) - initial_dist = Normal(15.0, 5.0) +@testitem "LogGamma projected to Normal" begin + using BayesBase + using ExponentialFamilyProjection + using ClosedFormExpectations + using Distributions + using ExponentialFamily + using StableRNGs + using LinearAlgebra + + # Target: LogGamma(α=2.0, β=1.0) + # Density: p(x) ∝ x^(α-1) * e^(-βx) + # Log density: (α-1)log(x) - βx + + # Note: Normal support is (-∞, ∞), Gamma is (0, ∞). + # This projection is technically improper if Normal puts mass on x < 0. + # But we can test if the gradient machinery works. + # To make it "reasonable", we put the Normal far from 0. + + target_dist = LogGamma(20.0, 1.0) # Mean 20, Var 20. + target = Logpdf(target_dist) + + # Initial approximation: Normal(15, 5) + initial_dist = Normal(15.0, 5.0) + + # Create strategy + strategy = ClosedFormStrategy() + + # Project + # We use a small number of iterations to check it runs and descends + result = project_to( + ProjectedTo(NormalMeanVariance), + target; + strategy = strategy, + initial_point = initial_dist, + parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5), + ) + + @test result isa NormalMeanVariance + μ, v = mean(result), var(result) + + # We expect it to move towards the mode of the LogGamma. + # Mode of LogGamma(α, β) is at x = log(α*β). + # Here α=20, β=1. Mode at log(20) ≈ 3.0. + + @test 2.0 < μ < 4.0 +end - # Create strategy - strategy = ClosedFormStrategy() +@testitem "LogNormal projected to Gamma" begin + using BayesBase + using ExponentialFamilyProjection + using ExponentialFamilyProjection: ControlVariateStrategy + using ClosedFormExpectations + using Distributions + using ExponentialFamily + using StableRNGs + using LinearAlgebra + + # Target: LogNormal(μ=1.0, σ=0.5) + # Approx: Gamma + + target_dist = LogNormal(1.0, 0.5) + target = Logpdf(target_dist) + + # Initial: Gamma(2.0, 2.0) -> mean 4, var 8 + initial_dist = Gamma(2.0, 2.0) + + strategy = ClosedFormStrategy() + + result = project_to( + ProjectedTo(Gamma), + target; + strategy = strategy, + initial_point = initial_dist, + parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5), + ) + + @test result isa GammaDistributionsFamily + + # Check if it ran without error and produced valid parameters + @test shape(result) > 0 + @test scale(result) > 0 + + # Comparison with ControlVariateStrategy + # CV strategy is stochastic, so it might be slightly different but should be close. + + cv_strategy = ControlVariateStrategy(nsamples = 500) + result_cv = project_to( + ProjectedTo(Gamma), + target; + strategy = cv_strategy, + initial_point = initial_dist, + parameters = ProjectionParameters(niterations = 50, tolerance = 1e-4), + ) + + # KL between result and result_cv should be small + # Since they are both Gamma, we can measure parameter distance or KL + + @test isapprox(mean(result), mean(result_cv), rtol = 0.1) +end - # Project - # We use a small number of iterations to check it runs and descends - result = project_to( +@testitem "Comparison: Speed and Accuracy" begin + using BayesBase + using ExponentialFamilyProjection + using ExponentialFamilyProjection: ControlVariateStrategy + using ClosedFormExpectations + using Distributions + using ExponentialFamily + using StableRNGs + using LinearAlgebra + + # Simple case: Normal to Normal (should be exact one step if Newton, but we use GradientDescent) + # Let's use a target that is actually a Normal + target_dist = Normal(5.0, 2.0) + target = Logpdf(target_dist) + + initial = Normal(0.0, 1.0) + + # Analytic + t_analytic = @elapsed begin + res_analytic = project_to( ProjectedTo(NormalMeanVariance), target; - strategy = strategy, - initial_point = initial_dist, - parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5), + strategy = ClosedFormStrategy(), + initial_point = initial, + parameters = ProjectionParameters(niterations = 100), ) - - @test result isa NormalMeanVariance - μ, v = mean(result), var(result) - - # We expect it to move towards the mode of the LogGamma. - # Mode of LogGamma(α, β) is at x = log(α*β). - # Here α=20, β=1. Mode at log(20) ≈ 3.0. - - @test 2.0 < μ < 4.0 end - @testset "LogNormal projected to Gamma" begin - # Target: LogNormal(μ=1.0, σ=0.5) - # Approx: Gamma - - target_dist = LogNormal(1.0, 0.5) - target = Logpdf(target_dist) - - # Initial: Gamma(2.0, 2.0) -> mean 4, var 8 - initial_dist = Gamma(2.0, 2.0) - - strategy = ClosedFormStrategy() - - result = project_to( - ProjectedTo(Gamma), - target; - strategy = strategy, - initial_point = initial_dist, - parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5), - ) - - @test result isa GammaDistributionsFamily - - # Check if it ran without error and produced valid parameters - @test shape(result) > 0 - @test scale(result) > 0 - - # Comparison with ControlVariateStrategy - # CV strategy is stochastic, so it might be slightly different but should be close. - - cv_strategy = ControlVariateStrategy(nsamples = 500) - result_cv = project_to( - ProjectedTo(Gamma), + # MC + t_mc = @elapsed begin + res_mc = project_to( + ProjectedTo(NormalMeanVariance), target; - strategy = cv_strategy, - initial_point = initial_dist, - parameters = ProjectionParameters(niterations = 50, tolerance = 1e-4), + strategy = ControlVariateStrategy(nsamples = 1000), + initial_point = initial, + parameters = ProjectionParameters(niterations = 100), ) - - # KL between result and result_cv should be small - # Since they are both Gamma, we can measure parameter distance or KL - - @test isapprox(mean(result), mean(result_cv), rtol = 0.1) end - @testset "Comparison: Speed and Accuracy" begin - # Simple case: Normal to Normal (should be exact one step if Newton, but we use GradientDescent) - # Let's use a target that is actually a Normal - target_dist = Normal(5.0, 2.0) - target = Logpdf(target_dist) - - initial = Normal(0.0, 1.0) - - # Analytic - t_analytic = @elapsed begin - res_analytic = project_to( - ProjectedTo(NormalMeanVariance), - target; - strategy = ClosedFormStrategy(), - initial_point = initial, - parameters = ProjectionParameters(niterations = 100), - ) - end - - # MC - t_mc = @elapsed begin - res_mc = project_to( - ProjectedTo(NormalMeanVariance), - target; - strategy = ControlVariateStrategy(nsamples = 1000), - initial_point = initial, - parameters = ProjectionParameters(niterations = 100), - ) - end - - println("Analytic time: $t_analytic") - println("MC time: $t_mc") - - # Analytic should be faster (no sampling) and more accurate (converge to exact target) - @test isapprox(mean(res_analytic), mean(target_dist), atol = 2e-2) - @test isapprox(std(res_analytic), std(target_dist), atol = 2e-2) - - # MC might have some noise - @test abs(mean(res_analytic) - 5.0) < abs(mean(res_mc) - 5.0) + 0.1 # Heuristic check - end + println("Analytic time: $t_analytic") + println("MC time: $t_mc") -end + # Analytic should be faster (no sampling) and more accurate (converge to exact target) + @test isapprox(mean(res_analytic), mean(target_dist), atol = 2e-2) + @test isapprox(std(res_analytic), std(target_dist), atol = 2e-2) + # MC might have some noise + @test abs(mean(res_analytic) - 5.0) < abs(mean(res_mc) - 5.0) + 0.1 # Heuristic check end From 42494b83bea13c770449cfb8fda3895de67f106a Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Sun, 23 Nov 2025 20:17:25 +0100 Subject: [PATCH 06/17] fix: revery try catch left because of debugging --- src/manopt/projection_objective.jl | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/manopt/projection_objective.jl b/src/manopt/projection_objective.jl index a1ec6a3..ecc459f 100644 --- a/src/manopt/projection_objective.jl +++ b/src/manopt/projection_objective.jl @@ -58,18 +58,7 @@ function call_objective( logpartition = ExponentialFamily.logpartition(current_ef) gradlogpartition = ExponentialFamily.gradlogpartition(current_ef) - fisher = ExponentialFamily.fisherinformation(current_ef) - inv_fisher = try - cholinv(fisher) - catch e - # If the Fisher information matrix is not positive definite, we try to compute the pseudo-inverse - # This can happen if the distribution is degenerate or if the parameters are on the boundary - if e isa PosDefException || e isa LinearAlgebra.PosDefException - pinv(fisher) - else - rethrow(e) - end - end + inv_fisher = cholinv(ExponentialFamily.fisherinformation(current_ef)) # If we have some supplementary natural parameters in the objective # we must subtract them from the natural parameters of the current η @@ -104,4 +93,4 @@ end function (objective::ProjectionCostGradientObjective)(M::AbstractManifold, X, p) return call_objective(objective, M, X, p) -end +end \ No newline at end of file From 95626b24331f37105aa47948ec0e477eaf572a9b Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Sun, 23 Nov 2025 20:18:22 +0100 Subject: [PATCH 07/17] style: :art: --- ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl | 6 +++--- src/manopt/projection_objective.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl index 53cf568..7cda308 100644 --- a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl +++ b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl @@ -131,11 +131,11 @@ function ExponentialFamilyProjection.preprocess_strategy_argument( # The closure typically has one field holding the ProductOf. fn_type = typeof(argument) field_names = fieldnames(fn_type) - + if !isempty(field_names) # Get the first field (usually the captured ProductOf) captured = getfield(argument, first(field_names)) - + # If it's a ProductOf, use it directly if captured isa ProductOf return (strategy, Logpdf(captured)) @@ -146,7 +146,7 @@ function ExponentialFamilyProjection.preprocess_strategy_argument( return (strategy, Logpdf(captured)) end end - + # Fallback: keep the function as-is return (strategy, argument) end diff --git a/src/manopt/projection_objective.jl b/src/manopt/projection_objective.jl index ecc459f..0db69d4 100644 --- a/src/manopt/projection_objective.jl +++ b/src/manopt/projection_objective.jl @@ -93,4 +93,4 @@ end function (objective::ProjectionCostGradientObjective)(M::AbstractManifold, X, p) return call_objective(objective, M, X, p) -end \ No newline at end of file +end From c519322a1f39c1065a9cbffd81fd3f8f3798e76f Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 14:34:39 +0100 Subject: [PATCH 08/17] test(fix): cover the closed form startegy --- .../ClosedFormExpectationsExt.jl | 9 + test/strategies/closed_form_tests.jl | 385 ++++++++++++++++-- 2 files changed, 359 insertions(+), 35 deletions(-) diff --git a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl index 7cda308..c8efd80 100644 --- a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl +++ b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl @@ -151,6 +151,15 @@ function ExponentialFamilyProjection.preprocess_strategy_argument( return (strategy, argument) end +# Generic fallback for non-Function arguments +function ExponentialFamilyProjection.preprocess_strategy_argument( + strategy::ClosedFormStrategy, + argument::Distribution, +) + # ClosedFormStrategy accepts any callable or distribution as argument + return (strategy, Logpdf(argument)) +end + # Generic fallback for non-Function arguments function ExponentialFamilyProjection.preprocess_strategy_argument( strategy::ClosedFormStrategy, diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl index fd87889..216d01b 100644 --- a/test/strategies/closed_form_tests.jl +++ b/test/strategies/closed_form_tests.jl @@ -1,47 +1,323 @@ -@testitem "LogGamma projected to Normal" begin - using BayesBase +@testitem "ClosedFormStrategy generic properties" begin using ExponentialFamilyProjection using ClosedFormExpectations + import ExponentialFamilyProjection: get_nsamples + + strategy = ClosedFormStrategy() + + @test strategy isa ClosedFormStrategy + @test get_nsamples(strategy) == 0 +end + +@testitem "ClosedFormStrategy create_state!" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamilyManifolds + using ExponentialFamily + using Distributions + import ExponentialFamilyProjection: create_state!, ProjectionParameters + + distributions = [Beta(5, 5), NormalMeanVariance(0, 1), Gamma(2, 2)] + parameters = ProjectionParameters() + + for dist in distributions + ef = convert(ExponentialFamilyDistribution, dist) + T = ExponentialFamily.exponential_family_typetag(ef) + d = size(mean(ef)) + c = getconditioner(ef) + M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) + + # Target wrapped in Logpdf + target = Logpdf(dist) + + # Test without supplementary parameters + state1 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, ()) + state2 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, ()) + + @test state1.target === target + @test state2.target === target + @test state1 !== state2 # Different objects in memory + + # Test with supplementary parameters + state3 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef,)) + state4 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef,)) + + @test state3.target === target + @test state4.target === target + @test state3 !== state4 + + # Test with multiple supplementary parameters + state5 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef, ef)) + state6 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef, ef)) + + @test state5.target === target + @test state6.target === target + @test state5 !== state6 + end +end + +@testitem "ClosedFormStrategy prepare_state!" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamilyManifolds + using ExponentialFamily using Distributions + import ExponentialFamilyProjection: create_state!, prepare_state!, ProjectionParameters + + distributions = [ + NormalMeanVariance(0, 1), + Gamma(2, 2), + Beta(3, 3), + ] + + for dist in distributions + target_dist = dist + target = Logpdf(target_dist) + + ef = convert(ExponentialFamilyDistribution, dist) + T = ExponentialFamily.exponential_family_typetag(ef) + d = size(mean(ef)) + c = getconditioner(ef) + M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) + + strategy = ClosedFormStrategy() + parameters = ProjectionParameters() + + # Create initial state + state1 = create_state!(strategy, M, parameters, target, ef, ()) + + # prepare_state! should return the same state object + state2 = prepare_state!(strategy, state1, M, parameters, target, ef, ()) + + @test state1 === state2 + @test state1.target === state2.target + + # Test with supplementary parameters + supplementary_η = (getnaturalparameters(ef),) + state3 = create_state!(strategy, M, parameters, target, ef, supplementary_η) + state4 = prepare_state!(strategy, state3, M, parameters, target, ef, supplementary_η) + + @test state3 === state4 + end +end + +@testitem "`ClosedFormStrategy` should fail if given a list of samples instead of a function" begin + using ExponentialFamily + using ClosedFormExpectations + + prj = ProjectedTo( + Beta; + parameters = ProjectionParameters( + strategy = ClosedFormStrategy(), + ), + ) + + # ClosedFormStrategy doesn't explicitly reject arrays, but it won't work properly + # The extension's preprocess_strategy_argument will keep the array as-is + # and then the compute_gradient! will fail when trying to use it + # This is the expected behavior - it will error during execution + + samples = [0.5, 0.6, 0.7] + + # This should fail during the projection, not during preprocessing + @test_throws Exception project_to(prj, samples) +end + +@testitem "ClosedFormStrategy argument preprocessing for ProductOf" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamily + using Distributions + import ExponentialFamily: ProductOf + import ExponentialFamilyProjection: preprocess_strategy_argument + + strategy = ClosedFormStrategy() + + # Case 1: ProductOf wrapped in closure-like struct + dist1 = Normal(0, 1) + dist2 = Normal(1, 1) + prod_dist = ProductOf(dist1, dist2) + + # Simulate RxInfer-style closure + struct MockClosure{T} + captured::T + end + (c::MockClosure)(x) = 1.0 + + closure = MockClosure(prod_dist) + + # preprocess should extract ProductOf and wrap in Logpdf + result_strat, result_arg = preprocess_strategy_argument(strategy, closure) + @test result_strat === strategy + @test result_arg isa Logpdf + @test result_arg.distribution === prod_dist +end + +@testitem "ClosedFormStrategy argument preprocessing for Distribution" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using Distributions + import ExponentialFamilyProjection: preprocess_strategy_argument + + strategy = ClosedFormStrategy() + dist1 = Normal(0, 1) + + # Case 2: Distribution wrapped in closure + struct MockClosure{T} + captured::T + end + (c::MockClosure)(x) = 1.0 + + closure_dist = MockClosure(dist1) + result_strat, result_arg = preprocess_strategy_argument(strategy, closure_dist) + @test result_strat === strategy + @test result_arg isa Logpdf + @test result_arg.distribution === dist1 +end + +@testitem "ClosedFormStrategy argument preprocessing for plain function" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + import ExponentialFamilyProjection: preprocess_strategy_argument + + strategy = ClosedFormStrategy() + + # Case 3: Plain function (should return as-is) + fn = (x) -> x^2 + result_strat, result_arg = preprocess_strategy_argument(strategy, fn) + @test result_strat === strategy + @test result_arg === fn +end + +@testitem "ClosedFormStrategy argument preprocessing for direct Distribution" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using Distributions + import ExponentialFamilyProjection: preprocess_strategy_argument + + strategy = ClosedFormStrategy() + dist1 = Normal(0, 1) + + # Case 4: Distribution directly (non-Function argument) + result_strat, result_arg = preprocess_strategy_argument(strategy, dist1) + @test result_strat === strategy + @test result_arg isa ClosedFormExpectations.Logpdf + @test result_arg.dist == dist1 +end + +@testitem "ClosedFormStrategy cost computation" begin + using ExponentialFamilyProjection + using ClosedFormExpectations using ExponentialFamily + using Distributions + using ExponentialFamilyManifolds using StableRNGs - using LinearAlgebra + import ExponentialFamilyProjection: compute_cost, create_state!, ProjectionParameters + + strategy = ClosedFormStrategy() + + # Target: Normal(0, 1) + target_dist = Normal(0.0, 1.0) + target = Logpdf(target_dist) + + # Variational: Normal(1, 1) + # KL(q || p) where p=N(0,1), q=N(1,1) (variances equal) + # KL = 0.5 * ( (μ1-μ2)^2/σ^2 ) = 0.5 * (1-0)^2/1 = 0.5 + + # Manifold setup + M = ExponentialFamilyManifolds.get_natural_manifold(NormalMeanVariance, ()) + q_dist = NormalMeanVariance(1.0, 1.0) + ef = convert(ExponentialFamilyDistribution, q_dist) + η = getnaturalparameters(ef) + + # Create state + parameters = ProjectionParameters() + state = create_state!(strategy, M, parameters, target, ef, ()) + + # Dummy args for cost (not all used by ClosedFormStrategy's compute_cost) + logp = logpartition(ef) + gradlogp = gradlogpartition(ef) + inv_fisher = inv(fisherinformation(ef)) + + cost = compute_cost(M, strategy, state, η, logp, gradlogp, inv_fisher) + + # compute_cost returns: -entropy(q) - E_q[log p] + # KL(q||p) = E_q[log q] - E_q[log p] = -entropy(q) - E_q[log p] + # So it should approximately equal KL + @test cost ≈ 0.5 atol = 1e-6 +end - # Target: LogGamma(α=2.0, β=1.0) - # Density: p(x) ∝ x^(α-1) * e^(-βx) - # Log density: (α-1)log(x) - βx +@testitem "ClosedFormStrategy gradient computation" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamily + using Distributions + using ExponentialFamilyManifolds + using LinearAlgebra + using StableRNGs + import ExponentialFamilyProjection: compute_gradient!, create_state!, ProjectionParameters - # Note: Normal support is (-∞, ∞), Gamma is (0, ∞). - # This projection is technically improper if Normal puts mass on x < 0. - # But we can test if the gradient machinery works. - # To make it "reasonable", we put the Normal far from 0. + strategy = ClosedFormStrategy() - target_dist = LogGamma(20.0, 1.0) # Mean 20, Var 20. + # Simple test: project Normal to Normal + target_dist = Normal(2.0, 1.0) target = Logpdf(target_dist) - # Initial approximation: Normal(15, 5) - initial_dist = Normal(15.0, 5.0) + # Start from a different point + q_dist = NormalMeanVariance(0.0, 1.0) + ef = convert(ExponentialFamilyDistribution, q_dist) + + M = ExponentialFamilyManifolds.get_natural_manifold(NormalMeanVariance, ()) + η = getnaturalparameters(ef) + + parameters = ProjectionParameters() + state = create_state!(strategy, M, parameters, target, ef, ()) + + # Compute gradient + X = similar(η) + logp = logpartition(ef) + gradlogp = gradlogpartition(ef) + inv_fisher = inv(fisherinformation(ef)) + + X_result = compute_gradient!(M, strategy, state, X, η, logp, gradlogp, inv_fisher) + + @test X_result === X + @test length(X) == length(η) + @test all(isfinite.(X)) + + # The gradient should point towards the target + # Since target is at μ=2, gradient should push η in that direction +end + + +@testitem "LogGamma projected to Normal" begin + using BayesBase + using ExponentialFamilyProjection + using ClosedFormExpectations + using Distributions + using ExponentialFamily + using StableRNGs + using LinearAlgebra + + target_dist = LogGamma(20.0, 1.0) + target = Logpdf(target_dist) # Create strategy strategy = ClosedFormStrategy() # Project - # We use a small number of iterations to check it runs and descends result = project_to( ProjectedTo(NormalMeanVariance), target; strategy = strategy, - initial_point = initial_dist, parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5), ) @test result isa NormalMeanVariance μ, v = mean(result), var(result) - # We expect it to move towards the mode of the LogGamma. - # Mode of LogGamma(α, β) is at x = log(α*β). - # Here α=20, β=1. Mode at log(20) ≈ 3.0. - + # Mode of LogGamma(α, β) is at x = log(α*β) + # Here α=20, β=1. Mode at log(20) ≈ 3.0 @test 2.0 < μ < 4.0 end @@ -56,12 +332,10 @@ end using LinearAlgebra # Target: LogNormal(μ=1.0, σ=0.5) - # Approx: Gamma - target_dist = LogNormal(1.0, 0.5) target = Logpdf(target_dist) - # Initial: Gamma(2.0, 2.0) -> mean 4, var 8 + # Initial: Gamma(2.0, 2.0) initial_dist = Gamma(2.0, 2.0) strategy = ClosedFormStrategy() @@ -81,8 +355,6 @@ end @test scale(result) > 0 # Comparison with ControlVariateStrategy - # CV strategy is stochastic, so it might be slightly different but should be close. - cv_strategy = ControlVariateStrategy(nsamples = 500) result_cv = project_to( ProjectedTo(Gamma), @@ -92,13 +364,11 @@ end parameters = ProjectionParameters(niterations = 50, tolerance = 1e-4), ) - # KL between result and result_cv should be small - # Since they are both Gamma, we can measure parameter distance or KL - + # Should be close @test isapprox(mean(result), mean(result_cv), rtol = 0.1) end -@testitem "Comparison: Speed and Accuracy" begin +@testitem "ClosedFormStrategy vs ControlVariateStrategy: Speed and Accuracy" begin using BayesBase using ExponentialFamilyProjection using ExponentialFamilyProjection: ControlVariateStrategy @@ -108,8 +378,7 @@ end using StableRNGs using LinearAlgebra - # Simple case: Normal to Normal (should be exact one step if Newton, but we use GradientDescent) - # Let's use a target that is actually a Normal + # Simple case: Normal to Normal target_dist = Normal(5.0, 2.0) target = Logpdf(target_dist) @@ -137,13 +406,59 @@ end ) end - println("Analytic time: $t_analytic") - println("MC time: $t_mc") + println("ClosedFormStrategy time: $t_analytic") + println("ControlVariateStrategy time: $t_mc") - # Analytic should be faster (no sampling) and more accurate (converge to exact target) + # Analytic should be more accurate (converge to exact target) @test isapprox(mean(res_analytic), mean(target_dist), atol = 2e-2) @test isapprox(std(res_analytic), std(target_dist), atol = 2e-2) - # MC might have some noise - @test abs(mean(res_analytic) - 5.0) < abs(mean(res_mc) - 5.0) + 0.1 # Heuristic check + # ClosedFormStrategy should be at least as accurate as MC + @test abs(mean(res_analytic) - 5.0) <= abs(mean(res_mc) - 5.0) + 0.1 +end + + + +@testitem "ClosedFormStrategy ProjectionCostGradientObjective integration" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamily + using Distributions + using ExponentialFamilyManifolds + using Manifolds + using StableRNGs + import ExponentialFamilyProjection: ProjectionCostGradientObjective, create_state! + + # Test that the objective works correctly with ClosedFormStrategy + strategy = ClosedFormStrategy() + target_dist = Normal(2.0, 1.0) + target = Logpdf(target_dist) + + q_dist = NormalMeanVariance(0.0, 1.0) + ef = convert(ExponentialFamilyDistribution, q_dist) + + M = ExponentialFamilyManifolds.get_natural_manifold(NormalMeanVariance, ()) + η = getnaturalparameters(ef) + + parameters = ProjectionParameters(rng = StableRNG(42)) + state = create_state!(strategy, M, parameters, target, ef, ()) + + obj = ProjectionCostGradientObjective( + parameters, + target, + copy(η), + (), + strategy, + state, + ) + + # Test evaluation + p = ExponentialFamilyManifolds.partition_point(M, η) + X = Manifolds.zero_vector(M, p) + + cost, X_result = obj(M, X, p) + + @test isfinite(cost) + @test cost > 0 # KL divergence should be positive + @test all(isfinite.(X_result)) end From 1f057f6f28ed7e5d68285837c45743806de38bd9 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 14:39:48 +0100 Subject: [PATCH 09/17] fix: proper tests for captured in closure --- test/strategies/closed_form_tests.jl | 53 ++++++++++++++-------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl index 216d01b..f1ffca1 100644 --- a/test/strategies/closed_form_tests.jl +++ b/test/strategies/closed_form_tests.jl @@ -123,56 +123,55 @@ end @test_throws Exception project_to(prj, samples) end -@testitem "ClosedFormStrategy argument preprocessing for ProductOf" begin + +@testitem "ClosedFormStrategy argument preprocessing for Distribution in closure" begin using ExponentialFamilyProjection using ClosedFormExpectations - using ExponentialFamily using Distributions - import ExponentialFamily: ProductOf import ExponentialFamilyProjection: preprocess_strategy_argument strategy = ClosedFormStrategy() - - # Case 1: ProductOf wrapped in closure-like struct dist1 = Normal(0, 1) - dist2 = Normal(1, 1) - prod_dist = ProductOf(dist1, dist2) - # Simulate RxInfer-style closure - struct MockClosure{T} - captured::T + # Test extraction of Distribution from a closure (simulating RxInfer behavior) + # The closure captures a Distribution, and preprocess should extract it + closure_with_dist = let d = dist1 + (x) -> logpdf(d, x) end - (c::MockClosure)(x) = 1.0 - closure = MockClosure(prod_dist) - - # preprocess should extract ProductOf and wrap in Logpdf - result_strat, result_arg = preprocess_strategy_argument(strategy, closure) + result_strat, result_arg = preprocess_strategy_argument(strategy, closure_with_dist) @test result_strat === strategy + + # The function should extract the captured Distribution and wrap it in Logpdf @test result_arg isa Logpdf - @test result_arg.distribution === prod_dist + @test result_arg.dist === dist1 end -@testitem "ClosedFormStrategy argument preprocessing for Distribution" begin +@testitem "ClosedFormStrategy argument preprocessing for ProductOf in closure" begin using ExponentialFamilyProjection using ClosedFormExpectations using Distributions + using BayesBase import ExponentialFamilyProjection: preprocess_strategy_argument + import BayesBase: ProductOf strategy = ClosedFormStrategy() - dist1 = Normal(0, 1) - - # Case 2: Distribution wrapped in closure - struct MockClosure{T} - captured::T + + # Test extraction of ProductOf from a closure (RxInfer use case) + left = Beta(10, 10) + right = Beta(3, 3) + prod = ProductOf(left, right) + + closure_with_product = let p = prod + (x) -> logpdf(p, x) end - (c::MockClosure)(x) = 1.0 - closure_dist = MockClosure(dist1) - result_strat, result_arg = preprocess_strategy_argument(strategy, closure_dist) + result_strat, result_arg = preprocess_strategy_argument(strategy, closure_with_product) @test result_strat === strategy + + # The function should extract the captured ProductOf and wrap it in Logpdf @test result_arg isa Logpdf - @test result_arg.distribution === dist1 + @test result_arg.dist === prod end @testitem "ClosedFormStrategy argument preprocessing for plain function" begin @@ -182,7 +181,7 @@ end strategy = ClosedFormStrategy() - # Case 3: Plain function (should return as-is) + # Plain function without extractable Distribution (should return as-is) fn = (x) -> x^2 result_strat, result_arg = preprocess_strategy_argument(strategy, fn) @test result_strat === strategy From d521816554517c3c3d89940a4beb370863e5ad43 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 15:10:33 +0100 Subject: [PATCH 10/17] fix: normal readable error --- .../ClosedFormExpectationsExt.jl | 52 ++++++++++--------- test/strategies/closed_form_tests.jl | 10 ++-- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl index c8efd80..afd8929 100644 --- a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl +++ b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl @@ -66,7 +66,9 @@ function ExponentialFamilyProjection.compute_gradient!( end # Helper to create state -struct ClosedFormStrategyState{T} +# Note: We use a mutable struct to ensure each create_state! call +# produces a distinct object in memory, even with the same target +mutable struct ClosedFormStrategyState{T} target::T end @@ -124,31 +126,31 @@ end # Special handling for RxInfer closures that wrap ProductOf function ExponentialFamilyProjection.preprocess_strategy_argument( strategy::ClosedFormStrategy, - argument::Function, -) - # RxInfer wraps ProductOf in a closure. - # Extract the ProductOf from the closure's captured variables. - # The closure typically has one field holding the ProductOf. - fn_type = typeof(argument) - field_names = fieldnames(fn_type) - - if !isempty(field_names) - # Get the first field (usually the captured ProductOf) - captured = getfield(argument, first(field_names)) - - # If it's a ProductOf, use it directly - if captured isa ProductOf - return (strategy, Logpdf(captured)) - end - - # If it's a Distribution (e.g. LogNormal inside ProjectionExt closure), use it directly - if captured isa Distribution - return (strategy, Logpdf(captured)) - end + argument::F, +) where {F<:Function} + # RxInfer wraps ProductOf or Distribution in a closure. + # Extract the ProductOf/Distribution from the closure's captured variables. + # The closure typically has one field holding the ProductOf/Distribution. + field_names = fieldnames(F) + + if isempty(field_names) + error( + """`ClosedFormStrategy` requires a function that captures a `Distribution` or `ProductOf` in its closure. + + Expected form: + let dist = Normal(0, 1) + (x) -> logpdf(dist, x) + end + + Got a function without captured variables: $F + + If you want to use a plain function, pass the `Distribution` directly instead of wrapping it in a function. + """ + ) end - - # Fallback: keep the function as-is - return (strategy, argument) + + captured = getfield(argument, first(field_names)) + return (strategy, Logpdf(captured)) end # Generic fallback for non-Function arguments diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl index f1ffca1..7eb9c03 100644 --- a/test/strategies/closed_form_tests.jl +++ b/test/strategies/closed_form_tests.jl @@ -174,18 +174,18 @@ end @test result_arg.dist === prod end -@testitem "ClosedFormStrategy argument preprocessing for plain function" begin +@testitem "ClosedFormStrategy argument preprocessing for plain function should error" begin using ExponentialFamilyProjection using ClosedFormExpectations import ExponentialFamilyProjection: preprocess_strategy_argument strategy = ClosedFormStrategy() - # Plain function without extractable Distribution (should return as-is) + # Plain function without captured variables should throw an error + # because ClosedFormStrategy needs to extract Distribution/ProductOf from closure fn = (x) -> x^2 - result_strat, result_arg = preprocess_strategy_argument(strategy, fn) - @test result_strat === strategy - @test result_arg === fn + + @test_throws "`ClosedFormStrategy` requires a function that captures a `Distribution` or `ProductOf` in its closure" preprocess_strategy_argument(strategy, fn) end @testitem "ClosedFormStrategy argument preprocessing for direct Distribution" begin From 8f71ceb50b6d5ff3df2db2a11cc1e21a52908816 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 15:15:21 +0100 Subject: [PATCH 11/17] style: :art: --- test/strategies/closed_form_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl index 7eb9c03..ff7d3b9 100644 --- a/test/strategies/closed_form_tests.jl +++ b/test/strategies/closed_form_tests.jl @@ -101,7 +101,7 @@ end end end -@testitem "`ClosedFormStrategy` should fail if given a list of samples instead of a function" begin +@testitem "ClosedFormStrategy should fail if given a list of samples instead of a function" begin using ExponentialFamily using ClosedFormExpectations From 7d30e2bb5d0b24d58a6422acf5a370f564e99b83 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 17:38:12 +0100 Subject: [PATCH 12/17] style: make format --- .../ClosedFormExpectationsExt.jl | 12 +-- test/manopt/bounded_norm_update_rule_tests.jl | 16 +++- test/strategies/closed_form_tests.jl | 76 +++++++++---------- 3 files changed, 54 insertions(+), 50 deletions(-) diff --git a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl index afd8929..10e98a1 100644 --- a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl +++ b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl @@ -132,23 +132,23 @@ function ExponentialFamilyProjection.preprocess_strategy_argument( # Extract the ProductOf/Distribution from the closure's captured variables. # The closure typically has one field holding the ProductOf/Distribution. field_names = fieldnames(F) - + if isempty(field_names) error( """`ClosedFormStrategy` requires a function that captures a `Distribution` or `ProductOf` in its closure. - + Expected form: let dist = Normal(0, 1) (x) -> logpdf(dist, x) end - + Got a function without captured variables: $F - + If you want to use a plain function, pass the `Distribution` directly instead of wrapping it in a function. - """ + """, ) end - + captured = getfield(argument, first(field_names)) return (strategy, Logpdf(captured)) end diff --git a/test/manopt/bounded_norm_update_rule_tests.jl b/test/manopt/bounded_norm_update_rule_tests.jl index 38e4684..fbea5f4 100644 --- a/test/manopt/bounded_norm_update_rule_tests.jl +++ b/test/manopt/bounded_norm_update_rule_tests.jl @@ -82,8 +82,20 @@ for limit in (1, 1.0, 1.0f0), p in (zeros(Float64, 3), zeros(Float32, 3)) cpa = DefaultManoptProblem(M, ManifoldGradientObjective(f, grad_f)) gst = GradientDescentState(M; p = zero(p)) - @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule(limit)(cpa, gst, 1) - @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule(static(limit))(cpa, gst, 1) + @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule( + limit, + )( + cpa, + gst, + 1, + ) + @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule( + static(limit), + )( + cpa, + gst, + 1, + ) @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule( static(limit); direction = BoundedNormUpdateRule(limit), diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl index ff7d3b9..db1a04d 100644 --- a/test/strategies/closed_form_tests.jl +++ b/test/strategies/closed_form_tests.jl @@ -4,7 +4,7 @@ import ExponentialFamilyProjection: get_nsamples strategy = ClosedFormStrategy() - + @test strategy isa ClosedFormStrategy @test get_nsamples(strategy) == 0 end @@ -19,37 +19,37 @@ end distributions = [Beta(5, 5), NormalMeanVariance(0, 1), Gamma(2, 2)] parameters = ProjectionParameters() - + for dist in distributions ef = convert(ExponentialFamilyDistribution, dist) T = ExponentialFamily.exponential_family_typetag(ef) d = size(mean(ef)) c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) - + # Target wrapped in Logpdf target = Logpdf(dist) - + # Test without supplementary parameters state1 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, ()) state2 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, ()) - + @test state1.target === target @test state2.target === target @test state1 !== state2 # Different objects in memory - + # Test with supplementary parameters state3 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef,)) state4 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef,)) - + @test state3.target === target @test state4.target === target @test state3 !== state4 - + # Test with multiple supplementary parameters state5 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef, ef)) state6 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef, ef)) - + @test state5.target === target @test state6.target === target @test state5 !== state6 @@ -64,11 +64,7 @@ end using Distributions import ExponentialFamilyProjection: create_state!, prepare_state!, ProjectionParameters - distributions = [ - NormalMeanVariance(0, 1), - Gamma(2, 2), - Beta(3, 3), - ] + distributions = [NormalMeanVariance(0, 1), Gamma(2, 2), Beta(3, 3)] for dist in distributions target_dist = dist @@ -82,21 +78,22 @@ end strategy = ClosedFormStrategy() parameters = ProjectionParameters() - + # Create initial state state1 = create_state!(strategy, M, parameters, target, ef, ()) - + # prepare_state! should return the same state object state2 = prepare_state!(strategy, state1, M, parameters, target, ef, ()) - + @test state1 === state2 @test state1.target === state2.target - + # Test with supplementary parameters supplementary_η = (getnaturalparameters(ef),) state3 = create_state!(strategy, M, parameters, target, ef, supplementary_η) - state4 = prepare_state!(strategy, state3, M, parameters, target, ef, supplementary_η) - + state4 = + prepare_state!(strategy, state3, M, parameters, target, ef, supplementary_η) + @test state3 === state4 end end @@ -107,18 +104,16 @@ end prj = ProjectedTo( Beta; - parameters = ProjectionParameters( - strategy = ClosedFormStrategy(), - ), + parameters = ProjectionParameters(strategy = ClosedFormStrategy()), ) # ClosedFormStrategy doesn't explicitly reject arrays, but it won't work properly # The extension's preprocess_strategy_argument will keep the array as-is # and then the compute_gradient! will fail when trying to use it # This is the expected behavior - it will error during execution - + samples = [0.5, 0.6, 0.7] - + # This should fail during the projection, not during preprocessing @test_throws Exception project_to(prj, samples) end @@ -141,7 +136,7 @@ end result_strat, result_arg = preprocess_strategy_argument(strategy, closure_with_dist) @test result_strat === strategy - + # The function should extract the captured Distribution and wrap it in Logpdf @test result_arg isa Logpdf @test result_arg.dist === dist1 @@ -156,19 +151,19 @@ end import BayesBase: ProductOf strategy = ClosedFormStrategy() - + # Test extraction of ProductOf from a closure (RxInfer use case) left = Beta(10, 10) right = Beta(3, 3) prod = ProductOf(left, right) - + closure_with_product = let p = prod (x) -> logpdf(p, x) end result_strat, result_arg = preprocess_strategy_argument(strategy, closure_with_product) @test result_strat === strategy - + # The function should extract the captured ProductOf and wrap it in Logpdf @test result_arg isa Logpdf @test result_arg.dist === prod @@ -184,8 +179,11 @@ end # Plain function without captured variables should throw an error # because ClosedFormStrategy needs to extract Distribution/ProductOf from closure fn = (x) -> x^2 - - @test_throws "`ClosedFormStrategy` requires a function that captures a `Distribution` or `ProductOf` in its closure" preprocess_strategy_argument(strategy, fn) + + @test_throws "`ClosedFormStrategy` requires a function that captures a `Distribution` or `ProductOf` in its closure" preprocess_strategy_argument( + strategy, + fn, + ) end @testitem "ClosedFormStrategy argument preprocessing for direct Distribution" begin @@ -254,7 +252,8 @@ end using ExponentialFamilyManifolds using LinearAlgebra using StableRNGs - import ExponentialFamilyProjection: compute_gradient!, create_state!, ProjectionParameters + import ExponentialFamilyProjection: + compute_gradient!, create_state!, ProjectionParameters strategy = ClosedFormStrategy() @@ -283,7 +282,7 @@ end @test X_result === X @test length(X) == length(η) @test all(isfinite.(X)) - + # The gradient should point towards the target # Since target is at μ=2, gradient should push η in that direction end @@ -442,19 +441,12 @@ end parameters = ProjectionParameters(rng = StableRNG(42)) state = create_state!(strategy, M, parameters, target, ef, ()) - obj = ProjectionCostGradientObjective( - parameters, - target, - copy(η), - (), - strategy, - state, - ) + obj = ProjectionCostGradientObjective(parameters, target, copy(η), (), strategy, state) # Test evaluation p = ExponentialFamilyManifolds.partition_point(M, η) X = Manifolds.zero_vector(M, p) - + cost, X_result = obj(M, X, p) @test isfinite(cost) From 7624aeab40df41ff2bc00a5b6cdbe1eed2cd4698 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 17:48:58 +0100 Subject: [PATCH 13/17] test: add logbasemeasure test --- test/strategies/closed_form_tests.jl | 37 ++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl index db1a04d..d6fc471 100644 --- a/test/strategies/closed_form_tests.jl +++ b/test/strategies/closed_form_tests.jl @@ -453,3 +453,40 @@ end @test cost > 0 # KL divergence should be positive @test all(isfinite.(X_result)) end + +@testitem "ClosedFormStrategy logbasemeasure_correction for ConstantBaseMeasure" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamily + using Distributions + using Test + + # Get the extension module + # The extension should be loaded because both ExponentialFamilyProjection and ClosedFormExpectations are loaded + ClosedFormExpectationsExt = Base.get_extension(ExponentialFamilyProjection, :ClosedFormExpectationsExt) + + @test !isnothing(ClosedFormExpectationsExt) + + strategy = ClosedFormStrategy() + + # Create a mock or usage that triggers ConstantBaseMeasure + # We will use reflection/internals to test the specific method + + # Using a distribution that we know has ConstantBaseMeasure + # or constructing it manually if possible. + # ExponentialFamily.ConstantBaseMeasure is a singleton struct usually. + + base_measure = ExponentialFamily.ConstantBaseMeasure() + q_dist = Normal(0, 1) # Any distribution works as q_dist is passed through + grad_target = [1.0, 2.0] + + result = ClosedFormExpectationsExt.logbasemeasure_correction( + strategy, + base_measure, + q_dist, + grad_target + ) + + # The function should return grad_target exactly for ConstantBaseMeasure + @test result === grad_target +end From 4e21f93c19f7d7df61e442f116a530f9a2a089e1 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 18:43:10 +0100 Subject: [PATCH 14/17] fix: warning in tests --- test/strategies/closed_form_tests.jl | 91 ++++++++++++++++------------ 1 file changed, 52 insertions(+), 39 deletions(-) diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl index d6fc471..074a67f 100644 --- a/test/strategies/closed_form_tests.jl +++ b/test/strategies/closed_form_tests.jl @@ -300,15 +300,17 @@ end target_dist = LogGamma(20.0, 1.0) target = Logpdf(target_dist) - # Create strategy - strategy = ClosedFormStrategy() - - # Project + # Project with ClosedFormStrategy result = project_to( - ProjectedTo(NormalMeanVariance), - target; - strategy = strategy, - parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5), + ProjectedTo( + NormalMeanVariance; + parameters = ProjectionParameters( + strategy = ClosedFormStrategy(), + niterations = 50, + tolerance = 1e-5 + ) + ), + target ) @test result isa NormalMeanVariance @@ -336,14 +338,17 @@ end # Initial: Gamma(2.0, 2.0) initial_dist = Gamma(2.0, 2.0) - strategy = ClosedFormStrategy() - result = project_to( - ProjectedTo(Gamma), + ProjectedTo( + Gamma; + parameters = ProjectionParameters( + strategy = ClosedFormStrategy(), + niterations = 50, + tolerance = 1e-5 + ) + ), target; - strategy = strategy, - initial_point = initial_dist, - parameters = ProjectionParameters(niterations = 50, tolerance = 1e-5), + initialpoint = initial_dist ) @test result isa GammaDistributionsFamily @@ -353,13 +358,17 @@ end @test scale(result) > 0 # Comparison with ControlVariateStrategy - cv_strategy = ControlVariateStrategy(nsamples = 500) result_cv = project_to( - ProjectedTo(Gamma), + ProjectedTo( + Gamma; + parameters = ProjectionParameters( + strategy = ControlVariateStrategy(nsamples = 500), + niterations = 50, + tolerance = 1e-4 + ) + ), target; - strategy = cv_strategy, - initial_point = initial_dist, - parameters = ProjectionParameters(niterations = 50, tolerance = 1e-4), + initialpoint = initial_dist ) # Should be close @@ -375,37 +384,38 @@ end using ExponentialFamily using StableRNGs using LinearAlgebra + using BenchmarkTools # Simple case: Normal to Normal - target_dist = Normal(5.0, 2.0) + target_dist = NormalMeanVariance(5.0, 2.0) target = Logpdf(target_dist) - initial = Normal(0.0, 1.0) + initial = NormalMeanVariance(0.0, 1.0) - # Analytic - t_analytic = @elapsed begin - res_analytic = project_to( - ProjectedTo(NormalMeanVariance), - target; + # Create projection objects + prj_analytic = ProjectedTo( + NormalMeanVariance; + parameters = ProjectionParameters( strategy = ClosedFormStrategy(), - initial_point = initial, - parameters = ProjectionParameters(niterations = 100), + niterations = 100 ) - end + ) - # MC - t_mc = @elapsed begin - res_mc = project_to( - ProjectedTo(NormalMeanVariance), - target; + prj_mc = ProjectedTo( + NormalMeanVariance; + parameters = ProjectionParameters( strategy = ControlVariateStrategy(nsamples = 1000), - initial_point = initial, - parameters = ProjectionParameters(niterations = 100), + niterations = 100 ) - end + ) - println("ClosedFormStrategy time: $t_analytic") - println("ControlVariateStrategy time: $t_mc") + # Benchmark with @belapsed for robust timing + t_analytic = @belapsed project_to($prj_analytic, $target; initialpoint=$initial) + t_mc = @belapsed project_to($prj_mc, $target; initialpoint=$initial) + + # Get results for accuracy testing + res_analytic = project_to(prj_analytic, target; initialpoint = initial) + res_mc = project_to(prj_mc, target; initialpoint = initial) # Analytic should be more accurate (converge to exact target) @test isapprox(mean(res_analytic), mean(target_dist), atol = 2e-2) @@ -413,6 +423,9 @@ end # ClosedFormStrategy should be at least as accurate as MC @test abs(mean(res_analytic) - 5.0) <= abs(mean(res_mc) - 5.0) + 0.1 + + # ClosedFormStrategy should be faster than MC sampling + @test t_analytic < t_mc end From d0e2f053afe3b528e075ad2d2e4311163f740b3e Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 18:44:33 +0100 Subject: [PATCH 15/17] docs: add ClosedFormExpectation docs --- docs/Project.toml | 1 + docs/make.jl | 2 +- docs/src/index.md | 71 +++++++++++++++++++++++++++++++++++ src/strategies/closed_form.jl | 52 ++++++++++++++++++++++--- 4 files changed, 119 insertions(+), 7 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index d57696e..2757f25 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,7 @@ [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" diff --git a/docs/make.jl b/docs/make.jl index 03f76f4..c83b395 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,7 +3,7 @@ using Documenter, ExponentialFamilyProjection DocMeta.setdocmeta!( ExponentialFamilyProjection, :DocTestSetup, - :(using ExponentialFamilyProjection); + :(using ExponentialFamilyProjection, ClosedFormExpectations); recursive = true, ) diff --git a/docs/src/index.md b/docs/src/index.md index bd01fdd..92c9f21 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -41,6 +41,7 @@ The optimization procedure requires computing the expectation of the gradient to ```@docs ExponentialFamilyProjection.DefaultStrategy ExponentialFamilyProjection.ControlVariateStrategy +ExponentialFamilyProjection.ClosedFormStrategy ExponentialFamilyProjection.MLEStrategy ExponentialFamilyProjection.BonnetStrategy ExponentialFamilyProjection.GaussNewton @@ -412,6 +413,76 @@ round.((speedup, t_manual, t_enzyme); digits = 3) On typical runs we observe a substantial speedup (often around 10×) for Enzyme while maintaining the same result. +### Closed-form strategy (zero-variance gradients) + +The `ClosedFormStrategy` uses `ClosedFormExpectations.jl` to compute exact, analytic gradients without Monte Carlo sampling. This provides "zero-variance" gradients, leading to faster and more accurate convergence compared to sampling-based strategies. + +!!! note + To use `ClosedFormStrategy`, you must install and load `ClosedFormExpectations.jl`: + ```julia + using Pkg + Pkg.add("ClosedFormExpectations") + using ClosedFormExpectations + ``` + +Let's compare `ClosedFormStrategy` with `ControlVariateStrategy` by projecting a LogNormal distribution onto a Gamma distribution: + +```@example projection +using ClosedFormExpectations +using BenchmarkTools + +# Target: LogNormal(μ=1.0, σ=0.5) +target_dist = LogNormal(1.0, 0.5) +target = (x) -> logpdf(target_dist, x) + +# Initial point +initial_dist = Gamma(2.0, 2.0) + +# Project using ClosedFormStrategy +t_closed = @elapsed result_closed = project_to( + ProjectedTo(Gamma; parameters=ProjectionParameters(strategy=ClosedFormStrategy(), niterations=50, tolerance=1e-5)), + target; + initialpoint = initial_dist +) + +# Project using ControlVariateStrategy +t_cv = @elapsed result_cv = project_to( + ProjectedTo(Gamma; parameters=ProjectionParameters(strategy=ControlVariateStrategy(nsamples=500), niterations=50, tolerance=1e-5)), + target; + initialpoint = initial_dist +) + +println("ClosedFormStrategy time: $(round(t_closed * 1000, digits=2)) ms") +println("ControlVariateStrategy time: $(round(t_cv * 1000, digits=2)) ms") +println("Speedup: $(round(t_cv / t_closed, digits=2))x") +``` + +Now let's visualize the results to see how both strategies compare: + +```@example projection +using Plots + +xs = 0.01:0.01:10.0 + +plot(xs, x -> pdf(target_dist, x), + label="Target (LogNormal)", linewidth=2, + fill=0, fillalpha=0.2, color=:blue) +plot!(xs, x -> pdf(result_closed, x), + label="ClosedForm", linewidth=2, + linestyle=:dash, color=:red) +plot!(xs, x -> pdf(result_cv, x), + label="ControlVariate", linewidth=2, + linestyle=:dot, color=:green) +xlabel!("x") +ylabel!("Density") +title!("LogNormal → Gamma Projection Comparison") +``` + +The `ClosedFormStrategy` typically provides: +- **Faster convergence**: No Monte Carlo noise in gradients +- **Better accuracy**: Exact gradient computations +- **Speed advantages**: Especially significant for lower-dimensional problems + ### Projection with samples The projection can be done given a set of samples instead of the function directly. For example, let's project an set of samples onto a Beta distribution: diff --git a/src/strategies/closed_form.jl b/src/strategies/closed_form.jl index 6c2b86f..f165946 100644 --- a/src/strategies/closed_form.jl +++ b/src/strategies/closed_form.jl @@ -9,14 +9,54 @@ of the cross-entropy term \$\\mathbb{E}_{q_\\eta}[\\log \\tilde{p}(x)]\$ analyti This strategy provides a "Zero-Variance" gradient estimator, avoiding the noise associated with Monte Carlo sampling (like in `ControlVariateStrategy`). +# Requirements + +To use this strategy, you **must** load the `ClosedFormExpectations` package: + +```julia +using ClosedFormExpectations +``` + +Loading `ClosedFormExpectations` will trigger a package extension that implements +the gradient computation for this strategy. + +# When to Use + +Use `ClosedFormStrategy` when: +- You need exact, deterministic gradients without Monte Carlo variance +- The target-to-variational family pair is supported by `ClosedFormExpectations.jl` +- You want faster convergence with fewer iterations +- Reproducibility is critical (no random sampling) + +# Example + +```julia +using ExponentialFamilyProjection, ClosedFormExpectations +using Distributions + +# Target distribution +target = LogNormal(1.0, 0.5) + +# Project to Gamma using closed-form gradients +result = project_to( + ProjectedTo( + Gamma; + parameters = ProjectionParameters( + strategy = ClosedFormStrategy(), + niterations = 50 + ) + ), + target +) +``` + +# References + This estimator was proposed in [Lukashchuk et al., 2024](https://proceedings.mlr.press/v246/lukashchuk24a.html). !!! note - To use this strategy, you **must** load the `ClosedFormExpectations` package in your environment. - Loading `ClosedFormExpectations` will trigger the package extension that implements `compute_gradient!` - and `compute_cost` for this strategy. - - It requires that `ClosedFormExpectations.jl` implements `ClosedWilliamsProduct` for the - specific pair of target function and variational family. + This strategy requires that `ClosedFormExpectations.jl` implements `ClosedWilliamsProduct` + for the specific pair of target distribution and variational family you're using. + See the `ClosedFormExpectations.jl` documentation for supported combinations. """ struct ClosedFormStrategy end From a1ecf6e7c411e5287fe311284557d484d912b5ae Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 18:45:56 +0100 Subject: [PATCH 16/17] docs(fix): ensure that docstring is correct --- src/strategies/closed_form.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strategies/closed_form.jl b/src/strategies/closed_form.jl index f165946..603d794 100644 --- a/src/strategies/closed_form.jl +++ b/src/strategies/closed_form.jl @@ -46,7 +46,7 @@ result = project_to( niterations = 50 ) ), - target + Logpdf(target) ) ``` From c92114f053e496f8f352d152a9297501c13b2a82 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 24 Nov 2025 18:53:18 +0100 Subject: [PATCH 17/17] fix: export ControlVariateStrategy --- docs/src/index.md | 9 ++++----- src/strategies/control_variate.jl | 2 ++ 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 92c9f21..f5ce806 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -433,22 +433,21 @@ using BenchmarkTools # Target: LogNormal(μ=1.0, σ=0.5) target_dist = LogNormal(1.0, 0.5) -target = (x) -> logpdf(target_dist, x) # Initial point initial_dist = Gamma(2.0, 2.0) -# Project using ClosedFormStrategy +# Project using ClosedFormStrategy (pass distribution directly) t_closed = @elapsed result_closed = project_to( ProjectedTo(Gamma; parameters=ProjectionParameters(strategy=ClosedFormStrategy(), niterations=50, tolerance=1e-5)), - target; + target_dist; initialpoint = initial_dist ) -# Project using ControlVariateStrategy +# Project using ControlVariateStrategy (with a function) t_cv = @elapsed result_cv = project_to( ProjectedTo(Gamma; parameters=ProjectionParameters(strategy=ControlVariateStrategy(nsamples=500), niterations=50, tolerance=1e-5)), - target; + (x) -> logpdf(target_dist, x); initialpoint = initial_dist ) diff --git a/src/strategies/control_variate.jl b/src/strategies/control_variate.jl index 9cdb7ef..d5de541 100644 --- a/src/strategies/control_variate.jl +++ b/src/strategies/control_variate.jl @@ -1,3 +1,5 @@ +export ControlVariateStrategy + using StableRNGs, Bumper, FillArrays import Random: AbstractRNG