Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[extensions]
ClosedFormExpectationsExt = "ClosedFormExpectations"

[compat]
BayesBase = "1.5.0"
Bumper = "0.6"
ClosedFormExpectations = "0.3.0"
Distributions = "0.25"
ExponentialFamily = "2.0.0"
ExponentialFamilyManifolds = "3.0.3"
Expand All @@ -43,9 +47,13 @@ StaticArrays = "1.9"
StatsFuns = "1.3"
julia = "1.10"

[weakdeps]
ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand All @@ -58,4 +66,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Aqua", "BenchmarkTools", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"]
test = ["Test", "Aqua", "BenchmarkTools", "ClosedFormExpectations", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"]
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Documenter, ExponentialFamilyProjection
DocMeta.setdocmeta!(
ExponentialFamilyProjection,
:DocTestSetup,
:(using ExponentialFamilyProjection);
:(using ExponentialFamilyProjection, ClosedFormExpectations);
recursive = true,
)

Expand Down
70 changes: 70 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ The optimization procedure requires computing the expectation of the gradient to
```@docs
ExponentialFamilyProjection.DefaultStrategy
ExponentialFamilyProjection.ControlVariateStrategy
ExponentialFamilyProjection.ClosedFormStrategy
ExponentialFamilyProjection.MLEStrategy
ExponentialFamilyProjection.BonnetStrategy
ExponentialFamilyProjection.GaussNewton
Expand Down Expand Up @@ -412,6 +413,75 @@ round.((speedup, t_manual, t_enzyme); digits = 3)

On typical runs we observe a substantial speedup (often around 10×) for Enzyme while maintaining the same result.

### Closed-form strategy (zero-variance gradients)

The `ClosedFormStrategy` uses `ClosedFormExpectations.jl` to compute exact, analytic gradients without Monte Carlo sampling. This provides "zero-variance" gradients, leading to faster and more accurate convergence compared to sampling-based strategies.

!!! note
To use `ClosedFormStrategy`, you must install and load `ClosedFormExpectations.jl`:
```julia
using Pkg
Pkg.add("ClosedFormExpectations")
using ClosedFormExpectations
```

Let's compare `ClosedFormStrategy` with `ControlVariateStrategy` by projecting a LogNormal distribution onto a Gamma distribution:

```@example projection
using ClosedFormExpectations
using BenchmarkTools

# Target: LogNormal(μ=1.0, σ=0.5)
target_dist = LogNormal(1.0, 0.5)

# Initial point
initial_dist = Gamma(2.0, 2.0)

# Project using ClosedFormStrategy (pass distribution directly)
t_closed = @elapsed result_closed = project_to(
ProjectedTo(Gamma; parameters=ProjectionParameters(strategy=ClosedFormStrategy(), niterations=50, tolerance=1e-5)),
target_dist;
initialpoint = initial_dist
)

# Project using ControlVariateStrategy (with a function)
t_cv = @elapsed result_cv = project_to(
ProjectedTo(Gamma; parameters=ProjectionParameters(strategy=ControlVariateStrategy(nsamples=500), niterations=50, tolerance=1e-5)),
(x) -> logpdf(target_dist, x);
initialpoint = initial_dist
)

println("ClosedFormStrategy time: $(round(t_closed * 1000, digits=2)) ms")
println("ControlVariateStrategy time: $(round(t_cv * 1000, digits=2)) ms")
println("Speedup: $(round(t_cv / t_closed, digits=2))x")
```

Now let's visualize the results to see how both strategies compare:

```@example projection
using Plots

xs = 0.01:0.01:10.0

plot(xs, x -> pdf(target_dist, x),
label="Target (LogNormal)", linewidth=2,
fill=0, fillalpha=0.2, color=:blue)
plot!(xs, x -> pdf(result_closed, x),
label="ClosedForm", linewidth=2,
linestyle=:dash, color=:red)
plot!(xs, x -> pdf(result_cv, x),
label="ControlVariate", linewidth=2,
linestyle=:dot, color=:green)
xlabel!("x")
ylabel!("Density")
title!("LogNormal → Gamma Projection Comparison")
```

The `ClosedFormStrategy` typically provides:
- **Faster convergence**: No Monte Carlo noise in gradients
- **Better accuracy**: Exact gradient computations
- **Speed advantages**: Especially significant for lower-dimensional problems

### Projection with samples

The projection can be done given a set of samples instead of the function directly. For example, let's project an set of samples onto a Beta distribution:
Expand Down
174 changes: 174 additions & 0 deletions ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
module ClosedFormExpectationsExt

using ExponentialFamilyProjection
using ClosedFormExpectations
using ExponentialFamily
using ExponentialFamilyManifolds
using Manifolds
using ManifoldsBase
using LinearAlgebra
using Distributions

# Import types needed for RxInfer closure unwrapping
import ExponentialFamily: ProductOf
import ClosedFormExpectations: Logpdf

ExponentialFamilyProjection.get_nsamples(::ClosedFormStrategy) = 0

function logbasemeasure_correction(
::ClosedFormStrategy,
::ExponentialFamily.ConstantBaseMeasure,
q_dist,
grad_target,
)
grad_target
end

function ExponentialFamilyProjection.compute_gradient!(
M::AbstractManifold,
strategy::ClosedFormStrategy,
state,
X,
η,
logpartition,
gradlogpartition,
inv_fisher,
)
# The gradient of KL(q||p) involves E_q[(log p̃ - log h_q) * (T - μ)]
# where h_q is the base measure of q (the variational distribution).
#
# For constant base measure:
# E[(log p̃ - log h) * (T - μ)] = E[log p̃ * (T - μ)] - log h * E[(T - μ)]
# Since E[T] = μ, the second term is zero.

target_fn = state.target

# Convert natural parameters on manifold to an ExponentialFamilyDistribution object
q_dist = convert(
ExponentialFamilyDistribution,
M,
ExponentialFamilyManifolds.partition_point(M, η),
)

# Compute ∇_η E[log p̃ * (T - μ)]
grad_target = mean(ClosedWilliamsProduct(), target_fn, q_dist)
grad_eta = logbasemeasure_correction(
strategy,
ExponentialFamily.isbasemeasureconstant(q_dist),
q_dist,
grad_target,
)

# Natural Gradient Update: X = η - F⁻¹ * ∇_η E
X .= η .- inv_fisher * grad_eta

return X
end

# Helper to create state
# Note: We use a mutable struct to ensure each create_state! call
# produces a distinct object in memory, even with the same target
mutable struct ClosedFormStrategyState{T}
target::T
end

function ExponentialFamilyProjection.create_state!(
strategy::ClosedFormStrategy,
M::AbstractManifold,
parameters::ProjectionParameters,
projection_argument,
initial_ef,
supplementary_η,
)
return ClosedFormStrategyState(projection_argument)
end

function ExponentialFamilyProjection.prepare_state!(
strategy::ClosedFormStrategy,
state::ClosedFormStrategyState,
M::AbstractManifold,
parameters::ProjectionParameters,
projection_argument,
current_ef,
supplementary_η,
)
return state
end

# Compute cost for logging/convergence check
function ExponentialFamilyProjection.compute_cost(
M::AbstractManifold,
strategy::ClosedFormStrategy,
state::ClosedFormStrategyState,
η,
logpartition,
gradlogpartition,
inv_fisher,
)
# Cost = KL(q || p) = E_q[log q] - E_q[log p]

# Reconstruct distribution
q_dist = convert(
ExponentialFamilyDistribution,
M,
ExponentialFamilyManifolds.partition_point(M, η),
)
dist_std = convert(Distribution, q_dist)

# E_q[log p] via CFE
E_log_p = mean(ClosedFormExpectation(), state.target, dist_std)

# E_q[log q] = -entropy(q)
return -entropy(dist_std) - E_log_p
end

# preprocess_strategy_argument for ClosedFormStrategy
# Special handling for RxInfer closures that wrap ProductOf
function ExponentialFamilyProjection.preprocess_strategy_argument(
strategy::ClosedFormStrategy,
argument::F,
) where {F<:Function}
# RxInfer wraps ProductOf or Distribution in a closure.
# Extract the ProductOf/Distribution from the closure's captured variables.
# The closure typically has one field holding the ProductOf/Distribution.
field_names = fieldnames(F)

if isempty(field_names)
error(
"""`ClosedFormStrategy` requires a function that captures a `Distribution` or `ProductOf` in its closure.

Expected form:
let dist = Normal(0, 1)
(x) -> logpdf(dist, x)
end

Got a function without captured variables: $F

If you want to use a plain function, pass the `Distribution` directly instead of wrapping it in a function.
""",
)
end

captured = getfield(argument, first(field_names))
return (strategy, Logpdf(captured))
end

# Generic fallback for non-Function arguments
function ExponentialFamilyProjection.preprocess_strategy_argument(
strategy::ClosedFormStrategy,
argument::Distribution,
)
# ClosedFormStrategy accepts any callable or distribution as argument
return (strategy, Logpdf(argument))
end

# Generic fallback for non-Function arguments
function ExponentialFamilyProjection.preprocess_strategy_argument(
strategy::ClosedFormStrategy,
argument,
)
# ClosedFormStrategy accepts any callable or distribution as argument
return (strategy, argument)
end

end
1 change: 1 addition & 0 deletions src/ExponentialFamilyProjection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ function compute_gradient! end
include("strategies/control_variate.jl")
include("strategies/mle.jl")
include("strategies/default.jl")
include("strategies/closed_form.jl")
# Bonnet strategy
include("strategies/bonnet/naive_grad_hess.jl")
include("strategies/bonnet/bonnet_logpdf.jl")
Expand Down
62 changes: 62 additions & 0 deletions src/strategies/closed_form.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
export ClosedFormStrategy

"""
ClosedFormStrategy <: ExponentialFamilyProjection.AbstractStrategy

A projection strategy that uses `ClosedFormExpectations.jl` to compute the exact gradient
of the cross-entropy term \$\\mathbb{E}_{q_\\eta}[\\log \\tilde{p}(x)]\$ analytically.

This strategy provides a "Zero-Variance" gradient estimator, avoiding the noise associated
with Monte Carlo sampling (like in `ControlVariateStrategy`).

# Requirements

To use this strategy, you **must** load the `ClosedFormExpectations` package:

```julia
using ClosedFormExpectations
```

Loading `ClosedFormExpectations` will trigger a package extension that implements
the gradient computation for this strategy.

# When to Use

Use `ClosedFormStrategy` when:
- You need exact, deterministic gradients without Monte Carlo variance
- The target-to-variational family pair is supported by `ClosedFormExpectations.jl`
- You want faster convergence with fewer iterations
- Reproducibility is critical (no random sampling)

# Example

```julia
using ExponentialFamilyProjection, ClosedFormExpectations
using Distributions

# Target distribution
target = LogNormal(1.0, 0.5)

# Project to Gamma using closed-form gradients
result = project_to(
ProjectedTo(
Gamma;
parameters = ProjectionParameters(
strategy = ClosedFormStrategy(),
niterations = 50
)
),
Logpdf(target)
)
```

# References

This estimator was proposed in [Lukashchuk et al., 2024](https://proceedings.mlr.press/v246/lukashchuk24a.html).

!!! note
This strategy requires that `ClosedFormExpectations.jl` implements `ClosedWilliamsProduct`
for the specific pair of target distribution and variational family you're using.
See the `ClosedFormExpectations.jl` documentation for supported combinations.
"""
struct ClosedFormStrategy end
2 changes: 2 additions & 0 deletions src/strategies/control_variate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export ControlVariateStrategy

using StableRNGs, Bumper, FillArrays

import Random: AbstractRNG
Expand Down
Loading
Loading