Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ for qr_f in (:qr_compact, :qr_full)
QR = $(qr_f!)(Ac, QR, alg)
function qr_pullback(ΔQR)
ΔA = zero(A)
MatrixAlgebraKit.qr_compact_pullback!(ΔA, A, QR, unthunk.(ΔQR))
MatrixAlgebraKit.qr_pullback!(ΔA, A, QR, unthunk.(ΔQR))
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function qr_pullback(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
Expand All @@ -46,7 +46,7 @@ function ChainRulesCore.rrule(::typeof(qr_null!), A::AbstractMatrix, N, alg)
minmn = min(m, n)
ΔQ = zero!(similar(A, (m, m)))
view(ΔQ, 1:m, (minmn + 1):m) .= unthunk.(ΔN)
MatrixAlgebraKit.qr_compact_pullback!(ΔA, A, (Q, R), (ΔQ, ZeroTangent()))
MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ZeroTangent()))
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function qr_null_pullback(::ZeroTangent) # is this extra definition useful?
Expand All @@ -63,7 +63,7 @@ for lq_f in (:lq_compact, :lq_full)
LQ = $(lq_f!)(Ac, LQ, alg)
function lq_pullback(ΔLQ)
ΔA = zero(A)
MatrixAlgebraKit.lq_compact_pullback!(ΔA, A, LQ, unthunk.(ΔLQ))
MatrixAlgebraKit.lq_pullback!(ΔA, A, LQ, unthunk.(ΔLQ))
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function lq_pullback(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
Expand All @@ -84,7 +84,7 @@ function ChainRulesCore.rrule(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, alg)
minmn = min(m, n)
ΔQ = zero!(similar(A, (n, n)))
view(ΔQ, (minmn + 1):n, 1:n) .= unthunk.(ΔNᴴ)
MatrixAlgebraKit.lq_compact_pullback!(ΔA, A, (L, Q), (ZeroTangent(), ΔQ))
MatrixAlgebraKit.lq_pullback!(ΔA, A, (L, Q), (ZeroTangent(), ΔQ))
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function lq_null_pullback(::ZeroTangent) # is this extra definition useful?
Expand Down
8 changes: 5 additions & 3 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
)
eval(
Expr(
:public, :qr_compact_pullback!, :lq_compact_pullback!, :left_polar_pullback!,
:right_polar_pullback!, :eig_pullback!, :eig_trunc_pullback!, :eigh_pullback!,
:eigh_trunc_pullback!, :svd_pullback!, :svd_trunc_pullback!
:public, :qr_pullback!, :lq_pullback!, :svd_pullback!, :svd_trunc_pullback!,
:eig_pullback!, :eig_trunc_pullback!, :eigh_pullback!, :eigh_trunc_pullback!,
:left_polar_pullback!, :right_polar_pullback!
)
)
end
Expand Down Expand Up @@ -103,4 +103,6 @@ include("pullbacks/eigh.jl")
include("pullbacks/svd.jl")
include("pullbacks/polar.jl")

include("deprecate.jl")

end
2 changes: 2 additions & 0 deletions src/deprecate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Base.@deprecate qr_compact_pullback!(args...) qr_pullback!(args...) false
Base.@deprecate lq_compact_pullback!(args...) lq_pullback!(args...) false
Comment thread
lkdvos marked this conversation as resolved.
Outdated
10 changes: 5 additions & 5 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
lq_compact_pullback!(
lq_pullback!(
ΔA, A, LQ, ΔLQ;
tol::Real=default_pullback_gaugetol(LQ[1]),
rank_atol::Real=tol,
gauge_atol::Real=tol
tol::Real = default_pullback_gaugetol(LQ[1]),
rank_atol::Real = tol,
gauge_atol::Real = tol
)

Adds the pullback from the LQ decomposition of `A` to `ΔA` given the output `LQ` and
Expand All @@ -16,7 +16,7 @@ well-defined, and also the adjoint variables `ΔL` and `ΔQ` should have nonzero
in the first `r` columns and rows respectively. If nonzero values in the remaining columns
or rows exceed `gauge_atol`, a warning will be printed.
"""
function lq_compact_pullback!(
function lq_pullback!(
ΔA::AbstractMatrix, A, LQ, ΔLQ;
tol::Real = default_pullback_gaugetol(LQ[1]),
rank_atol::Real = tol,
Expand Down
10 changes: 5 additions & 5 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
qr_compact_pullback!(
qr_pullback!(
ΔA, A, QR, ΔQR;
tol::Real=default_pullback_gaugetol(QR[2]),
rank_atol::Real=tol,
gauge_atol::Real=tol
tol::Real = default_pullback_gaugetol(QR[2]),
rank_atol::Real = tol,
gauge_atol::Real = tol
)

Adds the pullback from the QR decomposition of `A` to `ΔA` given the output `QR` and
Expand All @@ -16,7 +16,7 @@ and also the adjoint variables `ΔQ` and `ΔR` should have nonzero values only i
`r` columns and rows respectively. If nonzero values in the remaining columns or rows exceed
`gauge_atol`, a warning will be printed.
"""
function qr_compact_pullback!(
function qr_pullback!(
ΔA::AbstractMatrix, A, QR, ΔQR;
tol::Real = default_pullback_gaugetol(QR[2]),
rank_atol::Real = tol,
Expand Down