Skip to content

Commit 13e6245

Browse files
Jutholkdvos
andauthored
WIP: chainrules (#11)
* first chainrules * add LQ * revert default positive lq qr * add qr_null and lq_null rrule * add polar AD rules * fix bug * add orth null tests * Update ext/MatrixAlgebraKitChainRulesCoreExt.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * increase coverage --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 5810500 commit 13e6245

17 files changed

Lines changed: 907 additions & 373 deletions

Project.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,32 @@ version = "0.1.1"
66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

9+
[weakdeps]
10+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
11+
12+
[extensions]
13+
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
14+
915
[compat]
1016
Aqua = "0.6, 0.7, 0.8"
17+
ChainRulesCore = "1"
18+
ChainRulesTestUtils = "1"
1119
JET = "0.9"
1220
LinearAlgebra = "1"
1321
StableRNGs = "1"
1422
Test = "1"
1523
TestExtras = "0.2,0.3"
24+
Zygote = "0.7"
1625
julia = "1.10"
1726

1827
[extras]
1928
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
29+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2030
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
2131
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2232
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2333
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
34+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2435

2536
[targets]
26-
test = ["Aqua", "JET", "Test", "TestExtras", "StableRNGs"]
37+
test = ["Aqua", "JET", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"]
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
module MatrixAlgebraKitChainRulesCoreExt
2+
3+
using MatrixAlgebraKit
4+
using MatrixAlgebraKit: copy_input, TruncatedAlgorithm, zero!
5+
using ChainRulesCore
6+
using LinearAlgebra
7+
8+
# TODO: Decide on an interface to pass on the kwargs for the pullback functions
9+
# from the primal function calls
10+
11+
MatrixAlgebraKit.iszerotangent(::AbstractZero) = true
12+
13+
function ChainRulesCore.rrule(::typeof(copy_input), f, A::AbstractMatrix)
14+
project = ProjectTo(A)
15+
copy_input_pullback(ΔA) = (NoTangent(), NoTangent(), project(unthunk(ΔA)))
16+
return copy_input(f, A), copy_input_pullback
17+
end
18+
19+
for qr_f in (:qr_compact, :qr_full)
20+
qr_f! = Symbol(qr_f, '!')
21+
@eval begin
22+
function ChainRulesCore.rrule(::typeof($qr_f!), A::AbstractMatrix, QR, alg)
23+
Ac = copy_input($qr_f, A)
24+
QR = $(qr_f!)(Ac, QR, alg)
25+
function qr_pullback(ΔQR)
26+
ΔA = zero(A)
27+
MatrixAlgebraKit.qr_compact_pullback!(ΔA, QR, unthunk.(ΔQR))
28+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
29+
end
30+
function qr_pullback(::Tuple{ZeroTangent,ZeroTangent}) # is this extra definition useful?
31+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
32+
end
33+
return QR, qr_pullback
34+
end
35+
end
36+
end
37+
function ChainRulesCore.rrule(::typeof(qr_null!), A::AbstractMatrix, N, alg)
38+
Ac = copy_input(qr_full, A)
39+
QR = MatrixAlgebraKit.initialize_output(qr_full!, A, alg)
40+
Q, R = qr_full!(Ac, QR, alg)
41+
N = copy!(N, view(Q, 1:size(A, 1), (size(A, 2) + 1):size(A, 1)))
42+
function qr_null_pullback(ΔN)
43+
ΔA = zero(A)
44+
(m, n) = size(A)
45+
minmn = min(m, n)
46+
ΔQ = zero!(similar(A, (m, m)))
47+
view(ΔQ, 1:m, (minmn + 1):m) .= unthunk.(ΔN)
48+
MatrixAlgebraKit.qr_compact_pullback!(ΔA, (Q, R), (ΔQ, ZeroTangent()))
49+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
50+
end
51+
function qr_null_pullback(::ZeroTangent) # is this extra definition useful?
52+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
53+
end
54+
return N, qr_null_pullback
55+
end
56+
57+
for lq_f in (:lq_compact, :lq_full)
58+
lq_f! = Symbol(lq_f, '!')
59+
@eval begin
60+
function ChainRulesCore.rrule(::typeof($lq_f!), A::AbstractMatrix, LQ, alg)
61+
Ac = copy_input($lq_f, A)
62+
LQ = $(lq_f!)(Ac, LQ, alg)
63+
function lq_pullback(ΔLQ)
64+
ΔA = zero(A)
65+
MatrixAlgebraKit.lq_compact_pullback!(ΔA, LQ, unthunk.(ΔLQ))
66+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
67+
end
68+
function lq_pullback(::Tuple{ZeroTangent,ZeroTangent}) # is this extra definition useful?
69+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
70+
end
71+
return LQ, lq_pullback
72+
end
73+
end
74+
end
75+
function ChainRulesCore.rrule(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, alg)
76+
Ac = copy_input(lq_full, A)
77+
LQ = MatrixAlgebraKit.initialize_output(lq_full!, A, alg)
78+
L, Q = lq_full!(Ac, LQ, alg)
79+
Nᴴ = copy!(Nᴴ, view(Q, (size(A, 1) + 1):size(A, 2), 1:size(A, 2)))
80+
function lq_null_pullback(ΔNᴴ)
81+
ΔA = zero(A)
82+
(m, n) = size(A)
83+
minmn = min(m, n)
84+
ΔQ = zero!(similar(A, (n, n)))
85+
view(ΔQ, (minmn + 1):n, 1:n) .= unthunk.(ΔNᴴ)
86+
MatrixAlgebraKit.lq_compact_pullback!(ΔA, (L, Q), (ZeroTangent(), ΔQ))
87+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
88+
end
89+
function lq_null_pullback(::ZeroTangent) # is this extra definition useful?
90+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
91+
end
92+
return Nᴴ, lq_null_pullback
93+
end
94+
95+
for eig in (:eig, :eigh)
96+
eig_f = Symbol(eig, "_full")
97+
eig_f! = Symbol(eig_f, "!")
98+
eig_f_pb! = Symbol(eig, "_full_pullback!")
99+
eig_pb = Symbol(eig, "_pullback")
100+
@eval begin
101+
function ChainRulesCore.rrule(::typeof($eig_f!), A::AbstractMatrix, DV, alg)
102+
Ac = copy_input($eig_f, A)
103+
DV = $(eig_f!)(Ac, DV, alg)
104+
function $eig_pb(ΔDV)
105+
ΔA = zero(A)
106+
MatrixAlgebraKit.$eig_f_pb!(ΔA, DV, unthunk.(ΔDV))
107+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
108+
end
109+
function $eig_pb(::Tuple{ZeroTangent,ZeroTangent}) # is this extra definition useful?
110+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
111+
end
112+
return DV, $eig_pb
113+
end
114+
end
115+
end
116+
117+
for svd_f in (:svd_compact, :svd_full)
118+
svd_f! = Symbol(svd_f, "!")
119+
@eval begin
120+
function ChainRulesCore.rrule(::typeof($svd_f!), A::AbstractMatrix, USVᴴ, alg)
121+
Ac = copy_input($svd_f, A)
122+
USVᴴ = $(svd_f!)(Ac, USVᴴ, alg)
123+
function svd_pullback(ΔUSVᴴ)
124+
ΔA = zero(A)
125+
MatrixAlgebraKit.svd_compact_pullback!(ΔA, USVᴴ, unthunk.(ΔUSVᴴ))
126+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
127+
end
128+
function svd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) # is this extra definition useful?
129+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
130+
end
131+
return USVᴴ, svd_pullback
132+
end
133+
end
134+
end
135+
136+
function ChainRulesCore.rrule(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ,
137+
alg::TruncatedAlgorithm)
138+
Ac = MatrixAlgebraKit.copy_input(svd_compact, A)
139+
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
140+
function svd_trunc_pullback(ΔUSVᴴ)
141+
ΔA = zero(A)
142+
MatrixAlgebraKit.svd_compact_pullback!(ΔA, USVᴴ, unthunk.(ΔUSVᴴ))
143+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
144+
end
145+
function svd_trunc_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) # is this extra definition useful?
146+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
147+
end
148+
return MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, alg.trunc), svd_trunc_pullback
149+
end
150+
151+
function ChainRulesCore.rrule(::typeof(left_polar!), A::AbstractMatrix, WP, alg)
152+
Ac = copy_input(left_polar, A)
153+
WP = left_polar!(Ac, WP, alg)
154+
function left_polar_pullback(ΔWP)
155+
ΔA = zero(A)
156+
MatrixAlgebraKit.left_polar_pullback!(ΔA, WP, unthunk.(ΔWP))
157+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
158+
end
159+
function left_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) # is this extra definition useful?
160+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
161+
end
162+
return WP, left_polar_pullback
163+
end
164+
165+
function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, alg)
166+
Ac = copy_input(left_polar, A)
167+
PWᴴ = right_polar!(Ac, PWᴴ, alg)
168+
function right_polar_pullback(ΔPWᴴ)
169+
ΔA = zero(A)
170+
MatrixAlgebraKit.right_polar_pullback!(ΔA, PWᴴ, unthunk.(ΔPWᴴ))
171+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
172+
end
173+
function right_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) # is this extra definition useful?
174+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
175+
end
176+
return PWᴴ, right_polar_pullback
177+
end
178+
179+
end

src/MatrixAlgebraKit.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ module MatrixAlgebraKit
33
using LinearAlgebra: LinearAlgebra
44
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
55
using LinearAlgebra: mul!, rmul!, lmul!
6+
using LinearAlgebra: sylvester
67
using LinearAlgebra: isposdef, ishermitian
78
using LinearAlgebra: Diagonal, diag, diagind
8-
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu!, tril!
9+
using LinearAlgebra: UpperTriangular, LowerTriangular
10+
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu!, tril!, rdiv!, ldiv!
911

1012
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
1113
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
@@ -22,9 +24,10 @@ export left_polar!, right_polar!
2224
export left_orth, right_orth, left_null, right_null
2325
export left_orth!, right_orth!, left_null!, right_null!
2426

25-
export LAPACK_HouseholderQR, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration,
26-
LAPACK_Bisection, LAPACK_DivideAndConquer,
27-
LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_Jacobi
27+
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
28+
LAPACK_Simple, LAPACK_Expert,
29+
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
30+
LAPACK_DivideAndConquer, LAPACK_Jacobi
2831
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
2932

3033
include("common/defaults.jl")
@@ -56,4 +59,11 @@ include("implementations/schur.jl")
5659
include("implementations/polar.jl")
5760
include("implementations/orthnull.jl")
5861

62+
include("pullbacks/qr.jl")
63+
include("pullbacks/lq.jl")
64+
include("pullbacks/eig.jl")
65+
include("pullbacks/eigh.jl")
66+
include("pullbacks/svd.jl")
67+
include("pullbacks/polar.jl")
68+
5969
end

src/common/pullbacks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ pullback definitions in term of it, we will be able to hook into different AD
77
ecosystems
88
"""
99
function iszerotangent end
10+
11+
iszerotangent(::Any) = false

src/common/safemethods.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
# Sign
44
"""
5-
safesign(s::Number)
5+
sign_safe(s::Number)
66
77
Compute the sign of a number `s`, but return `+1` if `s` is zero so that the result is
88
always a number with modulus 1, i.e. an element of the unitary group U(1).
99
"""
10-
safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s))
11-
safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s))
10+
sign_safe(s::Real) = ifelse(s < zero(s), -one(s), +one(s))
11+
sign_safe(s::Complex) = ifelse(iszero(s), one(s), s / abs(s))
1212

1313
# Inverse
1414

src/implementations/decompositions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ elements of `R` are non-negative.
2727
Algorithm type to denote the standard LAPACK algorithm for computing the LQ decomposition of
2828
a matrix using Householder reflectors. The specific LAPACK function can be controlled using
2929
the keyword arugments, i.e. `?gelqt` will be chosen if `blocksize > 1` or `?gelqf` will be
30-
chosen if `blocksize == 1`. The keyword `positive =true` can be used to ensure that the diagonal
30+
chosen if `blocksize == 1`. The keyword `positive=true` can be used to ensure that the diagonal
3131
elements of `L` are non-negative.
3232
"""
3333
@algdef LAPACK_HouseholderLQ

src/implementations/lq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ function _lapack_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
119119
if positive # already fix Q even if we do not need R
120120
@inbounds for j in 1:n
121121
@simd for i in 1:minmn
122-
s = safesign(A[i, i])
122+
s = sign_safe(A[i, i])
123123
Q[i, j] *= s
124124
end
125125
end
@@ -129,7 +129,7 @@ function _lapack_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
129129
= tril!(view(A, axes(L)...))
130130
if positive
131131
@inbounds for j in 1:minmn
132-
s = conj(safesign(L̃[j, j]))
132+
s = conj(sign_safe(L̃[j, j]))
133133
@simd for i in j:m
134134
L̃[i, j] = L̃[i, j] * s
135135
end

src/implementations/qr.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function _lapack_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
122122

123123
if positive # already fix Q even if we do not need R
124124
@inbounds for j in 1:minmn
125-
s = safesign(A[j, j])
125+
s = sign_safe(A[j, j])
126126
@simd for i in 1:m
127127
Q[i, j] *= s
128128
end
@@ -134,7 +134,7 @@ function _lapack_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
134134
if positive
135135
@inbounds for j in n:-1:1
136136
@simd for i in 1:min(minmn, j)
137-
R̃[i, j] = R̃[i, j] * conj(safesign(R̃[i, i]))
137+
R̃[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
138138
end
139139
end
140140
end

0 commit comments

Comments
 (0)