Skip to content

Commit 71ee1f1

Browse files
kshyattlkdvos
andauthored
Mooncake forward rules for linalg (QuantumKitHub#434)
* Mooncake forward rules for linalg * Update ext/TensorKitMooncakeExt/linalg.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 033704e commit 71ee1f1

3 files changed

Lines changed: 88 additions & 20 deletions

File tree

ext/TensorKitMooncakeExt/linalg.jl

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
pullback_dC!(ΔC, β) = (scale!(ΔC, conj(β)); return NoRData())
44
pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData()
55

6-
@is_primitive DefaultCtx ReverseMode Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number}
6+
@is_primitive DefaultCtx Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number}
77

88
function Mooncake.rrule!!(
99
::CoDual{typeof(mul!)},
@@ -40,9 +40,29 @@ function Mooncake.rrule!!(
4040

4141
return C_ΔC, mul_pullback
4242
end
43+
function Mooncake.frule!!(
44+
::Dual{typeof(mul!)},
45+
C_ΔC::Dual{<:AbstractTensorMap}, A_ΔA::Dual{<:AbstractTensorMap}, B_ΔB::Dual{<:AbstractTensorMap},
46+
α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}
47+
)
48+
(C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB))
49+
α, Δα = Mooncake.extract(α_Δα)
50+
β, Δβ = Mooncake.extract(β_Δβ)
51+
# ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
52+
scale!(ΔC, β)
53+
if !isa(Δβ, Mooncake.NoTangent)
54+
add!(ΔC, C, Δβ)
55+
end
56+
if !isa(Δα, Mooncake.NoTangent)
57+
project_mul!(ΔC, A, B, Δα)
58+
end
59+
project_mul!(ΔC, ΔA, B, α)
60+
project_mul!(ΔC, A, ΔB, α)
61+
mul!(C, A, B, α, β)
62+
return C_ΔC
63+
end
4364

44-
@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real}
45-
65+
@is_primitive DefaultCtx Tuple{typeof(norm), AbstractTensorMap, Real}
4666
function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real})
4767
t, Δt = arrayify(tΔt)
4868
p = primal(pdp)
@@ -55,9 +75,16 @@ function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorM
5575
end
5676
return CoDual(n, Mooncake.NoFData()), norm_pullback
5777
end
78+
function Mooncake.frule!!(::Dual{typeof(norm)}, tΔt::Dual{<:AbstractTensorMap}, pdp::Dual{<:Real})
79+
t, Δt = arrayify(tΔt)
80+
p, Δp = Mooncake.extract(pdp)
81+
p == 2 || error("currently only implemented for p = 2")
82+
n = norm(t, p)
83+
Δn = real(dot(t, Δt)) * pinv(n)
84+
return Dual(n, Δn)
85+
end
5886

59-
@is_primitive DefaultCtx ReverseMode Tuple{typeof(tr), AbstractTensorMap}
60-
87+
@is_primitive DefaultCtx Tuple{typeof(tr), AbstractTensorMap}
6188
function Mooncake.rrule!!(::CoDual{typeof(tr)}, A_ΔA::CoDual{<:AbstractTensorMap})
6289
A, ΔA = arrayify(A_ΔA)
6390
trace = tr(A)
@@ -71,8 +98,12 @@ function Mooncake.rrule!!(::CoDual{typeof(tr)}, A_ΔA::CoDual{<:AbstractTensorMa
7198

7299
return CoDual(trace, Mooncake.NoFData()), tr_pullback
73100
end
101+
function Mooncake.frule!!(::Dual{typeof(tr)}, A_ΔA::Dual{<:AbstractTensorMap})
102+
A, ΔA = arrayify(A_ΔA)
103+
return Dual(tr(A), tr(ΔA))
104+
end
74105

75-
@is_primitive DefaultCtx ReverseMode Tuple{typeof(inv), AbstractTensorMap}
106+
@is_primitive DefaultCtx Tuple{typeof(inv), AbstractTensorMap}
76107

77108
function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorMap})
78109
A, ΔA = arrayify(A_ΔA)
@@ -86,13 +117,21 @@ function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorM
86117

87118
return Ainv_ΔAinv, inv_pullback
88119
end
120+
function Mooncake.frule!!(::Dual{typeof(inv)}, A_ΔA::Dual{<:AbstractTensorMap})
121+
A, ΔA = arrayify(A_ΔA)
122+
Ainv = inv(A)
123+
ΔAinv = scale!(Ainv * ΔA * Ainv, -1)
124+
return Dual(Ainv, ΔAinv)
125+
end
89126

90127
# single-output projections: project_hermitian!, project_antihermitian!
91128
for (f!, f, adj) in (
92129
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
93130
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
94131
)
95132
@eval begin
133+
@is_primitive DefaultCtx Tuple{typeof($f!), AbstractTensorMap, AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
134+
@is_primitive DefaultCtx Tuple{typeof($f), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
96135
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
97136
A, dA = arrayify(A_dA)
98137
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)
@@ -113,7 +152,13 @@ for (f!, f, adj) in (
113152

114153
return arg_darg, $adj
115154
end
116-
155+
function Mooncake.frule!!(f_df::Dual{typeof($f!)}, A_dA::Dual{<:AbstractTensorMap}, arg_darg::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
156+
A, dA = arrayify(A_dA)
157+
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)
158+
arg = $f!(A, arg, Mooncake.primal(alg_dalg))
159+
$f!(dA, darg, Mooncake.primal(alg_dalg))
160+
return arg_darg
161+
end
117162
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
118163
A, dA = arrayify(A_dA)
119164
output = $f(A, Mooncake.primal(alg_dalg))
@@ -129,5 +174,11 @@ for (f!, f, adj) in (
129174

130175
return output_doutput, $adj
131176
end
177+
function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual{<:AbstractTensorMap}, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
178+
A, dA = arrayify(A_dA)
179+
output = $f(A, Mooncake.primal(alg_dalg))
180+
doutput = $f(dA, Mooncake.primal(alg_dalg))
181+
return Dual(output, doutput)
182+
end
132183
end
133184
end

ext/TensorKitMooncakeExt/tangent.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Arrayify is needed to make MatrixAlgebraKit function properly -
22
# it turns coduals into argument types that MAK knows how to handle.
33
Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) = arrayify(primal(A_dA), tangent(A_dA))
4+
Mooncake.arrayify(A_dA::Dual{<:TensorMap}) = arrayify(primal(A_dA), tangent(A_dA))
45
Mooncake.arrayify(A::TensorMap, dA::TensorMap) = (A, dA)
56

67
Mooncake.arrayify(A_dA::CoDual{<:DiagonalTensorMap}) = arrayify(primal(A_dA), tangent(A_dA))
8+
Mooncake.arrayify(A_dA::Dual{<:DiagonalTensorMap}) = arrayify(primal(A_dA), tangent(A_dA))
79
Mooncake.arrayify(A::DiagonalTensorMap, dA::DiagonalTensorMap) = (A, dA)
810

911
function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TK.AdjointTensorMap})
@@ -14,6 +16,14 @@ function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TK.AdjointTensorMap})
1416
return A', ΔA'
1517
end
1618

19+
function Mooncake.arrayify(Aᴴ_ΔAᴴ::Dual{<:TK.AdjointTensorMap})
20+
Aᴴ = Mooncake.primal(Aᴴ_ΔAᴴ)
21+
ΔAᴴ = Mooncake.tangent(Aᴴ_ΔAᴴ)
22+
A_ΔA = Dual(Aᴴ', ΔAᴴ.fields.parent)
23+
A, ΔA = arrayify(A_ΔA)
24+
return A', ΔA'
25+
end
26+
1727
# Define the tangent type of a TensorMap to be TensorMap itself.
1828
# This has a number of benefits, but also correctly alters the
1929
# inner product when dealing with non-abelian symmetries.

test/mooncake/linalg.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ using TensorKit
33
using Mooncake
44
using Random
55

6-
7-
mode = Mooncake.ReverseMode
86
rng = Random.default_rng()
97

108
spacelist = ad_spacelist(fast_tests)
@@ -20,21 +18,30 @@ eltypes = (Float64, ComplexF64)
2018
α = randn(T)
2119
β = randn(T)
2220

23-
Mooncake.TestUtils.test_rule(rng, mul!, C, A, B, α, β; atol, rtol, mode)
24-
Mooncake.TestUtils.test_rule(rng, mul!, C, A, B; atol, rtol, mode, is_primitive = false)
21+
Mooncake.TestUtils.test_rule(rng, mul!, C, A, B, α, β; atol, rtol)
22+
Mooncake.TestUtils.test_rule(rng, mul!, C, A, B; atol, rtol, is_primitive = false)
2523

26-
Mooncake.TestUtils.test_rule(rng, norm, C, 2; atol, rtol, mode)
27-
Mooncake.TestUtils.test_rule(rng, norm, C', 2; atol, rtol, mode)
24+
Mooncake.TestUtils.test_rule(rng, norm, C, 2; atol, rtol)
25+
Mooncake.TestUtils.test_rule(rng, norm, C', 2; atol, rtol)
2826

2927
D1 = randn(T, V[1] V[1])
3028
D2 = randn(T, V[1] V[2] V[1] V[2])
3129
D3 = randn(T, V[1] V[2] V[3] V[1] V[2] V[3])
3230

33-
Mooncake.TestUtils.test_rule(rng, tr, D1; atol, rtol, mode)
34-
Mooncake.TestUtils.test_rule(rng, tr, D2; atol, rtol, mode)
35-
Mooncake.TestUtils.test_rule(rng, tr, D3; atol, rtol, mode)
36-
37-
Mooncake.TestUtils.test_rule(rng, inv, D1; atol, rtol, mode)
38-
Mooncake.TestUtils.test_rule(rng, inv, D2; atol, rtol, mode)
39-
Mooncake.TestUtils.test_rule(rng, inv, D3; atol, rtol, mode)
31+
Mooncake.TestUtils.test_rule(rng, tr, D1; atol, rtol)
32+
Mooncake.TestUtils.test_rule(rng, tr, D2; atol, rtol)
33+
Mooncake.TestUtils.test_rule(rng, tr, D3; atol, rtol)
34+
35+
Mooncake.TestUtils.test_rule(rng, inv, D1; atol, rtol)
36+
Mooncake.TestUtils.test_rule(rng, inv, D2; atol, rtol)
37+
Mooncake.TestUtils.test_rule(rng, inv, D3; atol, rtol)
38+
39+
C = randn(T, V[1] V[1])
40+
C′ = similar(C)
41+
Mooncake.TestUtils.test_rule(rng, project_hermitian!, C, C′; atol, rtol, is_primitive = false)
42+
Mooncake.TestUtils.test_rule(rng, project_hermitian!, C, C; atol, rtol, is_primitive = false)
43+
Mooncake.TestUtils.test_rule(rng, project_hermitian, C; atol, rtol, is_primitive = false)
44+
Mooncake.TestUtils.test_rule(rng, project_antihermitian!, C, C′; atol, rtol, is_primitive = false)
45+
Mooncake.TestUtils.test_rule(rng, project_antihermitian!, C, C; atol, rtol, is_primitive = false)
46+
Mooncake.TestUtils.test_rule(rng, project_antihermitian, C; atol, rtol, is_primitive = false)
4047
end

0 commit comments

Comments
 (0)