|
1 | 1 | module PEPSKitMooncakeExt |
2 | 2 |
|
3 | 3 | using PEPSKit, TensorKit, Mooncake, MatrixAlgebraKit |
4 | | -using PEPSKit: SVDAdjoint, EighAdjoint |
| 4 | +using PEPSKit: SVDAdjoint, EighAdjoint, QRAdjoint |
5 | 5 | using Mooncake: DefaultCtx, CoDual, Dual, NoRData, primal, rrule!!, arrayify, @is_primitive |
6 | 6 |
|
7 | 7 | _warn_pullback_truncerror(dϵ::Real; tol = MatrixAlgebraKit.defaulttol(dϵ)) = |
8 | 8 | abs(dϵ) ≤ tol || @warn "Pullback ignores non-zero tangents for truncation error" |
9 | 9 |
|
10 | 10 | Mooncake.tangent_type(::Type{<:PEPSKit.SVDAdjoint}) = Mooncake.NoTangent |
11 | 11 | Mooncake.tangent_type(::Type{<:PEPSKit.EighAdjoint}) = Mooncake.NoTangent |
| 12 | +Mooncake.tangent_type(::Type{<:PEPSKit.QRAdjoint}) = Mooncake.NoTangent |
12 | 13 |
|
13 | 14 | @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), TensorKit.AbstractTensorMap, SVDAdjoint} |
14 | 15 | function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{SVDAdjoint{F, R}}) where {F, R <: PEPSKit.FullPullback} |
@@ -109,4 +110,23 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::C |
109 | 110 | return output_codual, eigh_trunc!_trunc_pullback |
110 | 111 | end |
111 | 112 |
|
| 113 | +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(left_orth), TensorKit.AbstractTensorMap, QRAdjoint} |
| 114 | +function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.left_orth)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{QRAdjoint}) |
| 115 | + t, dt = arrayify(t_dt) |
| 116 | + alg = primal(alg_dalg) |
| 117 | + |
| 118 | + QR = left_orth(t, alg) |
| 119 | + gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity) |
| 120 | + |
| 121 | + output_codual = Mooncake.zero_fcodual(QR) |
| 122 | + dQ_, dR_ = Mooncake.tangent(output_codual) |
| 123 | + Q, dQ = arrayify(Q, dQ_) |
| 124 | + R, dR = arrayify(R, dR_) |
| 125 | + function left_orth_pullback(::NoRData) |
| 126 | + MatrixAlgebraKit.qr_pullback!(dt, t, QR, (dQ, dR); gauge_atol = gtol(dQR)) |
| 127 | + return ntuple(Returns(NoRData()), 3) |
| 128 | + end |
| 129 | + return output_codual, left_orth_pullback |
| 130 | +end |
| 131 | + |
112 | 132 | end |
0 commit comments