Skip to content

Commit 16da8a9

Browse files
committed
Add support for QRAdjoint
1 parent 37fca8d commit 16da8a9

1 file changed

Lines changed: 21 additions & 1 deletion

File tree

ext/PEPSKitMooncakeExt.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
module PEPSKitMooncakeExt
22

33
using PEPSKit, TensorKit, Mooncake, MatrixAlgebraKit
4-
using PEPSKit: SVDAdjoint, EighAdjoint
4+
using PEPSKit: SVDAdjoint, EighAdjoint, QRAdjoint
55
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, primal, rrule!!, arrayify, @is_primitive
66

77
_warn_pullback_truncerror(dϵ::Real; tol = MatrixAlgebraKit.defaulttol(dϵ)) =
88
abs(dϵ) tol || @warn "Pullback ignores non-zero tangents for truncation error"
99

1010
Mooncake.tangent_type(::Type{<:PEPSKit.SVDAdjoint}) = Mooncake.NoTangent
1111
Mooncake.tangent_type(::Type{<:PEPSKit.EighAdjoint}) = Mooncake.NoTangent
12+
Mooncake.tangent_type(::Type{<:PEPSKit.QRAdjoint}) = Mooncake.NoTangent
1213

1314
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), TensorKit.AbstractTensorMap, SVDAdjoint}
1415
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{SVDAdjoint{F, R}}) where {F, R <: PEPSKit.FullPullback}
@@ -109,4 +110,23 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::C
109110
return output_codual, eigh_trunc!_trunc_pullback
110111
end
111112

113+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(left_orth), TensorKit.AbstractTensorMap, QRAdjoint}
114+
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.left_orth)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{QRAdjoint})
115+
t, dt = arrayify(t_dt)
116+
alg = primal(alg_dalg)
117+
118+
QR = left_orth(t, alg)
119+
gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)
120+
121+
output_codual = Mooncake.zero_fcodual(QR)
122+
dQ_, dR_ = Mooncake.tangent(output_codual)
123+
Q, dQ = arrayify(Q, dQ_)
124+
R, dR = arrayify(R, dR_)
125+
function left_orth_pullback(::NoRData)
126+
MatrixAlgebraKit.qr_pullback!(dt, t, QR, (dQ, dR); gauge_atol = gtol(dQR))
127+
return ntuple(Returns(NoRData()), 3)
128+
end
129+
return output_codual, left_orth_pullback
130+
end
131+
112132
end

0 commit comments

Comments
 (0)