Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
MooncakeSparse = "85e08d05-c593-4d83-8d7e-0e912d511203"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SelectedInversion = "043bf095-3f01-458a-9f1c-8cf4448fe908"
Expand Down Expand Up @@ -77,6 +78,7 @@ LinearMaps = "3.11"
LinearSolve = "2, 3"
Makie = "0.19 - 0.22"
Mooncake = "0.5.25"
MooncakeSparse = "0.1"
NearestNeighbors = "0.4"
Pardiso = "1"
Random = "<0.0.1, 1"
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/autodiff_comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,17 @@ try

# Prepare gradient (includes warmup/compilation)
print(" Preparing... ")
prep_chordal = DifferentiationInterface.prepare_gradient(loss_chordal, AutoMooncake(; config=nothing), θ_init)
prep_chordal = DifferentiationInterface.prepare_gradient(loss_chordal, AutoMooncake(; config = nothing), θ_init)
println("✓")

# Compute gradient once
print(" Computing gradient... ")
grad_chordal = DifferentiationInterface.gradient(loss_chordal, prep_chordal, AutoMooncake(; config=nothing), θ_init)
grad_chordal = DifferentiationInterface.gradient(loss_chordal, prep_chordal, AutoMooncake(; config = nothing), θ_init)
println("✓")

# Benchmark with prepared gradient
print(" Benchmarking... ")
bench_chordal = @benchmark DifferentiationInterface.gradient($loss_chordal, $prep_chordal, AutoMooncake(; config=nothing), $θ_init) samples = 10 seconds = 30
bench_chordal = @benchmark DifferentiationInterface.gradient($loss_chordal, $prep_chordal, AutoMooncake(; config = nothing), $θ_init) samples = 10 seconds = 30

results["ChordalGMRF+Mooncake"] = (
gradient = grad_chordal,
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/gaussian_approximation_comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ for (matrix_name, desc) in test_matrices
# Use prepared gradients for both backends
prep_gmrf = DifferentiationInterface.prepare_gradient(loss_gmrf, AutoZygote(), μ)
grad_gmrf = DifferentiationInterface.gradient(loss_gmrf, prep_gmrf, AutoZygote(), μ)
prep_chordal = DifferentiationInterface.prepare_gradient(loss_chordal, AutoMooncake(; config=nothing), μ)
grad_chordal = DifferentiationInterface.gradient(loss_chordal, prep_chordal, AutoMooncake(; config=nothing), μ)
prep_chordal = DifferentiationInterface.prepare_gradient(loss_chordal, AutoMooncake(; config = nothing), μ)
grad_chordal = DifferentiationInterface.gradient(loss_chordal, prep_chordal, AutoMooncake(; config = nothing), μ)
grad_abs_diff = norm(grad_gmrf - grad_chordal)
grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1.0e-10)

Expand All @@ -168,7 +168,7 @@ for (matrix_name, desc) in test_matrices
println("$(@sprintf("%.3f", time_grad_gmrf)) ms")

print(" ChordalGMRF (Mooncake)... ")
bench_grad_chordal = @benchmark DifferentiationInterface.gradient($loss_chordal, $prep_chordal, AutoMooncake(; config=nothing), $μ) samples = 10 seconds = 10
bench_grad_chordal = @benchmark DifferentiationInterface.gradient($loss_chordal, $prep_chordal, AutoMooncake(; config = nothing), $μ) samples = 10 seconds = 10
time_grad_chordal = minimum(bench_grad_chordal.times) / 1.0e6
println("$(@sprintf("%.3f", time_grad_chordal)) ms")

Expand Down
6 changes: 3 additions & 3 deletions benchmarks/logpdf_comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ for (matrix_name, desc) in test_matrices

# Use prepared gradients for both backends
prep_gmrf = DifferentiationInterface.prepare_gradient(gmrf_logpdf, AutoZygote(), z)
prep_chordal = DifferentiationInterface.prepare_gradient(chordal_logpdf, AutoMooncake(; config=nothing), z)
prep_chordal = DifferentiationInterface.prepare_gradient(chordal_logpdf, AutoMooncake(; config = nothing), z)

grad_gmrf = DifferentiationInterface.gradient(gmrf_logpdf, prep_gmrf, AutoZygote(), z)
grad_chordal = DifferentiationInterface.gradient(chordal_logpdf, prep_chordal, AutoMooncake(; config=nothing), z)
grad_chordal = DifferentiationInterface.gradient(chordal_logpdf, prep_chordal, AutoMooncake(; config = nothing), z)
grad_abs_diff = norm(grad_gmrf - grad_chordal)
grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1.0e-10)

Expand All @@ -140,7 +140,7 @@ for (matrix_name, desc) in test_matrices
println("$(@sprintf("%.3f", time_grad_gmrf)) ms")

print(" ChordalGMRF... ")
bench_grad_chordal = @benchmark DifferentiationInterface.gradient($chordal_logpdf, $prep_chordal, AutoMooncake(; config=nothing), $z) samples = 20 seconds = 5
bench_grad_chordal = @benchmark DifferentiationInterface.gradient($chordal_logpdf, $prep_chordal, AutoMooncake(; config = nothing), $z) samples = 20 seconds = 5
time_grad_chordal = minimum(bench_grad_chordal.times) / 1.0e6
println("$(@sprintf("%.3f", time_grad_chordal)) ms")

Expand Down
1 change: 0 additions & 1 deletion deps/MooncakeSparse
Submodule MooncakeSparse deleted from 6b1b2e
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Turing = "0.40"
Turing = "0.40, 0.41, 0.42, 0.43, 0.44"
1 change: 1 addition & 0 deletions docs/src/reference/gmrfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
```@docs
AbstractGMRF
GMRF
ChordalGMRF
precision_map
precision_matrix
InformationVector
Expand Down
2 changes: 1 addition & 1 deletion src/GaussianMarkovRandomFields.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module GaussianMarkovRandomFields

include("../deps/MooncakeSparse/MooncakeSparse.jl")
using MooncakeSparse
include("typedefs.jl")
include("utils/utils.jl")
include("linear_maps/linear_maps.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/arithmetic/condition/gaussian_approximation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using LinearAlgebra
using SparseArrays
using LinearMaps
using CliqueTrees.Multifrontal: chordal, ChordalCholesky, triangular
using CliqueTrees.Multifrontal: ChordalCholesky

export gaussian_approximation

Expand Down
5 changes: 2 additions & 3 deletions src/autodiff/logpdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ function ChainRulesCore.rrule(::typeof(logpdf), x::ConstrainedGMRF, z::AbstractV
return val, constrained_logpdf_pullback
end

function ChainRulesCore.rrule(::typeof(logpdf), x::AbstractGMRF, z::AbstractVector)
function ChainRulesCore.rrule(::typeof(logpdf), x::GMRF, z::AbstractVector)
μ = mean(x)
Q = precision_matrix(x)
r = z - μ

# Check if GMRF supports selected inversion for efficient gradients
if hasproperty(x, :linsolve_cache) && supports_selinv(x.linsolve_cache.alg) == Val{true}()
if supports_selinv(x.linsolve_cache.alg) == Val{true}()
# Forward computation - use existing implementation
val = logpdf(x, z)

Expand Down Expand Up @@ -98,4 +98,3 @@ function ChainRulesCore.rrule(::typeof(logpdf), x::AbstractGMRF, z::AbstractVect
)
end
end

46 changes: 23 additions & 23 deletions src/autodiff/mooncake_gaussian_approximation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ using CliqueTrees.Multifrontal: ChordalCholesky
@is_primitive MinimalCtx Tuple{Type{ChordalGMRF}, AbstractVector, SparseMatrixCSC}

function Mooncake.rrule!!(
::CoDual{Type{ChordalGMRF}},
cdμ::CoDual{<:AbstractVector},
cdQ::CoDual{<:SparseMatrixCSC},
)
::CoDual{Type{ChordalGMRF}},
cdμ::CoDual{<:AbstractVector},
cdQ::CoDual{<:SparseMatrixCSC},
)
μ, Σμ = MooncakeSparse.primaltangent(cdμ)
Q, ΣQ = MooncakeSparse.primaltangent(cdQ)

Expand All @@ -33,11 +33,11 @@ end
@is_primitive MinimalCtx Tuple{Type{ChordalGMRF}, AbstractVector, Hermitian, ChordalCholesky}

function Mooncake.rrule!!(
::CoDual{Type{ChordalGMRF}},
cdμ::CoDual{<:AbstractVector},
cdQ::CoDual{<:Hermitian},
cdF::CoDual{<:ChordalCholesky},
)
::CoDual{Type{ChordalGMRF}},
cdμ::CoDual{<:AbstractVector},
cdQ::CoDual{<:Hermitian},
cdF::CoDual{<:ChordalCholesky},
)
μ, Σμ = MooncakeSparse.primaltangent(cdμ)
Q, ΣQ = MooncakeSparse.primaltangent(cdQ)
F = primal(cdF)
Expand Down Expand Up @@ -66,10 +66,10 @@ end
@is_primitive MinimalCtx Tuple{typeof(Core.kwcall), Any, typeof(gaussian_approximation_notangent), ChordalGMRF, ObservationLikelihood}

function Mooncake.rrule!!(
::CoDual{typeof(gaussian_approximation_notangent)},
cdprior::CoDual{<:ChordalGMRF},
cdobslik::CoDual{<:ObservationLikelihood},
)
::CoDual{typeof(gaussian_approximation_notangent)},
cdprior::CoDual{<:ChordalGMRF},
cdobslik::CoDual{<:ObservationLikelihood},
)
prior = primal(cdprior)
obslik = primal(cdobslik)
posterior = gaussian_approximation_notangent(prior, obslik)
Expand All @@ -82,12 +82,12 @@ function Mooncake.rrule!!(
end

function Mooncake.rrule!!(
::CoDual{typeof(Core.kwcall)},
cdkwargs::CoDual,
::CoDual{typeof(gaussian_approximation_notangent)},
cdprior::CoDual{<:ChordalGMRF},
cdobslik::CoDual{<:ObservationLikelihood},
)
::CoDual{typeof(Core.kwcall)},
cdkwargs::CoDual,
::CoDual{typeof(gaussian_approximation_notangent)},
cdprior::CoDual{<:ChordalGMRF},
cdobslik::CoDual{<:ObservationLikelihood},
)
prior = primal(cdprior)
obslik = primal(cdobslik)
kwargs = primal(cdkwargs)
Expand All @@ -101,10 +101,10 @@ function Mooncake.rrule!!(
end

@mooncake_overlay function gaussian_approximation(
prior::ChordalGMRF,
obslik::ObservationLikelihood;
kwargs...
)
prior::ChordalGMRF,
obslik::ObservationLikelihood;
kwargs...
)
posterior = gaussian_approximation_notangent(prior, obslik; kwargs...)
x_star = mean(posterior)

Expand Down
19 changes: 19 additions & 0 deletions src/autodiff/precision_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ function ChainRulesCore.rrule(::Type{SymTridiagonal}, dv::AbstractVector, ev::Ab
return y, pullback
end

"""
ChainRulesCore.rrule(::typeof(sum), Q::SymTridiagonal)

ChainRule for `sum(::SymTridiagonal)`.

Restoring the rrules above triggers method invalidation that exposes a bug in
ChainRulesCore's `ProjectTo{SymTridiagonal}`: it extracts only one triangle of
the off-diagonal, dropping the factor of 2 from symmetry. The explicit
`sum(::SymTridiagonal)` rrule sidesteps the projection and returns the
correctly-doubled off-diagonal tangent.
"""
function ChainRulesCore.rrule(::typeof(sum), Q::SymTridiagonal)
function sum_symtridiag_pullback(ȳ)
s = unthunk(ȳ)
return NoTangent(), Tangent{SymTridiagonal}(dv = fill(s, length(Q.dv)), ev = fill(2s, length(Q.ev)))
end
return sum(Q), sum_symtridiag_pullback
end

"""
compute_precision_gradient(Qinv::AbstractMatrix, r::AbstractVector, ȳ::Real)

Expand Down
24 changes: 24 additions & 0 deletions src/chordal_gmrf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,30 @@ using Random: AbstractRNG, randn

export ChordalGMRF

"""
ChordalGMRF{T, Hrm, Fac, Mea} <: AbstractGMRF{T, Hrm}

A `GMRF` backed by a chordal Cholesky factorization (via
`CliqueTrees.Multifrontal.ChordalCholesky`) instead of CHOLMOD.

The pure-Julia chordal factorization composes naturally with `Mooncake`'s
reverse-mode AD through the rrules shipped by `MooncakeSparse`, so `logpdf`
and `gaussian_approximation` give correct gradients with respect to the
hyperparameters that produced `Q`. This is the recommended GMRF type for
Mooncake-based hyperparameter optimization (e.g. L-BFGS / Adam on the
marginal likelihood).

# Fields
- `μ::AbstractVector`: Mean.
- `Q::Hermitian`: Precision matrix.
- `F::ChordalCholesky`: Chordal Cholesky factorization of `Q`.

# Construction
```julia
ChordalGMRF(μ, Q) # factorize Q via ChordalCholesky
ChordalGMRF(μ, Q, F) # reuse a precomputed factorization
```
"""
struct ChordalGMRF{T <: Real, Hrm <: Hermitian, Fac <: ChordalCholesky, Mea <: AbstractVector{T}} <: AbstractGMRF{T, Hrm}
μ::Mea
Q::Hrm
Expand Down
18 changes: 9 additions & 9 deletions test/autodiff/test_gaussian_approximation_chordal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ using Random
using DifferentiationInterface
using FiniteDiff, Mooncake

backends = Any[("Mooncake", AutoMooncake())]
chordal_backends = Any[("Mooncake", AutoMooncake())]

@testset "$backend_name ChordalGMRF autodiff tests" for (backend_name, backend) in backends
@testset "$backend_name ChordalGMRF autodiff tests" for (backend_name, backend) in chordal_backends
# Set seed for reproducibility
Random.seed!(42)
fd_backend = AutoFiniteDiff()
Expand All @@ -21,7 +21,7 @@ backends = Any[("Mooncake", AutoMooncake())]
end

# Test pipeline: hyperparameters → ChordalGMRF → gaussian_approximation → logpdf
function test_gauss_approx_pipeline(θ::Vector, y::Vector, x::Vector, k::Int)
function test_gauss_approx_pipeline_chordal(θ::Vector, y::Vector, x::Vector, k::Int)
# Extract hyperparameters
ρ = θ[1] # AR parameter
μ_const = θ[2] # constant mean
Expand Down Expand Up @@ -54,13 +54,13 @@ backends = Any[("Mooncake", AutoMooncake())]
x = randn(k) .+ 0.5 # Evaluation point

grad_test = DifferentiationInterface.gradient(
θ -> test_gauss_approx_pipeline(θ, y, x, k),
θ -> test_gauss_approx_pipeline_chordal(θ, y, x, k),
backend,
θ
)

grad_fd = DifferentiationInterface.gradient(
θ -> test_gauss_approx_pipeline(θ, y, x, k),
θ -> test_gauss_approx_pipeline_chordal(θ, y, x, k),
fd_backend,
θ
)
Expand All @@ -83,13 +83,13 @@ backends = Any[("Mooncake", AutoMooncake())]
θ = [ρ, μ_const]

grad_test = DifferentiationInterface.gradient(
θ -> test_gauss_approx_pipeline(θ, y, x, k),
θ -> test_gauss_approx_pipeline_chordal(θ, y, x, k),
backend,
θ
)

grad_fd = DifferentiationInterface.gradient(
θ -> test_gauss_approx_pipeline(θ, y, x, k),
θ -> test_gauss_approx_pipeline_chordal(θ, y, x, k),
fd_backend,
θ
)
Expand Down Expand Up @@ -150,13 +150,13 @@ backends = Any[("Mooncake", AutoMooncake())]
x = randn(k) .+ 0.4

grad_test = DifferentiationInterface.gradient(
θ -> test_gauss_approx_pipeline(θ, y, x, k),
θ -> test_gauss_approx_pipeline_chordal(θ, y, x, k),
backend,
θ
)

grad_fd = DifferentiationInterface.gradient(
θ -> test_gauss_approx_pipeline(θ, y, x, k),
θ -> test_gauss_approx_pipeline_chordal(θ, y, x, k),
fd_backend,
θ
)
Expand Down
1 change: 1 addition & 0 deletions test/observation_models/test_negative_binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ using SparseArrays
# ============================================================
@testset "Poisson limit (r → ∞)" begin
count_data = [3, 1, 8, 0, 5]
Random.seed!(42)
η = randn(5)
r_large = 1.0e8

Expand Down