Skip to content

Commit 6f74cef

Browse files
committed
Some more Mooncake support
1 parent 16da8a9 commit 6f74cef

1 file changed

Lines changed: 21 additions & 7 deletions

File tree

ext/PEPSKitMooncakeExt.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
11
module PEPSKitMooncakeExt
22

3-
using PEPSKit, TensorKit, Mooncake, MatrixAlgebraKit
4-
using PEPSKit: SVDAdjoint, EighAdjoint, QRAdjoint
5-
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, primal, rrule!!, arrayify, @is_primitive
3+
using PEPSKit, MPSKit, TensorKit, Mooncake, MatrixAlgebraKit
4+
using PEPSKit: SVDAdjoint, EighAdjoint, QRAdjoint, CTMRGAlgorithm, FixedPointGradient, sdiag_pow
5+
import PEPSKit: real_inner
6+
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, primal, tangent, rrule!!, arrayify, @is_primitive
7+
8+
function Mooncake.arrayify::PEPSKit.InfinitePEPS{T}, dψ) where {T}
9+
Δψmat = map((a, da) -> Mooncake.arrayify(a, da)[2], ψ.A, dψ.fields.A)
10+
Δψ = PEPSKit.InfinitePEPS{T}(Δψmat)
11+
return ψ, Δψ
12+
end
613

714
_warn_pullback_truncerror(dϵ::Real; tol = MatrixAlgebraKit.defaulttol(dϵ)) =
815
abs(dϵ) tol || @warn "Pullback ignores non-zero tangents for truncation error"
916

1017
Mooncake.tangent_type(::Type{<:PEPSKit.SVDAdjoint}) = Mooncake.NoTangent
1118
Mooncake.tangent_type(::Type{<:PEPSKit.EighAdjoint}) = Mooncake.NoTangent
1219
Mooncake.tangent_type(::Type{<:PEPSKit.QRAdjoint}) = Mooncake.NoTangent
20+
Mooncake.tangent_type(::Type{<:PEPSKit.CTMRGAlgorithm}) = Mooncake.NoTangent
21+
Mooncake.tangent_type(::Type{<:PEPSKit.FixedPointGradient}) = Mooncake.NoTangent
22+
23+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(PEPSKit.eachcoordinate), Any, Any}
24+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(PEPSKit._next_coordinate), Int, Int}
25+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(PEPSKit._set_decomposition_truncation), Any, Any}
26+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(PEPSKit.CTMRGEnv), Union{PEPSKit.InfinitePartitionFunction, PEPSKit.InfinitePEPS}, Vararg}
1327

1428
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), TensorKit.AbstractTensorMap, SVDAdjoint}
1529
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{SVDAdjoint{F, R}}) where {F, R <: PEPSKit.FullPullback}
@@ -32,7 +46,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, t_dt::Co
3246
Δt, t, (U, S, V⁺), ΔUSVᴴtrunc, inds;
3347
gauge_atol = gtol(ΔUSVᴴtrunc), degeneracy_atol = alg.rrule_alg.degeneracy_atol,
3448
)
35-
return NoRData(), NoRData(), NoRData()
49+
return NoRData(), NoRData(), NoRData(), zero(dϵ)
3650
end
3751
return output_codual, svd_trunc!_full_pullback
3852
end
@@ -64,11 +78,11 @@ end
6478
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{EighAdjoint{F, R}}) where {F, R <: PEPSKit.FullPullback}
6579
t, dt = arrayify(t_dt)
6680
alg = primal(alg_dalg)
67-
81+
6882
D, V = eigh_full!(t; alg.fwd_alg.alg)
6983
(D̃, Ṽ), inds = MatrixAlgebraKit.truncate(eigh_trunc!, (D, V), alg.fwd_alg.trunc)
7084
ϵ = MatrixAlgebraKit.truncation_error(diagview(D), inds)
71-
85+
7286
DVtrunc = (D̃, Ṽ)
7387
# pack output
7488
DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ))
@@ -89,7 +103,7 @@ end
89103
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{EighAdjoint{F, R}}) where {F, R <: PEPSKit.TruncPullback}
90104
t, dt = arrayify(t_dt)
91105
alg = primal(alg_dalg)
92-
106+
93107
D, V, truncerror = eigh_trunc(t, alg)
94108
gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)
95109
output = (D, V, truncerror)

0 commit comments

Comments
 (0)