Skip to content

Commit 096c100

Browse files
committed
Merge branch 'main' into 63-create-integrationtestyml
2 parents 7e96a0b + f076f24 commit 096c100

7 files changed

Lines changed: 99 additions & 19 deletions

File tree

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565
- uses: actions/checkout@v4
6666
- uses: julia-actions/setup-julia@v2
6767
with:
68-
version: '1'
68+
version: '1.11'
6969
- uses: julia-actions/cache@v2
7070
- name: Configure doc environment
7171
shell: julia --project=docs --color=yes {0}

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ExponentialFamily = "2.0.0"
3030
ExponentialFamilyManifolds = "3.0.2"
3131
FastCholesky = "1.3"
3232
FillArrays = "1"
33-
ForwardDiff = "0.10.36"
33+
ForwardDiff = "0.10.36, 1"
3434
LinearAlgebra = "1.10"
3535
Manifolds = "0.11"
3636
ManifoldsBase = "2"

docs/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
[deps]
22
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
3+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
34
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
45
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
6+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
57
ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b"
68
ExponentialFamilyManifolds = "5c9727c4-3b82-4ab3-b165-76e2eb971b08"
79
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
810
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589"
911
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1012
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
13+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ makedocs(;
1616
canonical = "https://reactivebayes.github.io/ExponentialFamilyProjection.jl",
1717
edit_link = "main",
1818
assets = String[],
19-
repolink="github.com/ReactiveBayes/ExponentialFamilyProjection.jl",
19+
repolink="https://github.com/ReactiveBayes/ExponentialFamilyProjection.jl"
2020
),
2121
pages = ["Home" => "index.md"],
2222
)

docs/src/index.md

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ using StableRNGs
162162
using Distributions
163163
using ExponentialFamily
164164
using ExponentialFamilyProjection
165+
using StatsFuns
165166
using Plots
166167
167168
# 1) Generate a reproducible dataset (shared RNG)
@@ -172,8 +173,7 @@ d = input_dim + 1
172173
X_feat = randn(rng, n, input_dim)
173174
X = hcat(ones(n), X_feat)
174175
β_true = [0.5, 2.0, -1.5]
175-
σ(z) = 1 / (1 + exp(-z))
176-
p = map(σ, X * β_true)
176+
p = map(logistic, X * β_true)
177177
y = rand.(Ref(rng), Bernoulli.(p));
178178
nothing # hide
179179
```
@@ -184,16 +184,7 @@ We created a binary logistic regression dataset with an intercept and fixed `rng
184184
# 2) Define in-place log-posterior, gradient, and Hessian
185185
function logpost!(out::AbstractVector{T}, β::AbstractVector{T}) where {T<:Real}
186186
Xβ = X * β
187-
@inline function log1pexp(z)
188-
z > 0 ? z + log1p(exp(-z)) : log1p(exp(z))
189-
end
190-
s = zero(T)
191-
@inbounds for i in 1:n
192-
s += y[i] * Xβ[i] - log1pexp(Xβ[i])
193-
end
194-
# standard normal prior on β
195-
s += -0.5 * dot(β, β)
196-
out[1] = s
187+
out[1] = mean(y .* Xβ .- log.(1 .+ exp.(Xβ)))
197188
return out
198189
end
199190
@@ -204,7 +195,8 @@ function grad!(out::AbstractVector{T}, β::AbstractVector{T}) where {T<:Real}
204195
pi = 1 / (1 + exp(-Xβ[i]))
205196
@views out[:] .+= (y[i] - pi) .* X[i, :]
206197
end
207-
return out
198+
out .= out ./ length(y)
199+
return
208200
end
209201
210202
function hess!(out::AbstractMatrix{T}, β::AbstractVector{T}) where {T<:Real}
@@ -215,6 +207,7 @@ function hess!(out::AbstractMatrix{T}, β::AbstractVector{T}) where {T<:Real}
215207
wi = pi * (1 - pi)
216208
@views out .-= wi .* (X[i, :] * transpose(X[i, :]))
217209
end
210+
out .= out ./ length(y)
218211
return out
219212
end
220213
```
@@ -358,6 +351,67 @@ plt_mc
358351
plot(plt_mean, plt_mc; layout = (1, 2), size = (1100, 450))
359352
```
360353

354+
### How to use autograd with Gauss–Newton (Enzyme.jl)
355+
356+
You do not need to hand-derive gradients or Hessians to use Gauss–Newton. With `Enzyme.jl`, you can automatically obtain both and use them through the same in-place API shown above. In practice, this is typically faster and yields more stable estimates than naïve manual derivatives. `Enzyme.jl` has some sharp edges; please consult the [Enzyme documentation](https://enzymejs.github.io/enzyme/) before use.
357+
358+
```@example gaussnewton
359+
using Enzyme
360+
using BenchmarkTools
361+
362+
# 10) Define the log-posterior for logistic regression with a standard normal prior
363+
function obj(β::AbstractVector, X::AbstractMatrix, y::AbstractVector)
364+
Xβ = X * β
365+
return mean(y .* Xβ .- log.(1 .+ exp.(Xβ)))
366+
end
367+
368+
# Reverse-mode gradient and forward-over-reverse Hessian via Enzyme
369+
grad_enzyme = (β, X, y) -> Enzyme.gradient(Reverse, obj, β, Const(X), Const(y))[1]
370+
function jacobian_enzyme(β, X, y)
371+
Enzyme.jacobian(set_runtime_activity(Forward), grad_enzyme, β, Const(X), Const(y))
372+
end
373+
374+
# 11) In-place wrappers expected by Gauss–Newton
375+
function make_logpost!(X, y)
376+
(out, β) -> (out[1] = obj(β, X, y); out)
377+
end
378+
function make_grad!(X, y)
379+
function _grad!(out::AbstractVector{T}, β::AbstractVector{T}) where {T}
380+
out .= grad_enzyme(β, X, y)
381+
return out
382+
end
383+
_grad!
384+
end
385+
function make_hess!(X, y)
386+
function _hess!(out::AbstractMatrix{T}, β::AbstractVector{T}) where {T}
387+
J, _ = jacobian_enzyme(β, X, y)
388+
out .= J
389+
return out
390+
end
391+
_hess!
392+
end
393+
394+
logpostE! = make_logpost!(X, y)
395+
gradE! = make_grad!(X, y)
396+
hessE! = make_hess!(X, y)
397+
398+
inplace_enzyme = ExponentialFamilyProjection.InplaceLogpdfGradHess(logpostE!, gradE!, hessE!)
399+
prj_enzyme = ProjectedTo(MvNormalMeanCovariance, d; parameters = params)
400+
result_enzyme = project_to(prj_enzyme, inplace_enzyme)
401+
```
402+
403+
We can quickly compare the runtime of the Enzyme-based implementation to the manual one defined above.
404+
405+
```@example gaussnewton
406+
# 12) Speed comparison against the manual implementation from above
407+
t_manual = @belapsed project_to($prj, $inplace)
408+
t_enzyme = @belapsed project_to($prj_enzyme, $inplace_enzyme)
409+
speedup = t_manual / t_enzyme
410+
round.((speedup, t_manual, t_enzyme); digits = 3)
411+
```
412+
413+
On typical runs we observe a substantial speedup (often around 10×) for Enzyme while maintaining the same result.
414+
361415
### Projection with samples
362416

363417
The projection can be done given a set of samples instead of the function directly. For example, let's project an set of samples onto a Beta distribution:

src/projected_to.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct ProjectedTo{T,D,C,P,E}
4242
end
4343

4444
ProjectedTo(
45-
dims::Vararg{Int};
45+
dims::Tuple{Vararg{Int}};
4646
conditioner = nothing,
4747
parameters = DefaultProjectionParameters(),
4848
kwargs = nothing,
@@ -53,6 +53,29 @@ ProjectedTo(
5353
parameters = parameters,
5454
kwargs = kwargs,
5555
)
56+
ProjectedTo(;
57+
conditioner = nothing,
58+
parameters = DefaultProjectionParameters(),
59+
kwargs = nothing,
60+
) = ProjectedTo(
61+
ExponentialFamilyDistribution,
62+
()...,
63+
conditioner = conditioner,
64+
parameters = parameters,
65+
kwargs = kwargs,
66+
)
67+
ProjectedTo(
68+
dim::Int;
69+
conditioner = nothing,
70+
parameters = DefaultProjectionParameters(),
71+
kwargs = nothing,
72+
) = ProjectedTo(
73+
ExponentialFamilyDistribution,
74+
dim,
75+
conditioner = conditioner,
76+
parameters = parameters,
77+
kwargs = kwargs,
78+
)
5679
function ProjectedTo(
5780
::Type{T},
5881
dims...;

src/strategies/bonnet/bonnet_logpdf.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ implementation and returns an `InplaceLogpdfGradHess` instance.
3535
# See also
3636
- `NaiveGradHess` — adapter that combines separate `grad!`/`hess!` into `grad_hess!`.
3737
"""
38-
function InplaceLogpdfGradHess(logpdf!::F, grad!::G, hess!::H) where {F,G,H}
38+
function InplaceLogpdfGradHess(__logpdf::F, grad!::G, hess!::H) where {F,G,H}
3939
wrapper_grad_hess! = NaiveGradHess(grad!, hess!)
40-
return InplaceLogpdfGradHess(logpdf!, wrapper_grad_hess!)
40+
return InplaceLogpdfGradHess(__logpdf, wrapper_grad_hess!)
4141
end
4242

4343
"""

0 commit comments

Comments
 (0)