|
8 | 8 | # this permutation is done multiple times. |
9 | 9 | @is_primitive( |
10 | 10 | DefaultCtx, |
11 | | - ReverseMode, |
12 | 11 | Tuple{ |
13 | 12 | typeof(TensorKit.blas_contract!), |
14 | 13 | AbstractTensorMap, |
@@ -70,6 +69,36 @@ function Mooncake.rrule!!( |
70 | 69 | return C_ΔC, blas_contract_pullback |
71 | 70 | end |
72 | 71 |
|
| 72 | +function Mooncake.frule!!( |
| 73 | + ::Dual{typeof(TensorKit.blas_contract!)}, |
| 74 | + C_ΔC::Dual{<:AbstractTensorMap}, |
| 75 | + A_ΔA::Dual{<:AbstractTensorMap}, pA_ΔpA::Dual{<:Index2Tuple}, |
| 76 | + B_ΔB::Dual{<:AbstractTensorMap}, pB_ΔpB::Dual{<:Index2Tuple}, |
| 77 | + pAB_ΔpAB::Dual{<:Index2Tuple}, |
| 78 | + α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}, |
| 79 | + backend_Δbackend::Dual, allocator_Δallocator::Dual |
| 80 | + ) |
| 81 | + # prepare arguments |
| 82 | + (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) |
| 83 | + pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB)) |
| 84 | + α, Δα = extract(α_Δα) |
| 85 | + β, Δβ = extract(β_Δβ) |
| 86 | + backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) |
| 87 | + # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α |
| 88 | + if isa(Δβ, NoTangent) |
| 89 | + scale!(ΔC, β) |
| 90 | + else |
| 91 | + add!(ΔC, C, Δβ, β) |
| 92 | + end |
| 93 | + if !isa(Δα, NoTangent) |
| 94 | + TensorKit.blas_contract!(ΔC, A, pA, B, pB, pAB, Δα, One(), backend, allocator) |
| 95 | + end |
| 96 | + TensorKit.blas_contract!(ΔC, ΔA, pA, B, pB, pAB, α, One(), backend, allocator) |
| 97 | + TensorKit.blas_contract!(ΔC, A, pA, ΔB, pB, pAB, α, One(), backend, allocator) |
| 98 | + TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) |
| 99 | + return C_ΔC |
| 100 | +end |
| 101 | + |
73 | 102 | function blas_contract_pullback_ΔA!( |
74 | 103 | ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator |
75 | 104 | ) |
|
124 | 153 | # ------------ |
125 | 154 | @is_primitive( |
126 | 155 | DefaultCtx, |
127 | | - ReverseMode, |
128 | 156 | Tuple{ |
129 | 157 | typeof(TensorKit.trace_permute!), |
130 | 158 | AbstractTensorMap, |
@@ -177,6 +205,37 @@ function Mooncake.rrule!!( |
177 | 205 | return C_ΔC, trace_permute_pullback |
178 | 206 | end |
179 | 207 |
|
| 208 | +function Mooncake.frule!!( |
| 209 | + ::Dual{typeof(TensorKit.trace_permute!)}, |
| 210 | + C_ΔC::Dual{<:AbstractTensorMap}, |
| 211 | + A_ΔA::Dual{<:AbstractTensorMap}, p_Δp::Dual{<:Index2Tuple}, q_Δq::Dual{<:Index2Tuple}, |
| 212 | + α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}, |
| 213 | + backend_Δbackend::Dual |
| 214 | + ) |
| 215 | + # prepare arguments |
| 216 | + C, ΔC = arrayify(C_ΔC) |
| 217 | + A, ΔA = arrayify(A_ΔA) |
| 218 | + p = primal(p_Δp) |
| 219 | + q = primal(q_Δq) |
| 220 | + α, Δα = extract(α_Δα) |
| 221 | + β, Δβ = extract(β_Δβ) |
| 222 | + backend = primal(backend_Δbackend) |
| 223 | + |
| 224 | + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC |
| 225 | + # dC1 = dβ * C + β * dC |
| 226 | + if isa(Δβ, NoTangent) |
| 227 | + scale!(ΔC, β) |
| 228 | + else |
| 229 | + add!(ΔC, C, Δβ, β) |
| 230 | + end |
| 231 | + if !isa(Δα, NoTangent) |
| 232 | + TensorKit.trace_permute!(ΔC, A, p, q, Δα, One(), backend) |
| 233 | + end |
| 234 | + TensorKit.trace_permute!(ΔC, ΔA, p, q, α, One(), backend) |
| 235 | + TensorKit.trace_permute!(C, A, p, q, α, β, backend) |
| 236 | + return C_ΔC |
| 237 | +end |
| 238 | + |
180 | 239 | function trace_permute_pullback_ΔA!( |
181 | 240 | ΔA, ΔC, A, p, q, α, backend |
182 | 241 | ) |
|
0 commit comments