11module MatrixAlgebraKitMooncakeExt
22
3- using Mooncake
4- using Mooncake: CoDual, Dual, NoRData, arrayify, primal, tangent, zero_fcodual
5- import Mooncake: rrule!!
3+ using Mooncake: Mooncake as MC,
4+ CoDual, Dual, NoRData, arrayify, primal, tangent, zero_fcodual
65using MatrixAlgebraKit
7- using MatrixAlgebraKit: MatrixAlgebraKit as MAK, diagview, zero!, AbstractAlgorithm, TruncatedAlgorithm
6+ using MatrixAlgebraKit: MatrixAlgebraKit as MAK,
7+ diagview, zero!, AbstractAlgorithm, TruncatedAlgorithm
88using LinearAlgebra
99
1010
1111# Utility
1212# -------
1313# convenience helper for marking DefaultCtx ReverseMode signature as primitive
1414macro is_rev_primitive (sig)
15- return esc (:(Mooncake . @is_primitive Mooncake . DefaultCtx Mooncake . ReverseMode $ sig))
15+ return esc (:(MC . @is_primitive MC . DefaultCtx MC . ReverseMode $ sig))
1616end
17+
1718_warn_pullback_truncerror (dϵ:: Real ; tol = MatrixAlgebraKit. defaulttol (dϵ)) =
1819 abs (dϵ) ≤ tol || @warn " Pullback ignores non-zero tangents for truncation error"
1920
2021const _nordata = Returns (NoRData ())
2122
2223# No derivatives
2324# --------------
24- Mooncake . tangent_type (:: Type{<:AbstractAlgorithm} ) = Mooncake . NoTangent
25+ MC . tangent_type (:: Type{<:AbstractAlgorithm} ) = MC . NoTangent
2526
26- Mooncake . @zero_derivative Mooncake . DefaultCtx Tuple{typeof (MAK. select_algorithm), Any, Any, Any}
27- Mooncake . @zero_derivative Mooncake . DefaultCtx Tuple{typeof (Core. kwcall), NamedTuple, typeof (MAK. select_algorithm), Any, Any, Any}
28- Mooncake . @zero_derivative Mooncake . DefaultCtx Tuple{typeof (MAK. initialize_output), Any, Any, Any}
29- Mooncake . @zero_derivative Mooncake . DefaultCtx Tuple{typeof (MAK. check_input), Any, Any, Any, Any}
27+ MC . @zero_derivative MC . DefaultCtx Tuple{typeof (MAK. select_algorithm), Any, Any, Any}
28+ MC . @zero_derivative MC . DefaultCtx Tuple{typeof (Core. kwcall), NamedTuple, typeof (MAK. select_algorithm), Any, Any, Any}
29+ MC . @zero_derivative MC . DefaultCtx Tuple{typeof (MAK. initialize_output), Any, Any, Any}
30+ MC . @zero_derivative MC . DefaultCtx Tuple{typeof (MAK. check_input), Any, Any, Any, Any}
3031
3132@is_rev_primitive Tuple{typeof (MAK. copy_input), Any, Any}
32- function rrule!! (:: CoDual{typeof(MAK.copy_input)} , f_df:: CoDual , A_dA:: CoDual )
33+ function MC . rrule!! (:: CoDual{typeof(MAK.copy_input)} , f_df:: CoDual , A_dA:: CoDual )
3334 Ac = MAK. copy_input (primal (f_df), primal (A_dA))
3435 Ac_dAc = zero_fcodual (Ac)
3536 dAc = tangent (Ac_dAc)
3637 function copy_input_pb (:: NoRData )
37- Mooncake . increment!! (tangent (A_dA), dAc)
38+ MC . increment!! (tangent (A_dA), dAc)
3839 return ntuple (_nordata, 3 )
3940 end
4041 return Ac_dAc, copy_input_pb
@@ -75,7 +76,7 @@ for (f, pullback!, adjoint) in (
7576
7677 @eval begin
7778 @is_rev_primitive Tuple{typeof ($ f), Any, AbstractAlgorithm}
78- function rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
79+ function MC . rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
7980 # unpack variables
8081 A, dA = arrayify (A_dA)
8182 alg = primal (alg_dalg)
@@ -95,8 +96,8 @@ for (f, pullback!, adjoint) in (
9596 end
9697
9798 @is_rev_primitive Tuple{typeof ($ f!), Any, Tuple, AbstractAlgorithm}
98- function rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
99- args_dargs, pb! = rrule!! (zero_fcodual ($ f), A_dA, alg_dalg)
99+ function MC . rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
100+ args_dargs, pb! = MC . rrule!! (zero_fcodual ($ f), A_dA, alg_dalg)
100101 return args_dargs, Returns (ntuple (_nordata, 4 )) ∘ pb!
101102 end
102103 end
@@ -112,7 +113,7 @@ for (f, pullback!, adjoint) in (
112113
113114 @eval begin
114115 @is_rev_primitive Tuple{typeof ($ f), Any, AbstractAlgorithm}
115- function rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
116+ function MC . rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
116117 # unpack variables
117118 A, dA = arrayify (A_dA)
118119 alg = primal (alg_dalg)
@@ -132,8 +133,8 @@ for (f, pullback!, adjoint) in (
132133 end
133134
134135 @is_rev_primitive Tuple{typeof ($ f!), Any, Any, AbstractAlgorithm}
135- function rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , N_dN:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
136- arg_darg, pb! = rrule!! (zero_fcodual ($ f), A_dA, alg_dalg)
136+ function MC . rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , N_dN:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
137+ arg_darg, pb! = MC . rrule!! (zero_fcodual ($ f), A_dA, alg_dalg)
137138 return arg_darg, Returns (ntuple (_nordata, 4 )) ∘ pb!
138139 end
139140 end
@@ -150,7 +151,7 @@ for f in (:eig, :eigh, :svd)
150151 # --------
151152 @eval begin
152153 @is_rev_primitive Tuple{typeof ($ f_vals), Any, AbstractAlgorithm}
153- function rrule!! (:: CoDual{typeof($f_vals)} , A_dA:: CoDual , alg_dalg:: CoDual )
154+ function MC . rrule!! (:: CoDual{typeof($f_vals)} , A_dA:: CoDual , alg_dalg:: CoDual )
154155 # unpack variables
155156 A, dA = arrayify (A_dA)
156157 alg = primal (alg_dalg)
@@ -171,8 +172,8 @@ for f in (:eig, :eigh, :svd)
171172 end
172173
173174 @is_rev_primitive Tuple{typeof ($ f_vals!), Any, Any, AbstractAlgorithm}
174- function rrule!! (:: CoDual{typeof($f_vals!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual )
175- args_dargs, pb! = rrule!! (zero_fcodual ($ f_vals), A_dA, alg_dalg)
175+ function MC . rrule!! (:: CoDual{typeof($f_vals!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual )
176+ args_dargs, pb! = MC . rrule!! (zero_fcodual ($ f_vals), A_dA, alg_dalg)
176177 return args_dargs, Returns (ntuple (_nordata, 4 )) ∘ pb!
177178 end
178179 end
@@ -184,10 +185,11 @@ for f in (:eig, :eigh, :svd)
184185 f_trunc! = Symbol (f_trunc, :! )
185186 pullback! = Symbol (f, :_pullback! )
186187 trunc_pullback! = Symbol (f_trunc, :_pullback! )
188+ f_trunc_no_error = Symbol (f_trunc, :_no_error )
187189
188190 @eval begin
189191 @is_rev_primitive Tuple{typeof ($ f_trunc), Any, AbstractAlgorithm}
190- function rrule!! (:: CoDual{typeof($f_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
192+ function MC . rrule!! (:: CoDual{typeof($f_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
191193 # unpack variables
192194 A, dA = arrayify (A_dA)
193195 alg = primal (alg_dalg)
@@ -207,15 +209,15 @@ for f in (:eig, :eigh, :svd)
207209
208210 return argsϵ_dargsϵ, $ adjoint
209211 end
210- function rrule!! (:: CoDual{typeof($f_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
212+ function MC . rrule!! (:: CoDual{typeof($f_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
211213 # unpack variables
212214 A, dA = arrayify (A_dA)
213- alg = Mooncake . primal (alg_dalg)
215+ alg = primal (alg_dalg)
214216
215217 # compute primal and pack output - capture full DV and ind
216218 args_full = $ f_full (A, alg. alg)
217219 args, ind = MAK. truncate ($ f_trunc!, args_full, alg. trunc)
218- ϵ = MAK. truncation_error (diagview (args[ 1 ]), ind)
220+ ϵ = MAK. truncation_error (diagview (args_full[ $ (f === :svd ? 2 : 1 ) ]), ind)
219221 argsϵ = (args... , ϵ)
220222 argsϵ_dargsϵ = zero_fcodual (argsϵ)
221223
@@ -229,39 +231,15 @@ for f in (:eig, :eigh, :svd)
229231
230232 return argsϵ_dargsϵ, $ adjoint
231233 end
234+
232235 @is_rev_primitive Tuple{typeof ($ f_trunc!), Any, Any, AbstractAlgorithm}
233- function rrule!! (:: CoDual{typeof($f_trunc!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual )
234- args_dargs, pb! = rrule!! (zero_fcodual ($ f_trunc), A_dA, alg_dalg)
236+ function MC . rrule!! (:: CoDual{typeof($f_trunc!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual )
237+ args_dargs, pb! = MC . rrule!! (zero_fcodual ($ f_trunc), A_dA, alg_dalg)
235238 return args_dargs, Returns (ntuple (_nordata, 4 )) ∘ pb!
236239 end
237- end
238-
239- # Truncated decompositions - no error
240- # -----------------------------------
241- f_trunc_no_error = Symbol (f_trunc, :_no_error )
242- f_trunc_no_error! = Symbol (f_trunc_no_error, :! )
243240
244- @eval begin
245- @is_rev_primitive Tuple{typeof ($ f_trunc_no_error), Any, AbstractAlgorithm}
246- function rrule!! (:: CoDual{typeof($f_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual )
247- # unpack variables
248- A, dA = arrayify (A_dA)
249- alg = primal (alg_dalg)
250-
251- # compute primal and pack output
252- args = $ f_trunc (A, alg)
253- args_dargs = zero_fcodual (args)
254-
255- # define pullback
256- dargs = last .(arrayify .(args, tangent (args_dargs)))
257- function $adjoint (:: NoRData )
258- MAK.$ trunc_pullback! (dA, A, args, dargs)
259- return ntuple (_nordata, 3 )
260- end
261-
262- return args_dargs, $ adjoint
263- end
264- function rrule!! (:: CoDual{typeof($f_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
241+ # still need specialized implementation for <:TruncatedAlgorithm
242+ function MC. rrule!! (:: CoDual{typeof($f_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual{<:TruncatedAlgorithm} )
265243 # unpack variables
266244 A, dA = arrayify (A_dA)
267245 alg = primal (alg_dalg)
@@ -280,12 +258,6 @@ for f in (:eig, :eigh, :svd)
280258
281259 return args_dargs, $ adjoint
282260 end
283-
284- @is_rev_primitive Tuple{typeof ($ f_trunc_no_error!), Any, Any, AbstractAlgorithm}
285- function rrule!! (:: CoDual{typeof($f_trunc_no_error!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual )
286- args_dargs, pb! = rrule!! (zero_fcodual ($ f_trunc_no_error), A_dA, alg_dalg)
287- return args_dargs, Returns (ntuple (_nordata, 4 )) ∘ pb!
288- end
289261 end
290262end
291263
0 commit comments