Skip to content

Commit 47da83e

Browse files
authored
Merge branch 'main' into ksh/enz_vi
2 parents e74ebd5 + 1de432d commit 47da83e

3 files changed

Lines changed: 68 additions & 11 deletions

File tree

ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module TensorKitMooncakeExt
33
using Mooncake
44
using Mooncake: @zero_derivative, @is_primitive,
55
DefaultCtx, MinimalCtx, ReverseMode, NoFData, NoRData, NoTangent,
6-
CoDual, Dual, arrayify, primal, tangent, zero_fcodual
6+
CoDual, Dual, arrayify, primal, tangent, zero_fcodual, extract
77
using TensorKit
88
import TensorKit as TK
99
using VectorInterface

ext/TensorKitMooncakeExt/tensoroperations.jl

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# this permutation is done multiple times.
99
@is_primitive(
1010
DefaultCtx,
11-
ReverseMode,
1211
Tuple{
1312
typeof(TensorKit.blas_contract!),
1413
AbstractTensorMap,
@@ -70,6 +69,36 @@ function Mooncake.rrule!!(
7069
return C_ΔC, blas_contract_pullback
7170
end
7271

72+
function Mooncake.frule!!(
73+
::Dual{typeof(TensorKit.blas_contract!)},
74+
C_ΔC::Dual{<:AbstractTensorMap},
75+
A_ΔA::Dual{<:AbstractTensorMap}, pA_ΔpA::Dual{<:Index2Tuple},
76+
B_ΔB::Dual{<:AbstractTensorMap}, pB_ΔpB::Dual{<:Index2Tuple},
77+
pAB_ΔpAB::Dual{<:Index2Tuple},
78+
α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number},
79+
backend_Δbackend::Dual, allocator_Δallocator::Dual
80+
)
81+
# prepare arguments
82+
(C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB))
83+
pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB))
84+
α, Δα = extract(α_Δα)
85+
β, Δβ = extract(β_Δβ)
86+
backend, allocator = primal.((backend_Δbackend, allocator_Δallocator))
87+
# ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
88+
if isa(Δβ, NoTangent)
89+
scale!(ΔC, β)
90+
else
91+
add!(ΔC, C, Δβ, β)
92+
end
93+
if !isa(Δα, NoTangent)
94+
TensorKit.blas_contract!(ΔC, A, pA, B, pB, pAB, Δα, One(), backend, allocator)
95+
end
96+
TensorKit.blas_contract!(ΔC, ΔA, pA, B, pB, pAB, α, One(), backend, allocator)
97+
TensorKit.blas_contract!(ΔC, A, pA, ΔB, pB, pAB, α, One(), backend, allocator)
98+
TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
99+
return C_ΔC
100+
end
101+
73102
function blas_contract_pullback_ΔA!(
74103
ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
75104
)
@@ -124,7 +153,6 @@ end
124153
# ------------
125154
@is_primitive(
126155
DefaultCtx,
127-
ReverseMode,
128156
Tuple{
129157
typeof(TensorKit.trace_permute!),
130158
AbstractTensorMap,
@@ -177,6 +205,37 @@ function Mooncake.rrule!!(
177205
return C_ΔC, trace_permute_pullback
178206
end
179207

208+
function Mooncake.frule!!(
209+
::Dual{typeof(TensorKit.trace_permute!)},
210+
C_ΔC::Dual{<:AbstractTensorMap},
211+
A_ΔA::Dual{<:AbstractTensorMap}, p_Δp::Dual{<:Index2Tuple}, q_Δq::Dual{<:Index2Tuple},
212+
α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number},
213+
backend_Δbackend::Dual
214+
)
215+
# prepare arguments
216+
C, ΔC = arrayify(C_ΔC)
217+
A, ΔA = arrayify(A_ΔA)
218+
p = primal(p_Δp)
219+
q = primal(q_Δq)
220+
α, Δα = extract(α_Δα)
221+
β, Δβ = extract(β_Δβ)
222+
backend = primal(backend_Δbackend)
223+
224+
# dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC
225+
# dC1 = dβ * C + β * dC
226+
if isa(Δβ, NoTangent)
227+
scale!(ΔC, β)
228+
else
229+
add!(ΔC, C, Δβ, β)
230+
end
231+
if !isa(Δα, NoTangent)
232+
TensorKit.trace_permute!(ΔC, A, p, q, Δα, One(), backend)
233+
end
234+
TensorKit.trace_permute!(ΔC, ΔA, p, q, α, One(), backend)
235+
TensorKit.trace_permute!(C, A, p, q, α, β, backend)
236+
return C_ΔC
237+
end
238+
180239
function trace_permute_pullback_ΔA!(
181240
ΔA, ΔC, A, p, q, α, backend
182241
)

test/mooncake/tensoroperations.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ using VectorInterface: One, Zero
55
using Mooncake
66
using Random
77

8-
9-
mode = Mooncake.ReverseMode
108
rng = Random.default_rng()
119

1210
spacelist = ad_spacelist(fast_tests)
@@ -53,32 +51,32 @@ eltypes = (Float64, ComplexF64)
5351
rng, TensorKit.blas_contract!,
5452
C, A, pA, B, pB, pAB, One(), Zero(),
5553
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
56-
atol, rtol, mode
54+
atol, rtol
5755
)
5856
Mooncake.TestUtils.test_rule(
5957
rng, TensorKit.blas_contract!,
6058
C, A, pA, B, pB, pAB, α, β,
6159
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
62-
atol, rtol, mode
60+
atol, rtol
6361
)
6462
if !(T <: Real)
6563
Mooncake.TestUtils.test_rule(
6664
rng, TensorKit.blas_contract!,
6765
C, A, pA, B, pB, pAB, real(α), real(β),
6866
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
69-
atol, rtol, mode
67+
atol, rtol
7068
)
7169
Mooncake.TestUtils.test_rule(
7270
rng, TensorKit.blas_contract!,
7371
C, real(A), pA, B, pB, pAB, real(α), real(β),
7472
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
75-
atol, rtol, mode
73+
atol, rtol
7674
)
7775
Mooncake.TestUtils.test_rule(
7876
rng, TensorKit.blas_contract!,
7977
C, A, pA, real(B), pB, pAB, real(α), real(β),
8078
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
81-
atol, rtol, mode
79+
atol, rtol
8280
)
8381
end
8482
end
@@ -102,7 +100,7 @@ eltypes = (Float64, ComplexF64)
102100
C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false)))
103101
Mooncake.TestUtils.test_rule(
104102
rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend();
105-
atol, rtol, mode
103+
atol, rtol
106104
)
107105
end
108106
end

0 commit comments

Comments
 (0)