1717_warn_pullback_truncerror (dϵ:: Real ; tol = MatrixAlgebraKit. defaulttol (dϵ)) =
1818 abs (dϵ) ≤ tol || @warn " Pullback ignores non-zero tangents for truncation error"
1919
20+ const _nordata = Returns (NoRData ())
21+
2022# No derivatives
2123# --------------
2224Mooncake. tangent_type (:: Type{<:AbstractAlgorithm} ) = Mooncake. NoTangent
2325
2426Mooncake. @zero_derivative Mooncake. DefaultCtx Tuple{typeof (MAK. select_algorithm), Any, Any, Any}
2527Mooncake. @zero_derivative Mooncake. DefaultCtx Tuple{typeof (Core. kwcall), NamedTuple, typeof (MAK. select_algorithm), Any, Any, Any}
2628Mooncake. @zero_derivative Mooncake. DefaultCtx Tuple{typeof (MAK. initialize_output), Any, Any, Any}
27- Mooncake. @zero_derivative Mooncake. DefaultCtx Tuple{typeof (MAK. check_input), Vararg{ Any} }
29+ Mooncake. @zero_derivative Mooncake. DefaultCtx Tuple{typeof (MAK. check_input), Any, Any, Any, Any }
2830
29- @is_rev_primitive Tuple{typeof (copy_input), Any, Any}
30- function rrule!! (:: CoDual{typeof(copy_input)} , f_df:: CoDual , A_dA:: CoDual )
31- Ac = copy_input (primal (f_df), primal (A_dA))
31+ @is_rev_primitive Tuple{typeof (MAK . copy_input), Any, Any}
32+ function rrule!! (:: CoDual{typeof(MAK. copy_input)} , f_df:: CoDual , A_dA:: CoDual )
33+ Ac = MAK . copy_input (primal (f_df), primal (A_dA))
3234 Ac_dAc = zero_fcodual (Ac)
3335 dAc = tangent (Ac_dAc)
3436 function copy_input_pb (:: NoRData )
3537 Mooncake. increment!! (tangent (A_dA), dAc)
36- return NoRData ( )
38+ return ntuple (_nordata, 3 )
3739 end
3840 return Ac_dAc, copy_input_pb
3941end
4042
4143# Factorizations
4244# --------------
45+
46+ # The general approach here is to define the functions in terms of the non-mutating versions first.
47+ # Since we are not guaranteeing that we will be mutating the input, nor that we will make
48+ # use of the provided output buffers, we can simplify our lives by calling the non-mutating
49+ # implementations instead of the mutating ones.
50+ #
51+ # The main benefit here is that we do not have to guarantee that we will restore the state
52+ # after executing the pullback - ensuring that we don't have to keep as many copied objects
53+ # around. This being said, the total number of allocations does not become smaller because
54+ # of this, and in cases where the pullback would be used multiple times we now have to
55+ # allocate multiple times. On the other hand, we can also free these objects inbetween, so
56+ # this might also reduce the total GC pressure...
57+
58+
4359for (f, pullback!, adjoint) in (
4460 (:qr_full , :qr_pullback! , :qr_adjoint ),
4561 (:lq_full , :lq_pullback! , :lq_adjoint ),
@@ -72,21 +88,23 @@ for (f, pullback!, adjoint) in (
7288 dargs = last .(arrayify .(args, tangent (args_dargs)))
7389 function $adjoint (:: NoRData )
7490 MAK.$ pullback! (dA, A, args, dargs)
75- return NoRData ( )
91+ return ntuple (_nordata, 3 )
7692 end
7793
78- return args , $ adjoint
94+ return args_dargs , $ adjoint
7995 end
8096
8197 @is_rev_primitive Tuple{typeof ($ f!), Any, Tuple, AbstractAlgorithm}
82- rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} ) =
83- rrule!! (zero_fcodual ($ f), A_dA, alg_dalg)
98+ function rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
99+ args_dargs, pb! = rrule!! (zero_fcodual ($ f), A_dA, alg_dalg)
100+ return args_dargs, Returns (ntuple (_nordata, 4 )) ∘ pb!
101+ end
84102 end
85103end
86104
87105# Nullspaces
88106# ----------
89- for (f, pullback, adjoint) in (
107+ for (f, pullback! , adjoint) in (
90108 (:qr_null , :qr_null_pullback! , :qr_null_adjoint ),
91109 (:lq_null , :lq_null_pullback! , :lq_null_adjoint ),
92110 )
@@ -107,24 +125,26 @@ for (f, pullback, adjoint) in (
107125 dN = last (arrayify (N, tangent (N_dN)))
108126 function $adjoint (:: NoRData )
109127 MAK.$ pullback! (dA, A, N, dN)
110- return NoRData ( )
128+ return ntuple (_nordata, 3 )
111129 end
112130
113- return N , $ adjoint
131+ return N_dN , $ adjoint
114132 end
115133
116134 @is_rev_primitive Tuple{typeof ($ f!), Any, Any, AbstractAlgorithm}
117- rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , N_dN:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} ) =
118- rrule!! (zero_fcodual ($ f), A_dA, alg_dalg)
135+ function rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , N_dN:: CoDual , alg_dalg:: CoDual{<:AbstractAlgorithm} )
136+ arg_darg, pb! = rrule!! (zero_fcodual ($ f), A_dA, alg_dalg)
137+ return arg_darg, Returns (ntuple (_nordata, 4 )) ∘ pb!
138+ end
119139 end
120140end
121141
122142for f in (:eig , :eigh , :svd )
123143 f_vals = Symbol (f, :_vals )
124144 f_vals! = Symbol (f_vals, :! )
125- f_full = Symbol (f, :_full )
126- vals_pulback ! = Symbol (f, :_vals_pullback! )
127- adjoint! = Symbol (f, :_adjoint )
145+ f_full = f === :svd ? Symbol (f, :_compact ) : Symbol (f, :_full )
146+ vals_pullback ! = Symbol (f, :_vals_pullback! )
147+ adjoint = Symbol (f, :_adjoint )
128148
129149 # f_values
130150 # --------
@@ -144,15 +164,17 @@ for f in (:eig, :eigh, :svd)
144164 dvals = last (arrayify (vals, tangent (vals_dvals)))
145165 function $adjoint (:: NoRData )
146166 MAK.$ vals_pullback! (dA, A, F, dvals)
147- return NoRData ( )
167+ return ntuple (_nordata, 3 )
148168 end
149169
150170 return vals_dvals, $ adjoint
151171 end
152172
153173 @is_rev_primitive Tuple{typeof ($ f_vals!), Any, Any, AbstractAlgorithm}
154- rrule!! (:: CoDual{typeof($f_vals!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual ) =
155- rrule!! (zero_fcodual ($ f_vals), A_dA, alg_dalg)
174+ function rrule!! (:: CoDual{typeof($f_vals!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual )
175+ args_dargs, pb! = rrule!! (zero_fcodual ($ f_vals), A_dA, alg_dalg)
176+ return args_dargs, Returns (ntuple (_nordata, 4 )) ∘ pb!
177+ end
156178 end
157179
158180
@@ -180,7 +202,7 @@ for f in (:eig, :eigh, :svd)
180202 function $adjoint (dy)
181203 _warn_pullback_truncerror (last (dy))
182204 MAK.$ trunc_pullback! (dA, A, args, dargs)
183- return NoRData ( )
205+ return ntuple (_nordata, 3 )
184206 end
185207
186208 return argsϵ_dargsϵ, $ adjoint
@@ -199,17 +221,19 @@ for f in (:eig, :eigh, :svd)
199221
200222 # define pullback
201223 dargs = last .(arrayify .(args, Base. front (tangent (argsϵ_dargsϵ))))
202- function $f_adjoint! (dy)
224+ function $adjoint (dy)
203225 _warn_pullback_truncerror (last (dy))
204226 MAK.$ pullback! (dA, A, args_full, dargs, ind)
205- return NoRData ( )
227+ return ntuple (_nordata, 3 )
206228 end
207229
208- return DVtrunc_dDVtrunc , $ f_adjoint!
230+ return argsϵ_dargsϵ , $ adjoint
209231 end
210232 @is_rev_primitive Tuple{typeof ($ f_trunc!), Any, Any, AbstractAlgorithm}
211- rrule!! (:: CoDual{typeof($f_trunc!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual ) =
212- rrule!! (zero_fcodual ($ f_trunc), A_dA, alg_dalg)
233+ function rrule!! (:: CoDual{typeof($f_trunc!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual )
234+ args_dargs, pb! = rrule!! (zero_fcodual ($ f_trunc), A_dA, alg_dalg)
235+ return args_dargs, Returns (ntuple (_nordata, 4 )) ∘ pb!
236+ end
213237 end
214238
215239 # Truncated decompositions - no error
@@ -232,7 +256,7 @@ for f in (:eig, :eigh, :svd)
232256 dargs = last .(arrayify .(args, tangent (args_dargs)))
233257 function $adjoint (:: NoRData )
234258 MAK.$ trunc_pullback! (dA, A, args, dargs)
235- return NoRData ( )
259+ return ntuple (_nordata, 3 )
236260 end
237261
238262 return args_dargs, $ adjoint
@@ -251,15 +275,17 @@ for f in (:eig, :eigh, :svd)
251275 dargs = last .(arrayify .(args, tangent (args_dargs)))
252276 function $adjoint (:: NoRData )
253277 MAK.$ pullback! (dA, A, args_full, dargs, ind)
254- return NoRData ( )
278+ return ntuple (_nordata, 3 )
255279 end
256280
257281 return args_dargs, $ adjoint
258282 end
259283
260284 @is_rev_primitive Tuple{typeof ($ f_trunc_no_error!), Any, Any, AbstractAlgorithm}
261- rrule!! (:: CoDual{typeof($f_trunc_no_error!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual ) =
262- rrule!! (zero_fcodual ($ f_trunc_no_error), A_dA, alg_dalg)
285+ function rrule!! (:: CoDual{typeof($f_trunc_no_error!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual )
286+ args_dargs, pb! = rrule!! (zero_fcodual ($ f_trunc_no_error), A_dA, alg_dalg)
287+ return args_dargs, Returns (ntuple (_nordata, 4 )) ∘ pb!
288+ end
263289 end
264290end
265291
0 commit comments