Skip to content

Commit 3a64838

Browse files
authored
Merge pull request #92 from ReactiveBayes/add-autograd-backed-for-closed-form-startegy
fix: add Enzyme backend into ClosedFormStrategy
2 parents b41491b + 05dbbf8 commit 3a64838

6 files changed

Lines changed: 278 additions & 22 deletions

File tree

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ ClosedFormExpectationsExt = "ClosedFormExpectations"
2828
[compat]
2929
BayesBase = "1.5.0"
3030
Bumper = "0.6"
31-
ClosedFormExpectations = "0.3.0"
31+
ClosedFormExpectations = "0.4.0"
3232
Distributions = "0.25"
3333
ExponentialFamily = "2.0.0"
3434
ExponentialFamilyManifolds = "3.0.3"
@@ -54,6 +54,7 @@ ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734"
5454
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5555
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5656
ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734"
57+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5758
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5859
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
5960
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
@@ -66,4 +67,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
6667
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6768

6869
[targets]
69-
test = ["Test", "Aqua", "BenchmarkTools", "ClosedFormExpectations", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"]
70+
test = ["Test", "Aqua", "BenchmarkTools", "ClosedFormExpectations", "Enzyme", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"]

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ ExponentialFamilyManifolds = "5c9727c4-3b82-4ab3-b165-76e2eb971b08"
1010
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
1111
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
1212
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
13+
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1314
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1415
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

docs/src/index.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,56 @@ The `ClosedFormStrategy` typically provides:
482482
- **Better accuracy**: Exact gradient computations
483483
- **Speed advantages**: Especially significant for lower-dimensional problems
484484

485+
### Autograd-backed closed-form strategy (Enzyme.jl)
486+
487+
When `ClosedFormExpectations.jl` provides `ClosedFormExpectation` (i.e. ``\mathbb{E}_q[f]``) for a target–variational pair but does **not** yet have a hand-coded `ClosedWilliamsProduct` (i.e. ``\mathbb{E}_q[f \nabla_\eta \log q]``), you can still use `ClosedFormStrategy` by passing an `EnzymeBackend`. This exploits the identity
488+
489+
```math
490+
\nabla_\eta \mathbb{E}_q[f(x)] = \mathbb{E}_q[f(x) \nabla_\eta \log q(x;\eta)]
491+
```
492+
493+
and lets Enzyme.jl compute the gradient automatically by differentiating the closed-form expectation with respect to the natural parameters.
494+
495+
#### Example: Gamma projected to LogNormal
496+
497+
For this pair the `ClosedFormExpectation` is available but there is no hand-coded `ClosedWilliamsProduct`, so the default `ClosedFormStrategy()` would fail. With `EnzymeBackend` it works out of the box:
498+
499+
```@example enzymebackend
500+
using ExponentialFamilyProjection, ClosedFormExpectations, Enzyme
501+
using ExponentialFamily, BayesBase, Distributions
502+
using Plots
503+
504+
target_dist = Gamma(3.0, 2.0)
505+
506+
result = project_to(
507+
ProjectedTo(
508+
LogNormal;
509+
parameters = ProjectionParameters(
510+
strategy = ClosedFormStrategy(EnzymeBackend()),
511+
niterations = 100,
512+
tolerance = 1e-6,
513+
),
514+
),
515+
target_dist,
516+
)
517+
518+
xs = 0.01:0.05:20.0
519+
520+
plot(xs, x -> pdf(target_dist, x),
521+
label="Target (Gamma)", linewidth=2,
522+
fill=0, fillalpha=0.2, color=:blue)
523+
plot!(xs, x -> pdf(result, x),
524+
label="Projection (LogNormal)", linewidth=2,
525+
linestyle=:dash, color=:red)
526+
xlabel!("x")
527+
ylabel!("Density")
528+
title!("Gamma → LogNormal (EnzymeBackend)")
529+
```
530+
531+
The `EnzymeBackend` supports both reverse and forward mode:
532+
- `ClosedFormStrategy(EnzymeBackend())` — reverse mode (default)
533+
- `ClosedFormStrategy(EnzymeBackend(EnzymeForward()))` — forward mode
534+
485535
### Projection with samples
486536

487537
The projection can be done given a set of samples instead of the function directly. For example, let's project an set of samples onto a Beta distribution:

ext/ClosedFormExpectationsExt/ClosedFormExpectationsExt.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function ExponentialFamilyProjection.compute_gradient!(
5151
)
5252

5353
# Compute ∇_η E[log p̃ * (T - μ)]
54-
grad_target = mean(ClosedWilliamsProduct(), target_fn, q_dist)
54+
grad_target = mean(ClosedWilliamsProduct(strategy.backend), target_fn, q_dist)
5555
grad_eta = logbasemeasure_correction(
5656
strategy,
5757
ExponentialFamily.isbasemeasureconstant(q_dist),
@@ -153,12 +153,19 @@ function ExponentialFamilyProjection.preprocess_strategy_argument(
153153
return (strategy, Logpdf(captured))
154154
end
155155

156-
# Generic fallback for non-Function arguments
156+
# Wrap Distribution in Logpdf for ClosedFormStrategy
157157
function ExponentialFamilyProjection.preprocess_strategy_argument(
158158
strategy::ClosedFormStrategy,
159159
argument::Distribution,
160160
)
161-
# ClosedFormStrategy accepts any callable or distribution as argument
161+
return (strategy, Logpdf(argument))
162+
end
163+
164+
# Wrap ProductOf in Logpdf for ClosedFormStrategy
165+
function ExponentialFamilyProjection.preprocess_strategy_argument(
166+
strategy::ClosedFormStrategy,
167+
argument::ProductOf,
168+
)
162169
return (strategy, Logpdf(argument))
163170
end
164171

@@ -167,7 +174,6 @@ function ExponentialFamilyProjection.preprocess_strategy_argument(
167174
strategy::ClosedFormStrategy,
168175
argument,
169176
)
170-
# ClosedFormStrategy accepts any callable or distribution as argument
171177
return (strategy, argument)
172178
end
173179

src/strategies/closed_form.jl

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
export ClosedFormStrategy
22

33
"""
4-
ClosedFormStrategy <: ExponentialFamilyProjection.AbstractStrategy
4+
ClosedFormStrategy{B} <: ExponentialFamilyProjection.AbstractStrategy
55
6-
A projection strategy that uses `ClosedFormExpectations.jl` to compute the exact gradient
6+
A projection strategy that uses `ClosedFormExpectations.jl` to compute the exact gradient
77
of the cross-entropy term \$\\mathbb{E}_{q_\\eta}[\\log \\tilde{p}(x)]\$ analytically.
88
9-
This strategy provides a "Zero-Variance" gradient estimator, avoiding the noise associated
9+
This strategy provides a "Zero-Variance" gradient estimator, avoiding the noise associated
1010
with Monte Carlo sampling (like in `ControlVariateStrategy`).
1111
12+
The optional `backend` field selects the differentiation backend used for computing
13+
`ClosedWilliamsProduct`. When `backend = nothing` (the default), hand-coded closed-form
14+
implementations are used. When an `EnzymeBackend` is provided, Enzyme.jl automatically
15+
differentiates the `ClosedFormExpectation` to obtain the Williams product gradient, enabling
16+
the strategy to work for any target-variational pair where the expectation is implemented
17+
but the Williams product is not.
18+
1219
# Requirements
1320
1421
To use this strategy, you **must** load the `ClosedFormExpectations` package:
@@ -17,7 +24,7 @@ To use this strategy, you **must** load the `ClosedFormExpectations` package:
1724
using ClosedFormExpectations
1825
```
1926
20-
Loading `ClosedFormExpectations` will trigger a package extension that implements
27+
Loading `ClosedFormExpectations` will trigger a package extension that implements
2128
the gradient computation for this strategy.
2229
2330
# When to Use
@@ -28,7 +35,11 @@ Use `ClosedFormStrategy` when:
2835
- You want faster convergence with fewer iterations
2936
- Reproducibility is critical (no random sampling)
3037
31-
# Example
38+
Use `ClosedFormStrategy(EnzymeBackend())` when:
39+
- A `ClosedFormExpectation` is implemented for the pair, but `ClosedWilliamsProduct` is not
40+
- You want to exploit the identity \$\\nabla_\\eta \\mathbb{E}_q[f] = \\mathbb{E}_q[f \\nabla_\\eta \\log q]\$ via autodiff
41+
42+
# Examples
3243
3344
```julia
3445
using ExponentialFamilyProjection, ClosedFormExpectations
@@ -37,7 +48,7 @@ using Distributions
3748
# Target distribution
3849
target = LogNormal(1.0, 0.5)
3950
40-
# Project to Gamma using closed-form gradients
51+
# Project to Gamma using closed-form gradients (hand-coded Williams product)
4152
result = project_to(
4253
ProjectedTo(
4354
Gamma;
@@ -50,13 +61,37 @@ result = project_to(
5061
)
5162
```
5263
64+
```julia
65+
using ExponentialFamilyProjection, ClosedFormExpectations, Enzyme
66+
using Distributions
67+
68+
# Target distribution (Gamma → LogNormal: ClosedFormExpectation is available
69+
# but ClosedWilliamsProduct is not, so we use EnzymeBackend to autodiff it)
70+
target = Gamma(2.0, 1.0)
71+
72+
result = project_to(
73+
ProjectedTo(
74+
LogNormal;
75+
parameters = ProjectionParameters(
76+
strategy = ClosedFormStrategy(EnzymeBackend()),
77+
niterations = 50
78+
)
79+
),
80+
Logpdf(target)
81+
)
82+
```
83+
5384
# References
5485
5586
This estimator was proposed in [Lukashchuk et al., 2024](https://proceedings.mlr.press/v246/lukashchuk24a.html).
5687
5788
!!! note
58-
This strategy requires that `ClosedFormExpectations.jl` implements `ClosedWilliamsProduct`
59-
for the specific pair of target distribution and variational family you're using.
60-
See the `ClosedFormExpectations.jl` documentation for supported combinations.
89+
Without a backend, this strategy requires that `ClosedFormExpectations.jl` implements
90+
`ClosedWilliamsProduct` for the specific target-variational pair. With an `EnzymeBackend`,
91+
it suffices to have `ClosedFormExpectation` implemented. See the `ClosedFormExpectations.jl`
92+
documentation for supported combinations.
6193
"""
62-
struct ClosedFormStrategy end
94+
struct ClosedFormStrategy{B}
95+
backend::B
96+
end
97+
ClosedFormStrategy() = ClosedFormStrategy(nothing)

0 commit comments

Comments
 (0)