Skip to content

Commit cd2e0e3

Browse files
Jutholkdvos
authored andcommitted
update qr/lq_null pullbacks
1 parent 9448878 commit cd2e0e3

3 files changed

Lines changed: 40 additions & 24 deletions

File tree

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,10 @@ for qr_f in (:qr_compact, :qr_full)
4343
end
4444
function ChainRulesCore.rrule(::typeof(qr_null!), A::AbstractMatrix, N, alg)
4545
Ac = copy_input(qr_full, A)
46-
QR = qr_full!(Ac, initialize_output(qr_full!, A, alg), alg)
47-
N = copy!(N, view(QR[1], 1:size(A, 1), (size(A, 2) + 1):size(A, 1)))
46+
N = qr_null!(Ac, N, alg)
4847
function qr_null_pullback(ΔN)
4948
ΔA = zero(A)
50-
MatrixAlgebraKit.qr_null_pullback!(ΔA, A, QR, unthunk(ΔN))
49+
MatrixAlgebraKit.qr_null_pullback!(ΔA, A, N, unthunk(ΔN))
5150
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
5251
end
5352
function qr_null_pullback(::ZeroTangent) # is this extra definition useful?
@@ -76,11 +75,10 @@ for lq_f in (:lq_compact, :lq_full)
7675
end
7776
function ChainRulesCore.rrule(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, alg)
7877
Ac = copy_input(lq_full, A)
79-
LQ = lq_full!(Ac, initialize_output(lq_full!, A, alg), alg)
80-
Nᴴ = copy!(Nᴴ, view(LQ[2], (size(A, 1) + 1):size(A, 2), 1:size(A, 2)))
78+
Nᴴ = lq_null!(Ac, Nᴴ, alg)
8179
function lq_null_pullback(ΔNᴴ)
8280
ΔA = zero(A)
83-
MatrixAlgebraKit.lq_null_pullback!(ΔA, A, LQ, unthunk(ΔNᴴ))
81+
MatrixAlgebraKit.lq_null_pullback!(ΔA, A, Nᴴ, unthunk(ΔNᴴ))
8482
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
8583
end
8684
function lq_null_pullback(::ZeroTangent) # is this extra definition useful?

src/pullbacks/lq.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,26 @@ function lq_pullback!(
105105
end
106106

107107
"""
108-
lq_null_pullback(ΔA, A, LQ, ΔNᴴ)
108+
lq_null_pullback(ΔA, A, Nᴴ, ΔNᴴ)
109109
110-
Adds the pullback from the nullspace of the LQ decomposition of `A` to `ΔA` given the
111-
factorization `LQ` and cotangent `ΔNᴴ` of `lq_null(A)`.
110+
Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis
111+
`Nᴴ` and its cotangent `ΔNᴴ` of `lq_null(A)`.
112112
113113
See also [`lq_compact_pullback!`](@ref).
114114
"""
115-
function lq_null_pullback!(ΔA::AbstractMatrix, A, LQ, ΔNᴴ)
116-
(m, n) = size(A)
117-
minmn = min(m, n)
118-
ΔQ = zero!(similar(A, (n, n)))
119-
view(ΔQ, (minmn + 1):n, 1:n) .= ΔNᴴ
120-
lq_compact_pullback!(ΔA, A, LQ, (nothing, ΔQ))
115+
function lq_null_pullback!(
116+
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
117+
tol::Real = default_pullback_gaugetol(A),
118+
gauge_atol::Real = tol
119+
)
120+
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
121+
NᴴΔN = Nᴴ * ΔNᴴ'
122+
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
123+
Δgauge < tol ||
124+
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
125+
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
126+
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
127+
ΔA = mul!(ΔA, X, Nᴴ, -1, 1)
128+
end
121129
return ΔA
122130
end

src/pullbacks/qr.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,27 @@ function qr_pullback!(
104104
end
105105

106106
"""
107-
qr_null_pullback(ΔA, A, QR, ΔNᴴ)
107+
qr_null_pullback(ΔA, A, N, ΔN)
108108
109-
Adds the pullback from the nullspace of the QR decomposition of `A` to `ΔA` given the
110-
factorization `QR` and cotangent `ΔN` of `qr_null(A)`.
109+
Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis
110+
`N` and its cotangent `ΔN` of `qr_null(A)`.
111111
112112
See also [`qr_compact_pullback!`](@ref).
113113
"""
114-
function qr_null_pullback!(ΔA::AbstractMatrix, A, QR, ΔN)
115-
m, n = size(A)
116-
minmn = min(m, n)
117-
ΔQ = zero!(similar(A, (m, m)))
118-
view(ΔQ, 1:m, (minmn + 1):m) .= ΔN
119-
return qr_compact_pullback!(ΔA, A, QR, (ΔQ, nothing))
114+
function qr_null_pullback!(
115+
ΔA::AbstractMatrix, A, N, ΔN;
116+
tol::Real = default_pullback_gaugetol(A),
117+
gauge_atol::Real = tol
118+
)
119+
if !iszerotangent(ΔN) && size(N, 2) > 0
120+
NᴴΔN = N' * ΔN
121+
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
122+
Δgauge < tol ||
123+
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
124+
125+
Q, R = qr_compact(A; positive = true)
126+
X = rdiv!(ΔN' * Q, UpperTriangular(R)')
127+
ΔA = mul!(ΔA, N, X, -1, 1)
128+
end
129+
return ΔA
120130
end

0 commit comments

Comments
 (0)