diff --git a/Project.toml b/Project.toml index b88423b..e29d0d0 100644 --- a/Project.toml +++ b/Project.toml @@ -22,9 +22,13 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +[extensions] +ClosedFormExpectationsExt = "ClosedFormExpectations" + [compat] BayesBase = "1.5.0" Bumper = "0.6" +ClosedFormExpectations = "0.3.0" Distributions = "0.25" ExponentialFamily = "2.0.0" ExponentialFamilyManifolds = "3.0.3" @@ -43,9 +47,13 @@ StaticArrays = "1.9" StatsFuns = "1.3" julia = "1.10" +[weakdeps] +ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" + [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -58,4 +66,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Aqua", "BenchmarkTools", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"] +test = ["Test", "Aqua", "BenchmarkTools", "ClosedFormExpectations", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"] diff --git a/docs/Project.toml b/docs/Project.toml index d57696e..2757f25 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,7 @@ [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" diff --git a/docs/make.jl b/docs/make.jl index 03f76f4..c83b395 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,7 +3,7 @@ using Documenter, ExponentialFamilyProjection DocMeta.setdocmeta!( ExponentialFamilyProjection, :DocTestSetup, - :(using ExponentialFamilyProjection); + :(using ExponentialFamilyProjection, ClosedFormExpectations); recursive = true, ) diff --git a/docs/src/index.md b/docs/src/index.md index bd01fdd..f5ce806 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -41,6 +41,7 @@ The optimization procedure requires computing the expectation of the gradient to ```@docs ExponentialFamilyProjection.DefaultStrategy ExponentialFamilyProjection.ControlVariateStrategy +ExponentialFamilyProjection.ClosedFormStrategy ExponentialFamilyProjection.MLEStrategy ExponentialFamilyProjection.BonnetStrategy ExponentialFamilyProjection.GaussNewton @@ -412,6 +413,75 @@ round.((speedup, t_manual, t_enzyme); digits = 3) On typical runs we observe a substantial speedup (often around 10×) for Enzyme while maintaining the same result. +### Closed-form strategy (zero-variance gradients) + +The `ClosedFormStrategy` uses `ClosedFormExpectations.jl` to compute exact, analytic gradients without Monte Carlo sampling. This provides "zero-variance" gradients, leading to faster and more accurate convergence compared to sampling-based strategies. + +!!! note + To use `ClosedFormStrategy`, you must install and load `ClosedFormExpectations.jl`: + ```julia + using Pkg + Pkg.add("ClosedFormExpectations") + using ClosedFormExpectations + ``` + +Let's compare `ClosedFormStrategy` with `ControlVariateStrategy` by projecting a LogNormal distribution onto a Gamma distribution: + +```@example projection +using ClosedFormExpectations +using BenchmarkTools + +# Target: LogNormal(μ=1.0, σ=0.5) +target_dist = LogNormal(1.0, 0.5) + +# Initial point +initial_dist = Gamma(2.0, 2.0) + +# Project using ClosedFormStrategy (pass distribution directly) +t_closed = @elapsed result_closed = project_to( + ProjectedTo(Gamma; parameters=ProjectionParameters(strategy=ClosedFormStrategy(), niterations=50, tolerance=1e-5)), + target_dist; + initialpoint = initial_dist +) + +# Project using ControlVariateStrategy (with a function) +t_cv = @elapsed result_cv = project_to( + ProjectedTo(Gamma; parameters=ProjectionParameters(strategy=ControlVariateStrategy(nsamples=500), niterations=50, tolerance=1e-5)), + (x) -> logpdf(target_dist, x); + initialpoint = initial_dist +) + +println("ClosedFormStrategy time: $(round(t_closed * 1000, digits=2)) ms") +println("ControlVariateStrategy time: $(round(t_cv * 1000, digits=2)) ms") +println("Speedup: $(round(t_cv / t_closed, digits=2))x") +``` + +Now let's visualize the results to see how both strategies compare: + +```@example projection +using Plots + +xs = 0.01:0.01:10.0 + +plot(xs, x -> pdf(target_dist, x), + label="Target (LogNormal)", linewidth=2, + fill=0, fillalpha=0.2, color=:blue) +plot!(xs, x -> pdf(result_closed, x), + label="ClosedForm", linewidth=2, + linestyle=:dash, color=:red) +plot!(xs, x -> pdf(result_cv, x), + label="ControlVariate", linewidth=2, + linestyle=:dot, color=:green) +xlabel!("x") +ylabel!("Density") +title!("LogNormal → Gamma Projection Comparison") +``` + +The `ClosedFormStrategy` typically provides: +- **Faster convergence**: No Monte Carlo noise in gradients +- **Better accuracy**: Exact gradient computations +- **Speed advantages**: Especially significant for lower-dimensional problems + ### Projection with samples The projection can be done given a set of samples instead of the function directly. For example, let's project an set of samples onto a Beta distribution: diff --git a/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl new file mode 100644 index 0000000..10e98a1 --- /dev/null +++ b/ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl @@ -0,0 +1,174 @@ +module ClosedFormExpectationsExt + +using ExponentialFamilyProjection +using ClosedFormExpectations +using ExponentialFamily +using ExponentialFamilyManifolds +using Manifolds +using ManifoldsBase +using LinearAlgebra +using Distributions + +# Import types needed for RxInfer closure unwrapping +import ExponentialFamily: ProductOf +import ClosedFormExpectations: Logpdf + +ExponentialFamilyProjection.get_nsamples(::ClosedFormStrategy) = 0 + +function logbasemeasure_correction( + ::ClosedFormStrategy, + ::ExponentialFamily.ConstantBaseMeasure, + q_dist, + grad_target, +) + grad_target +end + +function ExponentialFamilyProjection.compute_gradient!( + M::AbstractManifold, + strategy::ClosedFormStrategy, + state, + X, + η, + logpartition, + gradlogpartition, + inv_fisher, +) + # The gradient of KL(q||p) involves E_q[(log p̃ - log h_q) * (T - μ)] + # where h_q is the base measure of q (the variational distribution). + # + # For constant base measure: + # E[(log p̃ - log h) * (T - μ)] = E[log p̃ * (T - μ)] - log h * E[(T - μ)] + # Since E[T] = μ, the second term is zero. + + target_fn = state.target + + # Convert natural parameters on manifold to an ExponentialFamilyDistribution object + q_dist = convert( + ExponentialFamilyDistribution, + M, + ExponentialFamilyManifolds.partition_point(M, η), + ) + + # Compute ∇_η E[log p̃ * (T - μ)] + grad_target = mean(ClosedWilliamsProduct(), target_fn, q_dist) + grad_eta = logbasemeasure_correction( + strategy, + ExponentialFamily.isbasemeasureconstant(q_dist), + q_dist, + grad_target, + ) + + # Natural Gradient Update: X = η - F⁻¹ * ∇_η E + X .= η .- inv_fisher * grad_eta + + return X +end + +# Helper to create state +# Note: We use a mutable struct to ensure each create_state! call +# produces a distinct object in memory, even with the same target +mutable struct ClosedFormStrategyState{T} + target::T +end + +function ExponentialFamilyProjection.create_state!( + strategy::ClosedFormStrategy, + M::AbstractManifold, + parameters::ProjectionParameters, + projection_argument, + initial_ef, + supplementary_η, +) + return ClosedFormStrategyState(projection_argument) +end + +function ExponentialFamilyProjection.prepare_state!( + strategy::ClosedFormStrategy, + state::ClosedFormStrategyState, + M::AbstractManifold, + parameters::ProjectionParameters, + projection_argument, + current_ef, + supplementary_η, +) + return state +end + +# Compute cost for logging/convergence check +function ExponentialFamilyProjection.compute_cost( + M::AbstractManifold, + strategy::ClosedFormStrategy, + state::ClosedFormStrategyState, + η, + logpartition, + gradlogpartition, + inv_fisher, +) + # Cost = KL(q || p) = E_q[log q] - E_q[log p] + + # Reconstruct distribution + q_dist = convert( + ExponentialFamilyDistribution, + M, + ExponentialFamilyManifolds.partition_point(M, η), + ) + dist_std = convert(Distribution, q_dist) + + # E_q[log p] via CFE + E_log_p = mean(ClosedFormExpectation(), state.target, dist_std) + + # E_q[log q] = -entropy(q) + return -entropy(dist_std) - E_log_p +end + +# preprocess_strategy_argument for ClosedFormStrategy +# Special handling for RxInfer closures that wrap ProductOf +function ExponentialFamilyProjection.preprocess_strategy_argument( + strategy::ClosedFormStrategy, + argument::F, +) where {F<:Function} + # RxInfer wraps ProductOf or Distribution in a closure. + # Extract the ProductOf/Distribution from the closure's captured variables. + # The closure typically has one field holding the ProductOf/Distribution. + field_names = fieldnames(F) + + if isempty(field_names) + error( + """`ClosedFormStrategy` requires a function that captures a `Distribution` or `ProductOf` in its closure. + + Expected form: + let dist = Normal(0, 1) + (x) -> logpdf(dist, x) + end + + Got a function without captured variables: $F + + If you want to use a plain function, pass the `Distribution` directly instead of wrapping it in a function. + """, + ) + end + + captured = getfield(argument, first(field_names)) + return (strategy, Logpdf(captured)) +end + +# Generic fallback for non-Function arguments +function ExponentialFamilyProjection.preprocess_strategy_argument( + strategy::ClosedFormStrategy, + argument::Distribution, +) + # ClosedFormStrategy accepts any callable or distribution as argument + return (strategy, Logpdf(argument)) +end + +# Generic fallback for non-Function arguments +function ExponentialFamilyProjection.preprocess_strategy_argument( + strategy::ClosedFormStrategy, + argument, +) + # ClosedFormStrategy accepts any callable or distribution as argument + return (strategy, argument) +end + +end diff --git a/src/ExponentialFamilyProjection.jl b/src/ExponentialFamilyProjection.jl index 8431f42..39adb1a 100644 --- a/src/ExponentialFamilyProjection.jl +++ b/src/ExponentialFamilyProjection.jl @@ -123,6 +123,7 @@ function compute_gradient! end include("strategies/control_variate.jl") include("strategies/mle.jl") include("strategies/default.jl") +include("strategies/closed_form.jl") # Bonnet strategy include("strategies/bonnet/naive_grad_hess.jl") include("strategies/bonnet/bonnet_logpdf.jl") diff --git a/src/strategies/closed_form.jl b/src/strategies/closed_form.jl new file mode 100644 index 0000000..603d794 --- /dev/null +++ b/src/strategies/closed_form.jl @@ -0,0 +1,62 @@ +export ClosedFormStrategy + +""" + ClosedFormStrategy <: ExponentialFamilyProjection.AbstractStrategy + +A projection strategy that uses `ClosedFormExpectations.jl` to compute the exact gradient +of the cross-entropy term \$\\mathbb{E}_{q_\\eta}[\\log \\tilde{p}(x)]\$ analytically. + +This strategy provides a "Zero-Variance" gradient estimator, avoiding the noise associated +with Monte Carlo sampling (like in `ControlVariateStrategy`). + +# Requirements + +To use this strategy, you **must** load the `ClosedFormExpectations` package: + +```julia +using ClosedFormExpectations +``` + +Loading `ClosedFormExpectations` will trigger a package extension that implements +the gradient computation for this strategy. + +# When to Use + +Use `ClosedFormStrategy` when: +- You need exact, deterministic gradients without Monte Carlo variance +- The target-to-variational family pair is supported by `ClosedFormExpectations.jl` +- You want faster convergence with fewer iterations +- Reproducibility is critical (no random sampling) + +# Example + +```julia +using ExponentialFamilyProjection, ClosedFormExpectations +using Distributions + +# Target distribution +target = LogNormal(1.0, 0.5) + +# Project to Gamma using closed-form gradients +result = project_to( + ProjectedTo( + Gamma; + parameters = ProjectionParameters( + strategy = ClosedFormStrategy(), + niterations = 50 + ) + ), + Logpdf(target) +) +``` + +# References + +This estimator was proposed in [Lukashchuk et al., 2024](https://proceedings.mlr.press/v246/lukashchuk24a.html). + +!!! note + This strategy requires that `ClosedFormExpectations.jl` implements `ClosedWilliamsProduct` + for the specific pair of target distribution and variational family you're using. + See the `ClosedFormExpectations.jl` documentation for supported combinations. +""" +struct ClosedFormStrategy end diff --git a/src/strategies/control_variate.jl b/src/strategies/control_variate.jl index 9cdb7ef..d5de541 100644 --- a/src/strategies/control_variate.jl +++ b/src/strategies/control_variate.jl @@ -1,3 +1,5 @@ +export ControlVariateStrategy + using StableRNGs, Bumper, FillArrays import Random: AbstractRNG diff --git a/test/manopt/bounded_norm_update_rule_tests.jl b/test/manopt/bounded_norm_update_rule_tests.jl index 38e4684..fbea5f4 100644 --- a/test/manopt/bounded_norm_update_rule_tests.jl +++ b/test/manopt/bounded_norm_update_rule_tests.jl @@ -82,8 +82,20 @@ for limit in (1, 1.0, 1.0f0), p in (zeros(Float64, 3), zeros(Float32, 3)) cpa = DefaultManoptProblem(M, ManifoldGradientObjective(f, grad_f)) gst = GradientDescentState(M; p = zero(p)) - @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule(limit)(cpa, gst, 1) - @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule(static(limit))(cpa, gst, 1) + @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule( + limit, + )( + cpa, + gst, + 1, + ) + @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule( + static(limit), + )( + cpa, + gst, + 1, + ) @test_opt target_modules=(ExponentialFamilyProjection,) BoundedNormUpdateRule( static(limit); direction = BoundedNormUpdateRule(limit), diff --git a/test/strategies/closed_form_tests.jl b/test/strategies/closed_form_tests.jl new file mode 100644 index 0000000..074a67f --- /dev/null +++ b/test/strategies/closed_form_tests.jl @@ -0,0 +1,505 @@ +@testitem "ClosedFormStrategy generic properties" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + import ExponentialFamilyProjection: get_nsamples + + strategy = ClosedFormStrategy() + + @test strategy isa ClosedFormStrategy + @test get_nsamples(strategy) == 0 +end + +@testitem "ClosedFormStrategy create_state!" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamilyManifolds + using ExponentialFamily + using Distributions + import ExponentialFamilyProjection: create_state!, ProjectionParameters + + distributions = [Beta(5, 5), NormalMeanVariance(0, 1), Gamma(2, 2)] + parameters = ProjectionParameters() + + for dist in distributions + ef = convert(ExponentialFamilyDistribution, dist) + T = ExponentialFamily.exponential_family_typetag(ef) + d = size(mean(ef)) + c = getconditioner(ef) + M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) + + # Target wrapped in Logpdf + target = Logpdf(dist) + + # Test without supplementary parameters + state1 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, ()) + state2 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, ()) + + @test state1.target === target + @test state2.target === target + @test state1 !== state2 # Different objects in memory + + # Test with supplementary parameters + state3 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef,)) + state4 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef,)) + + @test state3.target === target + @test state4.target === target + @test state3 !== state4 + + # Test with multiple supplementary parameters + state5 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef, ef)) + state6 = create_state!(ClosedFormStrategy(), M, parameters, target, ef, (ef, ef)) + + @test state5.target === target + @test state6.target === target + @test state5 !== state6 + end +end + +@testitem "ClosedFormStrategy prepare_state!" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamilyManifolds + using ExponentialFamily + using Distributions + import ExponentialFamilyProjection: create_state!, prepare_state!, ProjectionParameters + + distributions = [NormalMeanVariance(0, 1), Gamma(2, 2), Beta(3, 3)] + + for dist in distributions + target_dist = dist + target = Logpdf(target_dist) + + ef = convert(ExponentialFamilyDistribution, dist) + T = ExponentialFamily.exponential_family_typetag(ef) + d = size(mean(ef)) + c = getconditioner(ef) + M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) + + strategy = ClosedFormStrategy() + parameters = ProjectionParameters() + + # Create initial state + state1 = create_state!(strategy, M, parameters, target, ef, ()) + + # prepare_state! should return the same state object + state2 = prepare_state!(strategy, state1, M, parameters, target, ef, ()) + + @test state1 === state2 + @test state1.target === state2.target + + # Test with supplementary parameters + supplementary_η = (getnaturalparameters(ef),) + state3 = create_state!(strategy, M, parameters, target, ef, supplementary_η) + state4 = + prepare_state!(strategy, state3, M, parameters, target, ef, supplementary_η) + + @test state3 === state4 + end +end + +@testitem "ClosedFormStrategy should fail if given a list of samples instead of a function" begin + using ExponentialFamily + using ClosedFormExpectations + + prj = ProjectedTo( + Beta; + parameters = ProjectionParameters(strategy = ClosedFormStrategy()), + ) + + # ClosedFormStrategy doesn't explicitly reject arrays, but it won't work properly + # The extension's preprocess_strategy_argument will keep the array as-is + # and then the compute_gradient! will fail when trying to use it + # This is the expected behavior - it will error during execution + + samples = [0.5, 0.6, 0.7] + + # This should fail during the projection, not during preprocessing + @test_throws Exception project_to(prj, samples) +end + + +@testitem "ClosedFormStrategy argument preprocessing for Distribution in closure" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using Distributions + import ExponentialFamilyProjection: preprocess_strategy_argument + + strategy = ClosedFormStrategy() + dist1 = Normal(0, 1) + + # Test extraction of Distribution from a closure (simulating RxInfer behavior) + # The closure captures a Distribution, and preprocess should extract it + closure_with_dist = let d = dist1 + (x) -> logpdf(d, x) + end + + result_strat, result_arg = preprocess_strategy_argument(strategy, closure_with_dist) + @test result_strat === strategy + + # The function should extract the captured Distribution and wrap it in Logpdf + @test result_arg isa Logpdf + @test result_arg.dist === dist1 +end + +@testitem "ClosedFormStrategy argument preprocessing for ProductOf in closure" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using Distributions + using BayesBase + import ExponentialFamilyProjection: preprocess_strategy_argument + import BayesBase: ProductOf + + strategy = ClosedFormStrategy() + + # Test extraction of ProductOf from a closure (RxInfer use case) + left = Beta(10, 10) + right = Beta(3, 3) + prod = ProductOf(left, right) + + closure_with_product = let p = prod + (x) -> logpdf(p, x) + end + + result_strat, result_arg = preprocess_strategy_argument(strategy, closure_with_product) + @test result_strat === strategy + + # The function should extract the captured ProductOf and wrap it in Logpdf + @test result_arg isa Logpdf + @test result_arg.dist === prod +end + +@testitem "ClosedFormStrategy argument preprocessing for plain function should error" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + import ExponentialFamilyProjection: preprocess_strategy_argument + + strategy = ClosedFormStrategy() + + # Plain function without captured variables should throw an error + # because ClosedFormStrategy needs to extract Distribution/ProductOf from closure + fn = (x) -> x^2 + + @test_throws "`ClosedFormStrategy` requires a function that captures a `Distribution` or `ProductOf` in its closure" preprocess_strategy_argument( + strategy, + fn, + ) +end + +@testitem "ClosedFormStrategy argument preprocessing for direct Distribution" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using Distributions + import ExponentialFamilyProjection: preprocess_strategy_argument + + strategy = ClosedFormStrategy() + dist1 = Normal(0, 1) + + # Case 4: Distribution directly (non-Function argument) + result_strat, result_arg = preprocess_strategy_argument(strategy, dist1) + @test result_strat === strategy + @test result_arg isa ClosedFormExpectations.Logpdf + @test result_arg.dist == dist1 +end + +@testitem "ClosedFormStrategy cost computation" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamily + using Distributions + using ExponentialFamilyManifolds + using StableRNGs + import ExponentialFamilyProjection: compute_cost, create_state!, ProjectionParameters + + strategy = ClosedFormStrategy() + + # Target: Normal(0, 1) + target_dist = Normal(0.0, 1.0) + target = Logpdf(target_dist) + + # Variational: Normal(1, 1) + # KL(q || p) where p=N(0,1), q=N(1,1) (variances equal) + # KL = 0.5 * ( (μ1-μ2)^2/σ^2 ) = 0.5 * (1-0)^2/1 = 0.5 + + # Manifold setup + M = ExponentialFamilyManifolds.get_natural_manifold(NormalMeanVariance, ()) + q_dist = NormalMeanVariance(1.0, 1.0) + ef = convert(ExponentialFamilyDistribution, q_dist) + η = getnaturalparameters(ef) + + # Create state + parameters = ProjectionParameters() + state = create_state!(strategy, M, parameters, target, ef, ()) + + # Dummy args for cost (not all used by ClosedFormStrategy's compute_cost) + logp = logpartition(ef) + gradlogp = gradlogpartition(ef) + inv_fisher = inv(fisherinformation(ef)) + + cost = compute_cost(M, strategy, state, η, logp, gradlogp, inv_fisher) + + # compute_cost returns: -entropy(q) - E_q[log p] + # KL(q||p) = E_q[log q] - E_q[log p] = -entropy(q) - E_q[log p] + # So it should approximately equal KL + @test cost ≈ 0.5 atol = 1e-6 +end + +@testitem "ClosedFormStrategy gradient computation" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamily + using Distributions + using ExponentialFamilyManifolds + using LinearAlgebra + using StableRNGs + import ExponentialFamilyProjection: + compute_gradient!, create_state!, ProjectionParameters + + strategy = ClosedFormStrategy() + + # Simple test: project Normal to Normal + target_dist = Normal(2.0, 1.0) + target = Logpdf(target_dist) + + # Start from a different point + q_dist = NormalMeanVariance(0.0, 1.0) + ef = convert(ExponentialFamilyDistribution, q_dist) + + M = ExponentialFamilyManifolds.get_natural_manifold(NormalMeanVariance, ()) + η = getnaturalparameters(ef) + + parameters = ProjectionParameters() + state = create_state!(strategy, M, parameters, target, ef, ()) + + # Compute gradient + X = similar(η) + logp = logpartition(ef) + gradlogp = gradlogpartition(ef) + inv_fisher = inv(fisherinformation(ef)) + + X_result = compute_gradient!(M, strategy, state, X, η, logp, gradlogp, inv_fisher) + + @test X_result === X + @test length(X) == length(η) + @test all(isfinite.(X)) + + # The gradient should point towards the target + # Since target is at μ=2, gradient should push η in that direction +end + + +@testitem "LogGamma projected to Normal" begin + using BayesBase + using ExponentialFamilyProjection + using ClosedFormExpectations + using Distributions + using ExponentialFamily + using StableRNGs + using LinearAlgebra + + target_dist = LogGamma(20.0, 1.0) + target = Logpdf(target_dist) + + # Project with ClosedFormStrategy + result = project_to( + ProjectedTo( + NormalMeanVariance; + parameters = ProjectionParameters( + strategy = ClosedFormStrategy(), + niterations = 50, + tolerance = 1e-5 + ) + ), + target + ) + + @test result isa NormalMeanVariance + μ, v = mean(result), var(result) + + # Mode of LogGamma(α, β) is at x = log(α*β) + # Here α=20, β=1. Mode at log(20) ≈ 3.0 + @test 2.0 < μ < 4.0 +end + +@testitem "LogNormal projected to Gamma" begin + using BayesBase + using ExponentialFamilyProjection + using ExponentialFamilyProjection: ControlVariateStrategy + using ClosedFormExpectations + using Distributions + using ExponentialFamily + using StableRNGs + using LinearAlgebra + + # Target: LogNormal(μ=1.0, σ=0.5) + target_dist = LogNormal(1.0, 0.5) + target = Logpdf(target_dist) + + # Initial: Gamma(2.0, 2.0) + initial_dist = Gamma(2.0, 2.0) + + result = project_to( + ProjectedTo( + Gamma; + parameters = ProjectionParameters( + strategy = ClosedFormStrategy(), + niterations = 50, + tolerance = 1e-5 + ) + ), + target; + initialpoint = initial_dist + ) + + @test result isa GammaDistributionsFamily + + # Check if it ran without error and produced valid parameters + @test shape(result) > 0 + @test scale(result) > 0 + + # Comparison with ControlVariateStrategy + result_cv = project_to( + ProjectedTo( + Gamma; + parameters = ProjectionParameters( + strategy = ControlVariateStrategy(nsamples = 500), + niterations = 50, + tolerance = 1e-4 + ) + ), + target; + initialpoint = initial_dist + ) + + # Should be close + @test isapprox(mean(result), mean(result_cv), rtol = 0.1) +end + +@testitem "ClosedFormStrategy vs ControlVariateStrategy: Speed and Accuracy" begin + using BayesBase + using ExponentialFamilyProjection + using ExponentialFamilyProjection: ControlVariateStrategy + using ClosedFormExpectations + using Distributions + using ExponentialFamily + using StableRNGs + using LinearAlgebra + using BenchmarkTools + + # Simple case: Normal to Normal + target_dist = NormalMeanVariance(5.0, 2.0) + target = Logpdf(target_dist) + + initial = NormalMeanVariance(0.0, 1.0) + + # Create projection objects + prj_analytic = ProjectedTo( + NormalMeanVariance; + parameters = ProjectionParameters( + strategy = ClosedFormStrategy(), + niterations = 100 + ) + ) + + prj_mc = ProjectedTo( + NormalMeanVariance; + parameters = ProjectionParameters( + strategy = ControlVariateStrategy(nsamples = 1000), + niterations = 100 + ) + ) + + # Benchmark with @belapsed for robust timing + t_analytic = @belapsed project_to($prj_analytic, $target; initialpoint=$initial) + t_mc = @belapsed project_to($prj_mc, $target; initialpoint=$initial) + + # Get results for accuracy testing + res_analytic = project_to(prj_analytic, target; initialpoint = initial) + res_mc = project_to(prj_mc, target; initialpoint = initial) + + # Analytic should be more accurate (converge to exact target) + @test isapprox(mean(res_analytic), mean(target_dist), atol = 2e-2) + @test isapprox(std(res_analytic), std(target_dist), atol = 2e-2) + + # ClosedFormStrategy should be at least as accurate as MC + @test abs(mean(res_analytic) - 5.0) <= abs(mean(res_mc) - 5.0) + 0.1 + + # ClosedFormStrategy should be faster than MC sampling + @test t_analytic < t_mc +end + + + +@testitem "ClosedFormStrategy ProjectionCostGradientObjective integration" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamily + using Distributions + using ExponentialFamilyManifolds + using Manifolds + using StableRNGs + import ExponentialFamilyProjection: ProjectionCostGradientObjective, create_state! + + # Test that the objective works correctly with ClosedFormStrategy + strategy = ClosedFormStrategy() + target_dist = Normal(2.0, 1.0) + target = Logpdf(target_dist) + + q_dist = NormalMeanVariance(0.0, 1.0) + ef = convert(ExponentialFamilyDistribution, q_dist) + + M = ExponentialFamilyManifolds.get_natural_manifold(NormalMeanVariance, ()) + η = getnaturalparameters(ef) + + parameters = ProjectionParameters(rng = StableRNG(42)) + state = create_state!(strategy, M, parameters, target, ef, ()) + + obj = ProjectionCostGradientObjective(parameters, target, copy(η), (), strategy, state) + + # Test evaluation + p = ExponentialFamilyManifolds.partition_point(M, η) + X = Manifolds.zero_vector(M, p) + + cost, X_result = obj(M, X, p) + + @test isfinite(cost) + @test cost > 0 # KL divergence should be positive + @test all(isfinite.(X_result)) +end + +@testitem "ClosedFormStrategy logbasemeasure_correction for ConstantBaseMeasure" begin + using ExponentialFamilyProjection + using ClosedFormExpectations + using ExponentialFamily + using Distributions + using Test + + # Get the extension module + # The extension should be loaded because both ExponentialFamilyProjection and ClosedFormExpectations are loaded + ClosedFormExpectationsExt = Base.get_extension(ExponentialFamilyProjection, :ClosedFormExpectationsExt) + + @test !isnothing(ClosedFormExpectationsExt) + + strategy = ClosedFormStrategy() + + # Create a mock or usage that triggers ConstantBaseMeasure + # We will use reflection/internals to test the specific method + + # Using a distribution that we know has ConstantBaseMeasure + # or constructing it manually if possible. + # ExponentialFamily.ConstantBaseMeasure is a singleton struct usually. + + base_measure = ExponentialFamily.ConstantBaseMeasure() + q_dist = Normal(0, 1) # Any distribution works as q_dist is passed through + grad_target = [1.0, 2.0] + + result = ClosedFormExpectationsExt.logbasemeasure_correction( + strategy, + base_measure, + q_dist, + grad_target + ) + + # The function should return grad_target exactly for ConstantBaseMeasure + @test result === grad_target +end