forked from timweiland/GaussianMarkovRandomFields.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprecision_gradient.jl
More file actions
171 lines (137 loc) · 5.87 KB
/
precision_gradient.jl
File metadata and controls
171 lines (137 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
using SparseArrays
using LinearAlgebra
using ChainRulesCore
# TODO: Make PR for SymTridiagonal rrules into ChainRules, to avoid type piracy
"""
ChainRulesCore.rrule(::typeof(*), Q::SymTridiagonal, v::AbstractVector)
Custom rrule for SymTridiagonal matrix-vector multiplication.
Zygote's default rrule only accounts for the upper triangle, but each off-diagonal
element appears in both [i,i+1] and [i+1,i], so we need to sum both contributions.
"""
function ChainRulesCore.rrule(
::typeof(*),
Q::SymTridiagonal{T, V},
v::AbstractVector{<:Union{Real, Complex}}
) where {T <: Union{Real, Complex}, V <: AbstractVector{T}}
y = Q * v
project_v = ProjectTo(v)
function symtri_mul_pullback(ȳ)
# Gradient w.r.t. Q: ȳ ⊗ vᵀ (outer product)
n = length(v)
Q̄ = @thunk(
Tridiagonal(
ȳ[2:n] .* v[1:(n - 1)],
ȳ .* v,
ȳ[1:(n - 1)] .* v[2:n]
)
)
return NoTangent(), Q̄, @thunk(project_v(Q' * ȳ))
end
return y, symtri_mul_pullback
end
"""
ChainRulesCore.rrule(::Type{SymTridiagonal}, dv::AbstractVector, ev::AbstractVector)
ChainRule for SymTridiagonal constructor to enable Zygote differentiation.
The SymTridiagonal stores diagonal (`dv`) and off-diagonal (`ev`) separately.
The pullback extracts these from the incoming tangent matrix.
"""
function ChainRulesCore.rrule(::Type{SymTridiagonal}, dv::AbstractVector, ev::AbstractVector)
y = SymTridiagonal(dv, ev)
project_d = ProjectTo(dv)
project_e = ProjectTo(ev)
function pullback(ȳ)
ȳ = unthunk(ȳ)
# If we’re given a SymTridiagonal as the cotangent, just read its fields.
if ȳ isa SymTridiagonal
dd = ȳ.dv
de = ȳ.ev
elseif ȳ isa Tangent
dd = haskey(ȳ, :dv) ? ȳ.dv : ZeroTangent()
de = haskey(ȳ, :ev) ? ȳ.ev : ZeroTangent()
elseif ȳ isa AbstractMatrix
# For off-diagonals, the parameter e[i] contributes to (i,i+1) and (i+1,i),
# so the adjoint is the sum of those two entries.
dd = diag(ȳ)
de = diag(ȳ, 1) + diag(ȳ, -1)
end
return (NoTangent(), project_d(dd), project_e(de))
end
return y, pullback
end
"""
ChainRulesCore.rrule(::typeof(sum), Q::SymTridiagonal)
ChainRule for `sum(::SymTridiagonal)`.
Restoring the rrules above triggers method invalidation that exposes a bug in
ChainRulesCore's `ProjectTo{SymTridiagonal}`: it extracts only one triangle of
the off-diagonal, dropping the factor of 2 from symmetry. The explicit
`sum(::SymTridiagonal)` rrule sidesteps the projection and returns the
correctly-doubled off-diagonal tangent.
"""
function ChainRulesCore.rrule(::typeof(sum), Q::SymTridiagonal)
function sum_symtridiag_pullback(ȳ)
s = unthunk(ȳ)
return NoTangent(), Tangent{SymTridiagonal}(dv = fill(s, length(Q.dv)), ev = fill(2s, length(Q.ev)))
end
return sum(Q), sum_symtridiag_pullback
end
"""
compute_precision_gradient(Qinv::AbstractMatrix, r::AbstractVector, ȳ::Real)
Compute the gradient of log-density w.r.t. precision matrix Q.
The gradient is: ∂logpdf/∂Q = 0.5 * ȳ * (Q⁻¹ - r*rᵀ)
This function uses multiple dispatch to efficiently compute the gradient for different
matrix types that may be returned by `selinv`:
- `SparseMatrixCSC`: Uses sparsity pattern to avoid dense operations
- `SymTridiagonal`: Uses tridiagonal structure
- `Symmetric{SparseMatrixCSC}`: Preserves symmetry and sparsity
- Generic fallback: May be inefficient for large matrices (issues warning)
# Arguments
- `Qinv`: Inverse precision matrix from selected inversion
- `r`: Residual vector (z - μ)
- `ȳ`: Incoming gradient scalar
# Returns
Gradient matrix with same structure as `Qinv`
"""
function compute_precision_gradient(Qinv::AbstractMatrix, r::AbstractVector, ȳ::Real)
@warn "Using generic fallback for precision gradient computation with matrix type $(typeof(Qinv)). " *
"This may be inefficient for large matrices. Consider using a factorization that returns " *
"SparseMatrixCSC or SymTridiagonal." maxlog = 1
# Generic approach - works but allocates full outer product
return @. 0.5 * ȳ * (Qinv - r * r')
end
"""
compute_precision_gradient(Qinv::SparseMatrixCSC, r::AbstractVector, ȳ::Real)
Efficient gradient computation for sparse matrices using sparsity pattern.
"""
function compute_precision_gradient(Qinv::SparseMatrixCSC, r::AbstractVector, ȳ::Real)
# Extract sparsity structure
rows, cols, vals = findnz(Qinv)
# Compute outer product values only at nonzero locations
rr_vals = r[rows] .* r[cols]
# Build sparse gradient matrix
return sparse(rows, cols, (0.5 * ȳ) .* (vals .- rr_vals), size(Qinv)...)
end
"""
compute_precision_gradient(Qinv::SymTridiagonal, r::AbstractVector, ȳ::Real)
Efficient gradient computation for symmetric tridiagonal matrices.
"""
function compute_precision_gradient(Qinv::SymTridiagonal, r::AbstractVector, ȳ::Real)
n = length(r)
# Diagonal: Qinv.dv - r .* r
dv = @. 0.5 * ȳ * (Qinv.dv - r * r)
# Off-diagonal: Qinv.ev - r[1:n-1] .* r[2:n]
ev = @. 0.5 * ȳ * (Qinv.ev - r[1:(n - 1)] * r[2:n])
return SymTridiagonal(dv, ev)
end
"""
compute_precision_gradient(Qinv::Symmetric{T, <:SparseMatrixCSC}, r, ȳ) where T
Efficient gradient computation for symmetric sparse matrices.
"""
function compute_precision_gradient(Qinv::Symmetric{T, <:SparseMatrixCSC}, r::AbstractVector, ȳ::Real) where {T}
# Extract sparsity structure from underlying data
rows, cols, vals = findnz(Qinv.data)
# Compute outer product values only at nonzero locations
rr_vals = r[rows] .* r[cols]
# Build sparse gradient matrix and wrap in Symmetric
grad_data = sparse(rows, cols, (0.5 * ȳ) .* (vals .- rr_vals), size(Qinv)...)
return Symmetric(grad_data, Symbol(Qinv.uplo))
end