Skip to content

Commit a6a1821

Browse files
committed
various fixes and clarifications
1 parent 37ba892 commit a6a1821

1 file changed

Lines changed: 56 additions & 30 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,45 @@ end
1717
_warn_pullback_truncerror(dϵ::Real; tol = MatrixAlgebraKit.defaulttol(dϵ)) =
1818
abs(dϵ) tol || @warn "Pullback ignores non-zero tangents for truncation error"
1919

20+
const _nordata = Returns(NoRData())
21+
2022
# No derivatives
2123
# --------------
2224
Mooncake.tangent_type(::Type{<:AbstractAlgorithm}) = Mooncake.NoTangent
2325

2426
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(MAK.select_algorithm), Any, Any, Any}
2527
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(Core.kwcall), NamedTuple, typeof(MAK.select_algorithm), Any, Any, Any}
2628
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(MAK.initialize_output), Any, Any, Any}
27-
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(MAK.check_input), Vararg{Any}}
29+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(MAK.check_input), Any, Any, Any, Any}
2830

29-
@is_rev_primitive Tuple{typeof(copy_input), Any, Any}
30-
function rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
31-
Ac = copy_input(primal(f_df), primal(A_dA))
31+
@is_rev_primitive Tuple{typeof(MAK.copy_input), Any, Any}
32+
function rrule!!(::CoDual{typeof(MAK.copy_input)}, f_df::CoDual, A_dA::CoDual)
33+
Ac = MAK.copy_input(primal(f_df), primal(A_dA))
3234
Ac_dAc = zero_fcodual(Ac)
3335
dAc = tangent(Ac_dAc)
3436
function copy_input_pb(::NoRData)
3537
Mooncake.increment!!(tangent(A_dA), dAc)
36-
return NoRData()
38+
return ntuple(_nordata, 3)
3739
end
3840
return Ac_dAc, copy_input_pb
3941
end
4042

4143
# Factorizations
4244
# --------------
45+
46+
# The general approach here is to define the functions in terms of the non-mutating versions first.
47+
# Since we are not guaranteeing that we will be mutating the input, nor that we will make
48+
# use of the provided output buffers, we can simplify our lives by calling the non-mutating
49+
# implementations instead of the mutating ones.
50+
#
51+
# The main benefit here is that we do not have to guarantee that we will restore the state
52+
# after executing the pullback - ensuring that we don't have to keep as many copied objects
53+
# around. This being said, the total number of allocations does not become smaller because
54+
# of this, and in cases where the pullback would be used multiple times we now have to
55+
# allocate multiple times. On the other hand, we can also free these objects inbetween, so
56+
# this might also reduce the total GC pressure...
57+
58+
4359
for (f, pullback!, adjoint) in (
4460
(:qr_full, :qr_pullback!, :qr_adjoint),
4561
(:lq_full, :lq_pullback!, :lq_adjoint),
@@ -72,21 +88,23 @@ for (f, pullback!, adjoint) in (
7288
dargs = last.(arrayify.(args, tangent(args_dargs)))
7389
function $adjoint(::NoRData)
7490
MAK.$pullback!(dA, A, args, dargs)
75-
return NoRData()
91+
return ntuple(_nordata, 3)
7692
end
7793

78-
return args, $adjoint
94+
return args_dargs, $adjoint
7995
end
8096

8197
@is_rev_primitive Tuple{typeof($f!), Any, Tuple, AbstractAlgorithm}
82-
rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm}) =
83-
rrule!!(zero_fcodual($f), A_dA, alg_dalg)
98+
function rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
99+
args_dargs, pb! = rrule!!(zero_fcodual($f), A_dA, alg_dalg)
100+
return args_dargs, Returns(ntuple(_nordata, 4)) pb!
101+
end
84102
end
85103
end
86104

87105
# Nullspaces
88106
# ----------
89-
for (f, pullback, adjoint) in (
107+
for (f, pullback!, adjoint) in (
90108
(:qr_null, :qr_null_pullback!, :qr_null_adjoint),
91109
(:lq_null, :lq_null_pullback!, :lq_null_adjoint),
92110
)
@@ -107,24 +125,26 @@ for (f, pullback, adjoint) in (
107125
dN = last(arrayify(N, tangent(N_dN)))
108126
function $adjoint(::NoRData)
109127
MAK.$pullback!(dA, A, N, dN)
110-
return NoRData()
128+
return ntuple(_nordata, 3)
111129
end
112130

113-
return N, $adjoint
131+
return N_dN, $adjoint
114132
end
115133

116134
@is_rev_primitive Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
117-
rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, N_dN::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm}) =
118-
rrule!!(zero_fcodual($f), A_dA, alg_dalg)
135+
function rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, N_dN::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
136+
arg_darg, pb! = rrule!!(zero_fcodual($f), A_dA, alg_dalg)
137+
return arg_darg, Returns(ntuple(_nordata, 4)) pb!
138+
end
119139
end
120140
end
121141

122142
for f in (:eig, :eigh, :svd)
123143
f_vals = Symbol(f, :_vals)
124144
f_vals! = Symbol(f_vals, :!)
125-
f_full = Symbol(f, :_full)
126-
vals_pulback! = Symbol(f, :_vals_pullback!)
127-
adjoint! = Symbol(f, :_adjoint)
145+
f_full = f === :svd ? Symbol(f, :_compact) : Symbol(f, :_full)
146+
vals_pullback! = Symbol(f, :_vals_pullback!)
147+
adjoint = Symbol(f, :_adjoint)
128148

129149
# f_values
130150
# --------
@@ -144,15 +164,17 @@ for f in (:eig, :eigh, :svd)
144164
dvals = last(arrayify(vals, tangent(vals_dvals)))
145165
function $adjoint(::NoRData)
146166
MAK.$vals_pullback!(dA, A, F, dvals)
147-
return NoRData()
167+
return ntuple(_nordata, 3)
148168
end
149169

150170
return vals_dvals, $adjoint
151171
end
152172

153173
@is_rev_primitive Tuple{typeof($f_vals!), Any, Any, AbstractAlgorithm}
154-
rrule!!(::CoDual{typeof($f_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual) =
155-
rrule!!(zero_fcodual($f_vals), A_dA, alg_dalg)
174+
function rrule!!(::CoDual{typeof($f_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual)
175+
args_dargs, pb! = rrule!!(zero_fcodual($f_vals), A_dA, alg_dalg)
176+
return args_dargs, Returns(ntuple(_nordata, 4)) pb!
177+
end
156178
end
157179

158180

@@ -180,7 +202,7 @@ for f in (:eig, :eigh, :svd)
180202
function $adjoint(dy)
181203
_warn_pullback_truncerror(last(dy))
182204
MAK.$trunc_pullback!(dA, A, args, dargs)
183-
return NoRData()
205+
return ntuple(_nordata, 3)
184206
end
185207

186208
return argsϵ_dargsϵ, $adjoint
@@ -199,17 +221,19 @@ for f in (:eig, :eigh, :svd)
199221

200222
# define pullback
201223
dargs = last.(arrayify.(args, Base.front(tangent(argsϵ_dargsϵ))))
202-
function $f_adjoint!(dy)
224+
function $adjoint(dy)
203225
_warn_pullback_truncerror(last(dy))
204226
MAK.$pullback!(dA, A, args_full, dargs, ind)
205-
return NoRData()
227+
return ntuple(_nordata, 3)
206228
end
207229

208-
return DVtrunc_dDVtrunc, $f_adjoint!
230+
return argsϵ_dargsϵ, $adjoint
209231
end
210232
@is_rev_primitive Tuple{typeof($f_trunc!), Any, Any, AbstractAlgorithm}
211-
rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual) =
212-
rrule!!(zero_fcodual($f_trunc), A_dA, alg_dalg)
233+
function rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual)
234+
args_dargs, pb! = rrule!!(zero_fcodual($f_trunc), A_dA, alg_dalg)
235+
return args_dargs, Returns(ntuple(_nordata, 4)) pb!
236+
end
213237
end
214238

215239
# Truncated decompositions - no error
@@ -232,7 +256,7 @@ for f in (:eig, :eigh, :svd)
232256
dargs = last.(arrayify.(args, tangent(args_dargs)))
233257
function $adjoint(::NoRData)
234258
MAK.$trunc_pullback!(dA, A, args, dargs)
235-
return NoRData()
259+
return ntuple(_nordata, 3)
236260
end
237261

238262
return args_dargs, $adjoint
@@ -251,15 +275,17 @@ for f in (:eig, :eigh, :svd)
251275
dargs = last.(arrayify.(args, tangent(args_dargs)))
252276
function $adjoint(::NoRData)
253277
MAK.$pullback!(dA, A, args_full, dargs, ind)
254-
return NoRData()
278+
return ntuple(_nordata, 3)
255279
end
256280

257281
return args_dargs, $adjoint
258282
end
259283

260284
@is_rev_primitive Tuple{typeof($f_trunc_no_error!), Any, Any, AbstractAlgorithm}
261-
rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual) =
262-
rrule!!(zero_fcodual($f_trunc_no_error), A_dA, alg_dalg)
285+
function rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual)
286+
args_dargs, pb! = rrule!!(zero_fcodual($f_trunc_no_error), A_dA, alg_dalg)
287+
return args_dargs, Returns(ntuple(_nordata, 4)) pb!
288+
end
263289
end
264290
end
265291

0 commit comments

Comments
 (0)