@@ -3,156 +3,220 @@ module MatrixAlgebraKitMooncakeExt
33using Mooncake
44using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55using MatrixAlgebraKit
6- using MatrixAlgebraKit: inv_safe, diagview
6+ using MatrixAlgebraKit: inv_safe, diagview, copy_input
77using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99using MatrixAlgebraKit: eig_pullback!, eigh_pullback!
1010using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1111using LinearAlgebra
1212
13+
14+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (copy_input), Any, Any}
15+ function Mooncake. rrule!! (:: CoDual{typeof(copy_input)} , f_df:: CoDual , A_dA:: CoDual )
16+ Ac = copy_input (Mooncake. primal (f_df), Mooncake. primal (A_dA))
17+ dAc = Mooncake. zero_tangent (Ac)
18+ function copy_input_pb (:: Mooncake.NoRData )
19+ Mooncake. increment!! (Mooncake. tangent (A_dA), dAc)
20+ return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
21+ end
22+ return CoDual (Ac, dAc), copy_input_pb
23+ end
24+
1325# two-argument factorizations like LQ, QR, EIG
14- for (f, pb, adj) in ((qr_full!, qr_pullback!, :dqr_adjoint ),
15- (qr_compact!, qr_pullback!, :dqr_adjoint ),
16- (lq_full!, lq_pullback!, :dlq_adjoint ),
17- (lq_compact!, lq_pullback!, :dlq_adjoint ),
18- (eig_full!, eig_pullback!, :deig_adjoint ),
19- (eigh_full!, eigh_pullback!, :deigh_adjoint ),
20- (left_polar!, left_polar_pullback!, :dleft_polar_adjoint ),
21- (right_polar!, right_polar_pullback!, :dright_polar_adjoint ),
22- )
26+ for (f, pb, adj) in (
27+ (qr_full!, qr_pullback!, :dqr_adjoint ),
28+ (qr_compact!, qr_pullback!, :dqr_adjoint ),
29+ (lq_full!, lq_pullback!, :dlq_adjoint ),
30+ (lq_compact!, lq_pullback!, :dlq_adjoint ),
31+ (eig_full!, eig_pullback!, :deig_adjoint ),
32+ (eigh_full!, eigh_pullback!, :deigh_adjoint ),
33+ (left_polar!, left_polar_pullback!, :dleft_polar_adjoint ),
34+ (right_polar!, right_polar_pullback!, :dright_polar_adjoint ),
35+ )
2336
2437 @eval begin
25- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), AbstractMatrix, Tuple{<: AbstractMatrix , <: AbstractMatrix }, MatrixAlgebraKit. AbstractAlgorithm}
26- function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual{<:AbstractMatrix} , args_dargs:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} ; kwargs... )
27- A, dA = arrayify (A_dA)
28- dA .= zero (eltype (A))
29- args = Mooncake. primal (args_dargs)
30- dargs = Mooncake. tangent (args_dargs)
31- arg1, darg1 = arrayify (args[1 ], dargs[1 ])
32- arg2, darg2 = arrayify (args[2 ], dargs[2 ])
38+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, Tuple{<: Any , <: Any }, MatrixAlgebraKit. AbstractAlgorithm}
39+ function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} ; kwargs... )
40+ A, dA = arrayify (A_dA)
41+ args = Mooncake. primal (args_dargs)
42+ dargs = Mooncake. tangent (args_dargs)
43+ arg1, darg1 = arrayify (args[1 ], dargs[1 ])
44+ arg2, darg2 = arrayify (args[2 ], dargs[2 ])
45+ Ac = copy (A)
46+ arg1c = copy (arg1)
47+ arg2c = copy (arg2)
48+ output = $ f (A, args, Mooncake. primal (alg_dalg); kwargs... )
3349 function $adj (:: Mooncake.NoRData )
34- dA = $ pb (dA, A, (arg1, arg2), (darg1, darg2); kwargs... )
50+ dAtmp_ = zero (dA)
51+ dAtmp_ .= $ pb (dAtmp_, A, (arg1, arg2), (darg1, darg2); kwargs... )
52+ dAtmp = if eltype (dA) <: Real
53+ dAtmp_
54+ else
55+ map (A_ -> Mooncake. build_tangent (typeof (A_), real (A_), imag (A_)), dAtmp_)
56+ end
57+ Mooncake. increment!! (Mooncake. tangent (A_dA), dAtmp)
58+ arg1 .= arg1c
59+ arg2 .= arg2c
60+ A .= Ac
61+ darg1 .= zero (darg1)
62+ darg2 .= zero (darg2)
3563 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
3664 end
37- args = $ f (A, args, Mooncake. primal (alg_dalg); kwargs... )
38- darg1 .= zero (eltype (arg1))
39- darg2 .= zero (eltype (arg2))
4065 return Mooncake. CoDual (args, dargs), $ adj
4166 end
4267 end
4368end
4469
45- for (f, f_full, pb, adj) in ((qr_null!, qr_full, qr_null_pullback!, :dqr_null_adjoint ),
46- (lq_null!, lq_full, lq_null_pullback!, :dlq_null_adjoint ),
47- )
70+ for (f, f_full, pb, adj) in (
71+ (qr_null!, qr_full, qr_null_pullback!, :dqr_null_adjoint ),
72+ (lq_null!, lq_full, lq_null_pullback!, :dlq_null_adjoint ),
73+ )
4874 @eval begin
49- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit. AbstractAlgorithm}
50- function Mooncake. rrule!! (f_df:: CoDual{typeof($f)} , A_dA:: CoDual{<:AbstractMatrix} , arg_darg:: CoDual{<:AbstractMatrix} , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} ; kwargs... )
51- A, dA = arrayify (A_dA)
52- Ac = MatrixAlgebraKit. copy_input ($ f_full, A)
53- arg, darg = arrayify (Mooncake. primal (arg_darg), Mooncake. tangent (arg_darg))
54- arg = $ f (Ac, arg, Mooncake. primal (alg_dalg))
75+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
76+ function Mooncake. rrule!! (f_df:: CoDual{typeof($f)} , A_dA:: CoDual , arg_darg:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} ; kwargs... )
77+ A, dA = arrayify (A_dA)
78+ Ac = copy (A)
79+ arg, darg = arrayify (arg_darg)
80+ argc = copy (arg)
81+ # WHY is this copy needed?
82+ arg = $ f (copy (A), arg, Mooncake. primal (alg_dalg))
5583 function $adj (:: Mooncake.NoRData )
56- dA .= zero (eltype (A))
57- $ pb (dA, A, arg, darg; kwargs... )
84+ dAtmp_ = zero (dA)
85+ dAtmp_ .= $ pb (dAtmp_, A, arg, darg; kwargs... )
86+ dAtmp = if eltype (dA) <: Real
87+ dAtmp_
88+ else
89+ map (A_ -> Mooncake. build_tangent (typeof (A_), real (A_), imag (A_)), dAtmp_)
90+ end
91+ Mooncake. increment!! (Mooncake. tangent (A_dA), dAtmp)
92+ A .= Ac
93+ arg .= argc
94+ darg .= zero (darg)
5895 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
5996 end
60- return arg_darg, $ adj
97+ return arg_darg, $ adj
6198 end
6299 end
63100end
64101
65- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. eig_vals!), AbstractMatrix, AbstractVector , MatrixAlgebraKit. AbstractAlgorithm}
102+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. eig_vals!), Any, Any , MatrixAlgebraKit. AbstractAlgorithm}
66103function Mooncake. rrule!! (:: CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual ; kwargs... )
67104 # compute primal
68- D_ = Mooncake. primal (D_dD)
69- dD_ = Mooncake. tangent (D_dD)
70- A_ = Mooncake. primal (A_dA)
71- dA_ = Mooncake. tangent (A_dA)
105+ D_ = Mooncake. primal (D_dD)
106+ dD_ = Mooncake. tangent (D_dD)
107+ A_ = Mooncake. primal (A_dA)
108+ dA_ = Mooncake. tangent (A_dA)
72109 A, dA = arrayify (A_, dA_)
73110 D, dD = arrayify (D_, dD_)
74- dA .= zero (eltype (dA))
75- # update primal
76- DV = eig_full (A, Mooncake. primal (alg_dalg); kwargs... )
77- V = DV[2 ]
78- dD .= zero (eltype (D))
111+ Ac = copy (A)
112+ Dc = copy (D)
113+ dDc = copy (dD)
114+ # update primal
115+ DV = eig_full (A, Mooncake. primal (alg_dalg); kwargs... )
116+ V = DV[2 ]
117+ eig_vals! (A, D, Mooncake. primal (alg_dalg))
79118 function deig_vals_adjoint (:: Mooncake.NoRData )
119+ dA .= zero (eltype (dA))
120+ A .= Ac
80121 PΔV = V' \ Diagonal (dD)
81122 if eltype (dA) <: Real
82123 ΔAc = PΔV * V'
83124 dA .+ = real .(ΔAc)
84125 else
85126 mul! (dA, PΔV, V' , 1 , 0 )
86127 end
128+ D .= Dc
129+ dD .= dDc
87130 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
88131 end
89- return Mooncake. CoDual (DV[1 ]. diag, dD_), deig_vals_adjoint
132+ dD .= zero (eltype (D))
133+ return D_dD, deig_vals_adjoint
90134end
91135
92- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. eigh_vals!), AbstractMatrix, AbstractVector , MatrixAlgebraKit. AbstractAlgorithm}
136+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. eigh_vals!), Any, Any , MatrixAlgebraKit. AbstractAlgorithm}
93137function Mooncake. rrule!! (:: CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual ; kwargs... )
94138 # compute primal
95- D_ = Mooncake. primal (D_dD)
96- dD_ = Mooncake. tangent (D_dD)
97- A_ = Mooncake. primal (A_dA)
98- dA_ = Mooncake. tangent (A_dA)
139+ D_ = Mooncake. primal (D_dD)
140+ dD_ = Mooncake. tangent (D_dD)
141+ A_ = Mooncake. primal (A_dA)
142+ dA_ = Mooncake. tangent (A_dA)
99143 A, dA = arrayify (A_, dA_)
100144 D, dD = arrayify (D_, dD_)
101- DV = eigh_full (A, Mooncake. primal (alg_dalg); kwargs... )
145+ DV = eigh_full (A, Mooncake. primal (alg_dalg); kwargs... )
102146 function deigh_vals_adjoint (:: Mooncake.NoRData )
103147 mul! (dA, DV[2 ] * Diagonal (real (dD)), DV[2 ]' , 1 , 0 )
104148 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
105149 end
150+ dD .= zero (eltype (D))
106151 return Mooncake. CoDual (DV[1 ]. diag, dD_), deigh_vals_adjoint
107152end
108153
109154
110- for (f, St) in (( svd_full!, :AbstractMatrix ), ( svd_compact!, :Diagonal ) )
155+ for f in (svd_full!, svd_compact!)
111156 @eval begin
112- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), AbstractMatrix, Tuple{<: AbstractMatrix , <: $St , <: AbstractMatrix }, MatrixAlgebraKit. AbstractAlgorithm}
157+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), AbstractMatrix, Tuple{<: Any , <: Any , <: Any }, MatrixAlgebraKit. AbstractAlgorithm}
113158 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual ; kwargs... )
114- A, dA = arrayify (A_dA)
115- USVᴴ = Mooncake. primal (USVᴴ_dUSVᴴ)
116- dUSVᴴ = Mooncake. tangent (USVᴴ_dUSVᴴ)
117- U, dU = arrayify (USVᴴ[1 ], dUSVᴴ[1 ])
118- S, dS = arrayify (USVᴴ[2 ], dUSVᴴ[2 ])
159+ A, dA = arrayify (A_dA)
160+ Ac = copy (A)
161+ USVᴴ = Mooncake. primal (USVᴴ_dUSVᴴ)
162+ dUSVᴴ = Mooncake. tangent (USVᴴ_dUSVᴴ)
163+ U, dU = arrayify (USVᴴ[1 ], dUSVᴴ[1 ])
164+ S, dS = arrayify (USVᴴ[2 ], dUSVᴴ[2 ])
119165 Vᴴ, dVᴴ = arrayify (USVᴴ[3 ], dUSVᴴ[3 ])
120- USVᴴ = $ f (A, USVᴴ, Mooncake. primal (alg_dalg); kwargs... )
166+ Uc = copy (U)
167+ Sc = copy (S)
168+ Vᴴc = copy (Vᴴ)
169+ USVᴴ = $ f (A, USVᴴ, Mooncake. primal (alg_dalg); kwargs... )
170+ minmn = min (size (A)... )
121171 function dsvd_adjoint (:: Mooncake.NoRData )
122- dA . = zero (eltype (A) )
123- minmn = min ( size (A) ... )
172+ dAtmp_ = zero (dA )
173+ A . = Ac
124174 if size (U, 2 ) == size (Vᴴ, 1 ) == minmn # compact
125- dA = MatrixAlgebraKit. svd_pullback! (dA , A, (U, S, Vᴴ), (dU, dS, dVᴴ))
175+ dAtmp_ = MatrixAlgebraKit. svd_pullback! (dAtmp_ , A, (U, S, Vᴴ), (dU, dS, dVᴴ))
126176 else # full
127- vU = view (U, :, 1 : minmn)
128- vS = Diagonal (diagview (S)[1 : minmn])
129- vVᴴ = view (Vᴴ, 1 : minmn, :)
130- vdU = view (dU, :, 1 : minmn)
131- vdS = Diagonal (diagview (dS)[1 : minmn])
132- vdVᴴ = view (dVᴴ, 1 : minmn, :)
133- dA = MatrixAlgebraKit. svd_pullback! (dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
177+ vU = view (U, :, 1 : minmn)
178+ vS = Diagonal (diagview (S)[1 : minmn])
179+ vVᴴ = view (Vᴴ, 1 : minmn, :)
180+ vdU = view (dU, :, 1 : minmn)
181+ vdS = Diagonal (diagview (dS)[1 : minmn])
182+ vdVᴴ = view (dVᴴ, 1 : minmn, :)
183+ dAtmp_ = MatrixAlgebraKit. svd_pullback! (dAtmp_, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
184+ end
185+ dAtmp = if eltype (dA) <: Real
186+ dAtmp_
187+ else
188+ map (A_ -> Mooncake. build_tangent (typeof (A_), real (A_), imag (A_)), dAtmp_)
134189 end
190+ Mooncake. increment!! (Mooncake. tangent (A_dA), dAtmp)
191+ U .= Uc
192+ S .= Sc
193+ Vᴴ .= Vᴴc
194+ dU .= zero (dU)
195+ dS .= zero (dS)
196+ dVᴴ .= zero (dVᴴ)
135197 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
136198 end
137199 return Mooncake. CoDual (USVᴴ, dUSVᴴ), dsvd_adjoint
138200 end
139201 end
140202end
141203
142- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals!), AbstractMatrix, AbstractVector , MatrixAlgebraKit. AbstractAlgorithm}
204+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals!), Any, Any , MatrixAlgebraKit. AbstractAlgorithm}
143205function Mooncake. rrule!! (:: CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)} , A_dA:: CoDual , S_dS:: CoDual , alg_dalg:: CoDual ; kwargs... )
144206 # compute primal
145- S_ = Mooncake. primal (S_dS)
146- dS_ = Mooncake. tangent (S_dS)
147- A_ = Mooncake. primal (A_dA)
148- dA_ = Mooncake. tangent (A_dA)
207+ S_ = Mooncake. primal (S_dS)
208+ dS_ = Mooncake. tangent (S_dS)
209+ A_ = Mooncake. primal (A_dA)
210+ dA_ = Mooncake. tangent (A_dA)
149211 A, dA = arrayify (A_, dA_)
150212 S, dS = arrayify (S_, dS_)
213+ Ac = copy (A)
151214 U, nS, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg); kwargs... )
152- S .= diagview (nS)
153- dS .= zero (eltype (S))
215+ S .= diagview (nS)
154216 function dsvd_vals_adjoint (:: Mooncake.NoRData )
155- dA .= U * Diagonal (dS) * Vᴴ
217+ dA .= U * Diagonal (dS) * Vᴴ
218+ A .= Ac
219+ dS .= zero (eltype (S))
156220 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
157221 end
158222 return S_dS, dsvd_vals_adjoint
0 commit comments