33pullback_dC! (ΔC, β) = (scale! (ΔC, conj (β)); return NoRData ())
44pullback_dβ (ΔC, C, β) = _needs_tangent (β) ? project_scalar (β, inner (C, ΔC)) : NoRData ()
55
6- @is_primitive DefaultCtx ReverseMode Tuple{typeof (mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number}
6+ @is_primitive DefaultCtx Tuple{typeof (mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number}
77
88function Mooncake. rrule!! (
99 :: CoDual{typeof(mul!)} ,
@@ -40,9 +40,29 @@ function Mooncake.rrule!!(
4040
4141 return C_ΔC, mul_pullback
4242end
43+ function Mooncake. frule!! (
44+ :: Dual{typeof(mul!)} ,
45+ C_ΔC:: Dual{<:AbstractTensorMap} , A_ΔA:: Dual{<:AbstractTensorMap} , B_ΔB:: Dual{<:AbstractTensorMap} ,
46+ α_Δα:: Dual{<:Number} , β_Δβ:: Dual{<:Number}
47+ )
48+ (C, ΔC), (A, ΔA), (B, ΔB) = arrayify .((C_ΔC, A_ΔA, B_ΔB))
49+ α, Δα = Mooncake. extract (α_Δα)
50+ β, Δβ = Mooncake. extract (β_Δβ)
51+ # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
52+ scale! (ΔC, β)
53+ if ! isa (Δβ, Mooncake. NoTangent)
54+ add! (ΔC, C, Δβ)
55+ end
56+ if ! isa (Δα, Mooncake. NoTangent)
57+ project_mul! (ΔC, A, B, Δα)
58+ end
59+ project_mul! (ΔC, ΔA, B, α)
60+ project_mul! (ΔC, A, ΔB, α)
61+ mul! (C, A, B, α, β)
62+ return C_ΔC
63+ end
4364
44- @is_primitive DefaultCtx ReverseMode Tuple{typeof (norm), AbstractTensorMap, Real}
45-
65+ @is_primitive DefaultCtx Tuple{typeof (norm), AbstractTensorMap, Real}
4666function Mooncake. rrule!! (:: CoDual{typeof(norm)} , tΔt:: CoDual{<:AbstractTensorMap} , pdp:: CoDual{<:Real} )
4767 t, Δt = arrayify (tΔt)
4868 p = primal (pdp)
@@ -55,9 +75,16 @@ function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorM
5575 end
5676 return CoDual (n, Mooncake. NoFData ()), norm_pullback
5777end
78+ function Mooncake. frule!! (:: Dual{typeof(norm)} , tΔt:: Dual{<:AbstractTensorMap} , pdp:: Dual{<:Real} )
79+ t, Δt = arrayify (tΔt)
80+ p, Δp = Mooncake. extract (pdp)
81+ p == 2 || error (" currently only implemented for p = 2" )
82+ n = norm (t, p)
83+ Δn = real (dot (t, Δt)) * pinv (n)
84+ return Dual (n, Δn)
85+ end
5886
59- @is_primitive DefaultCtx ReverseMode Tuple{typeof (tr), AbstractTensorMap}
60-
87+ @is_primitive DefaultCtx Tuple{typeof (tr), AbstractTensorMap}
6188function Mooncake. rrule!! (:: CoDual{typeof(tr)} , A_ΔA:: CoDual{<:AbstractTensorMap} )
6289 A, ΔA = arrayify (A_ΔA)
6390 trace = tr (A)
@@ -71,8 +98,12 @@ function Mooncake.rrule!!(::CoDual{typeof(tr)}, A_ΔA::CoDual{<:AbstractTensorMa
7198
7299 return CoDual (trace, Mooncake. NoFData ()), tr_pullback
73100end
101+ function Mooncake. frule!! (:: Dual{typeof(tr)} , A_ΔA:: Dual{<:AbstractTensorMap} )
102+ A, ΔA = arrayify (A_ΔA)
103+ return Dual (tr (A), tr (ΔA))
104+ end
74105
75- @is_primitive DefaultCtx ReverseMode Tuple{typeof (inv), AbstractTensorMap}
106+ @is_primitive DefaultCtx Tuple{typeof (inv), AbstractTensorMap}
76107
77108function Mooncake. rrule!! (:: CoDual{typeof(inv)} , A_ΔA:: CoDual{<:AbstractTensorMap} )
78109 A, ΔA = arrayify (A_ΔA)
@@ -86,13 +117,21 @@ function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorM
86117
87118 return Ainv_ΔAinv, inv_pullback
88119end
120+ function Mooncake. frule!! (:: Dual{typeof(inv)} , A_ΔA:: Dual{<:AbstractTensorMap} )
121+ A, ΔA = arrayify (A_ΔA)
122+ Ainv = inv (A)
123+ ΔAinv = scale! (Ainv * ΔA * Ainv, - 1 )
124+ return Dual (Ainv, ΔAinv)
125+ end
89126
90127# single-output projections: project_hermitian!, project_antihermitian!
91128for (f!, f, adj) in (
92129 (:project_hermitian! , :project_hermitian , :project_hermitian_adjoint ),
93130 (:project_antihermitian! , :project_antihermitian , :project_antihermitian_adjoint ),
94131 )
95132 @eval begin
133+ @is_primitive DefaultCtx Tuple{typeof ($ f!), AbstractTensorMap, AbstractTensorMap, MatrixAlgebraKit. AbstractAlgorithm}
134+ @is_primitive DefaultCtx Tuple{typeof ($ f), AbstractTensorMap, MatrixAlgebraKit. AbstractAlgorithm}
96135 function Mooncake. rrule!! (f_df:: CoDual{typeof($f!)} , A_dA:: CoDual{<:AbstractTensorMap} , arg_darg:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} )
97136 A, dA = arrayify (A_dA)
98137 arg, darg = A_dA === arg_darg ? (A, dA) : arrayify (arg_darg)
@@ -113,7 +152,13 @@ for (f!, f, adj) in (
113152
114153 return arg_darg, $ adj
115154 end
116-
155+ function Mooncake. frule!! (f_df:: Dual{typeof($f!)} , A_dA:: Dual{<:AbstractTensorMap} , arg_darg:: Dual , alg_dalg:: Dual{<:MatrixAlgebraKit.AbstractAlgorithm} )
156+ A, dA = arrayify (A_dA)
157+ arg, darg = A_dA === arg_darg ? (A, dA) : arrayify (arg_darg)
158+ arg = $ f! (A, arg, Mooncake. primal (alg_dalg))
159+ $ f! (dA, darg, Mooncake. primal (alg_dalg))
160+ return arg_darg
161+ end
117162 function Mooncake. rrule!! (f_df:: CoDual{typeof($f)} , A_dA:: CoDual{<:AbstractTensorMap} , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} )
118163 A, dA = arrayify (A_dA)
119164 output = $ f (A, Mooncake. primal (alg_dalg))
@@ -129,5 +174,11 @@ for (f!, f, adj) in (
129174
130175 return output_doutput, $ adj
131176 end
177+ function Mooncake. frule!! (f_df:: Dual{typeof($f)} , A_dA:: Dual{<:AbstractTensorMap} , alg_dalg:: Dual{<:MatrixAlgebraKit.AbstractAlgorithm} )
178+ A, dA = arrayify (A_dA)
179+ output = $ f (A, Mooncake. primal (alg_dalg))
180+ doutput = $ f (dA, Mooncake. primal (alg_dalg))
181+ return Dual (output, doutput)
182+ end
132183 end
133184end
0 commit comments