@@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
33using Mooncake
44using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55using MatrixAlgebraKit
6- using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
6+ using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero
77using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
@@ -54,11 +54,11 @@ for (f!, f, pb, adj) in (
5454 $ f! (A, args, Mooncake. primal (alg_dalg))
5555 function $adj (:: NoRData )
5656 copy! (A, Ac)
57- $ pb (dA, A, (arg1, arg2), (darg1, darg2))
5857 copy! (arg1, arg1c)
5958 copy! (arg2, arg2c)
60- MatrixAlgebraKit. zero! (darg1)
61- MatrixAlgebraKit. zero! (darg2)
59+ $ pb (dA, A, (arg1, arg2), (darg1, darg2))
60+ zero! (darg1)
61+ zero! (darg2)
6262 return NoRData (), NoRData (), NoRData (), NoRData ()
6363 end
6464 return args_dargs, $ adj
@@ -78,8 +78,8 @@ for (f!, f, pb, adj) in (
7878 arg1, darg1 = arrayify (arg1, darg1_)
7979 arg2, darg2 = arrayify (arg2, darg2_)
8080 $ pb (dA, A, (arg1, arg2), (darg1, darg2))
81- MatrixAlgebraKit . zero! (darg1)
82- MatrixAlgebraKit . zero! (darg2)
81+ zero! (darg1)
82+ zero! (darg2)
8383 return NoRData (), NoRData (), NoRData ()
8484 end
8585 return output_codual, $ adj
@@ -101,8 +101,8 @@ for (f!, f, pb, adj) in (
101101 $ f! (A, arg, Mooncake. primal (alg_dalg))
102102 function $adj (:: NoRData )
103103 copy! (A, Ac)
104- $ pb (dA, A, arg, darg)
105104 copy! (arg, argc)
105+ $ pb (dA, A, arg, darg)
106106 MatrixAlgebraKit. zero! (darg)
107107 return NoRData (), NoRData (), NoRData (), NoRData ()
108108 end
@@ -139,6 +139,7 @@ for (f!, f, f_full, pb, adj) in (
139139 copy! (D, diagview (DV[1 ]))
140140 V = DV[2 ]
141141 function $adj (:: NoRData )
142+ copy! (D, diagview (DV[1 ]))
142143 $ pb (dA, A, DV, dD)
143144 MatrixAlgebraKit. zero! (dD)
144145 return NoRData (), NoRData (), NoRData (), NoRData ()
@@ -165,12 +166,43 @@ for (f!, f, f_full, pb, adj) in (
165166 end
166167end
167168
168- for (f, f_ne, pb, adj) in (
169- (:eig_trunc , :eig_trunc_no_error , :eig_trunc_pullback! , :eig_trunc_adjoint ),
170- (:eigh_trunc , :eigh_trunc_no_error , :eigh_trunc_pullback! , :eigh_trunc_adjoint ),
169+ for (f!, f, f_ne! , f_ne, pb, adj) in (
170+ (:eig_trunc! , :eig_trunc , :eig_trunc_no_error! , :eig_trunc_no_error , :eig_trunc_pullback! , :eig_trunc_adjoint ),
171+ (:eigh_trunc! , :eigh_trunc , :eigh_trunc_no_error! , :eigh_trunc_no_error , :eigh_trunc_pullback! , :eigh_trunc_adjoint ),
171172 )
172173 @eval begin
174+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
173175 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
176+ function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
177+ # compute primal
178+ A, dA = arrayify (A_dA)
179+ DV = Mooncake. primal (DV_dDV)
180+ dDV = Mooncake. tangent (DV_dDV)
181+ Ac = copy (A)
182+ DVc = copy .(DV)
183+ alg = Mooncake. primal (alg_dalg)
184+ output = $ f! (A, DV, alg)
185+ # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
186+ # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
187+ # pass). For many types this is done automatically when the forward step returns, but
188+ # not for nested structs with various fields (like Diagonal{Complex})
189+ output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
190+ function $adj (dy:: Tuple{NoRData, NoRData, T} ) where {T <: Real }
191+ copy! (A, Ac)
192+ copy! (DV[1 ], DVc[1 ])
193+ copy! (DV[2 ], DVc[2 ])
194+ Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
195+ dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
196+ abs (dy[3 ]) > MatrixAlgebraKit. defaulttol (dy[3 ]) && @warn " Pullback for $f does not yet support non-zero tangent for the truncation error"
197+ D′, dD′ = arrayify (Dtrunc, dDtrunc_)
198+ V′, dV′ = arrayify (Vtrunc, dVtrunc_)
199+ $ pb (dA, A, (D′, V′), (dD′, dV′))
200+ MatrixAlgebraKit. zero! (dD)
201+ MatrixAlgebraKit. zero! (dV)
202+ return NoRData (), NoRData (), NoRData ()
203+ end
204+ return output_codual, $ adj
205+ end
174206 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
175207 # compute primal
176208 A, dA = arrayify (A_dA)
@@ -194,7 +226,37 @@ for (f, f_ne, pb, adj) in (
194226 end
195227 return output_codual, $ adj
196228 end
229+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_ne!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
197230 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_ne), Any, MatrixAlgebraKit. AbstractAlgorithm}
231+ function Mooncake. rrule!! (:: CoDual{typeof($f_ne!)} , A_dA:: CoDual , DV_dDV:: CoDual , alg_dalg:: CoDual )
232+ # compute primal
233+ A, dA = arrayify (A_dA)
234+ alg = Mooncake. primal (alg_dalg)
235+ DV = Mooncake. primal (DV_dDV)
236+ dDV = Mooncake. tangent (DV_dDV)
237+ Ac = copy (A)
238+ DVc = copy .(DV)
239+ output = $ f_ne (A, DV, alg)
240+ # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
241+ # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
242+ # pass). For many types this is done automatically when the forward step returns, but
243+ # not for nested structs with various fields (like Diagonal{Complex})
244+ output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
245+ function $adj (:: NoRData )
246+ copy! (A, Ac)
247+ copy! (DV[1 ], DVc[1 ])
248+ copy! (DV[2 ], DVc[2 ])
249+ Dtrunc, Vtrunc = Mooncake. primal (output_codual)
250+ dDtrunc_, dVtrunc_ = Mooncake. tangent (output_codual)
251+ D′, dD′ = arrayify (Dtrunc, dDtrunc_)
252+ V′, dV′ = arrayify (Vtrunc, dVtrunc_)
253+ $ pb (dA, A, (D′, V′), (dD′, dV′))
254+ MatrixAlgebraKit. zero! (dD)
255+ MatrixAlgebraKit. zero! (dV)
256+ return NoRData (), NoRData (), NoRData ()
257+ end
258+ return output_codual, $ adj
259+ end
198260 function Mooncake. rrule!! (:: CoDual{typeof($f_ne)} , A_dA:: CoDual , alg_dalg:: CoDual )
199261 # compute primal
200262 A, dA = arrayify (A_dA)
@@ -234,9 +296,13 @@ for (f!, f) in (
234296 U, dU = arrayify (USVᴴ[1 ], dUSVᴴ[1 ])
235297 S, dS = arrayify (USVᴴ[2 ], dUSVᴴ[2 ])
236298 Vᴴ, dVᴴ = arrayify (USVᴴ[3 ], dUSVᴴ[3 ])
299+ USVᴴc = copy .(USVᴴ)
237300 output = $ f! (A, Mooncake. primal (alg_dalg))
238301 function svd_adjoint (:: NoRData )
239302 copy! (A, Ac)
303+ copy! (U, USVᴴc[1 ])
304+ copy! (S, USVᴴc[2 ])
305+ copy! (Vᴴ, USVᴴc[3 ])
240306 if $ (f! == svd_compact!)
241307 svd_pullback! (dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
242308 else # full
@@ -303,6 +369,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
303369 function svd_vals_adjoint (:: NoRData )
304370 svd_vals_pullback! (dA, A, USVᴴ, dS)
305371 MatrixAlgebraKit. zero! (dS)
372+ copy! (S, diagview (USVᴴ[2 ]))
306373 return NoRData (), NoRData (), NoRData (), NoRData ()
307374 end
308375 return S_dS, svd_vals_adjoint
@@ -328,6 +395,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
328395 return S_codual, svd_vals_adjoint
329396end
330397
398+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
399+ function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
400+ # compute primal
401+ A, dA = arrayify (A_dA)
402+ alg = Mooncake. primal (alg_dalg)
403+ Ac = copy (A)
404+ USVᴴ = Mooncake. primal (USVᴴ_dUSVᴴ)
405+ dUSVᴴ = Mooncake. tangent (USVᴴ_dUSVᴴ)
406+ U, dU = arrayify (USVᴴ[1 ], dUSVᴴ[1 ])
407+ S, dS = arrayify (USVᴴ[2 ], dUSVᴴ[2 ])
408+ Vᴴ, dVᴴ = arrayify (USVᴴ[3 ], dUSVᴴ[3 ])
409+ USVᴴc = copy .(USVᴴ)
410+ output = svd_trunc! (A, USVᴴ, alg)
411+ # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
412+ # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
413+ # pass). For many types this is done automatically when the forward step returns, but
414+ # not for nested structs with various fields (like Diagonal{Complex})
415+ output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
416+ function svd_trunc_adjoint (dy:: Tuple{NoRData, NoRData, NoRData, T} ) where {T <: Real }
417+ copy! (A, Ac)
418+ copy! (U, USVᴴc[1 ])
419+ copy! (S, USVᴴc[2 ])
420+ copy! (Vᴴ, USVᴴc[3 ])
421+ Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake. primal (output_codual)
422+ dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake. tangent (output_codual)
423+ abs (dy[4 ]) > MatrixAlgebraKit. defaulttol (dy[4 ]) && @warn " Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
424+ U′, dU′ = arrayify (Utrunc, dUtrunc_)
425+ S′, dS′ = arrayify (Strunc, dStrunc_)
426+ Vᴴ′, dVᴴ′ = arrayify (Vᴴtrunc, dVᴴtrunc_)
427+ svd_trunc_pullback! (dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
428+ MatrixAlgebraKit. zero! (dU)
429+ MatrixAlgebraKit. zero! (dS)
430+ MatrixAlgebraKit. zero! (dVᴴ)
431+ return NoRData (), NoRData (), NoRData ()
432+ end
433+ return output_codual, svd_trunc_adjoint
434+ end
435+
331436@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc), Any, MatrixAlgebraKit. AbstractAlgorithm}
332437function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
333438 # compute primal
@@ -357,6 +462,43 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
357462 return output_codual, svd_trunc_adjoint
358463end
359464
465+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc_no_error!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
466+ function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual )
467+ # compute primal
468+ A, dA = arrayify (A_dA)
469+ alg = Mooncake. primal (alg_dalg)
470+ Ac = copy (A)
471+ USVᴴ = Mooncake. primal (USVᴴ_dUSVᴴ)
472+ dUSVᴴ = Mooncake. tangent (USVᴴ_dUSVᴴ)
473+ U, dU = arrayify (USVᴴ[1 ], dUSVᴴ[1 ])
474+ S, dS = arrayify (USVᴴ[2 ], dUSVᴴ[2 ])
475+ Vᴴ, dVᴴ = arrayify (USVᴴ[3 ], dUSVᴴ[3 ])
476+ USVᴴc = copy .(USVᴴ)
477+ output = svd_trunc_no_error! (A, USVᴴ, alg)
478+ # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
479+ # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
480+ # pass). For many types this is done automatically when the forward step returns, but
481+ # not for nested structs with various fields (like Diagonal{Complex})
482+ output_codual = CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
483+ function svd_trunc_adjoint (:: NoRData )
484+ copy! (A, Ac)
485+ copy! (U, USVᴴc[1 ])
486+ copy! (S, USVᴴc[2 ])
487+ copy! (Vᴴ, USVᴴc[3 ])
488+ Utrunc, Strunc, Vᴴtrunc = Mooncake. primal (output_codual)
489+ dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake. tangent (output_codual)
490+ U′, dU′ = arrayify (Utrunc, dUtrunc_)
491+ S′, dS′ = arrayify (Strunc, dStrunc_)
492+ Vᴴ′, dVᴴ′ = arrayify (Vᴴtrunc, dVᴴtrunc_)
493+ svd_trunc_pullback! (dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
494+ MatrixAlgebraKit. zero! (dU)
495+ MatrixAlgebraKit. zero! (dS)
496+ MatrixAlgebraKit. zero! (dVᴴ)
497+ return NoRData (), NoRData (), NoRData ()
498+ end
499+ return output_codual, svd_trunc_adjoint
500+ end
501+
360502@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (svd_trunc_no_error), Any, MatrixAlgebraKit. AbstractAlgorithm}
361503function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual )
362504 # compute primal
0 commit comments