Skip to content

Commit 2e832bd

Browse files
committed
some import changes and cleanup
1 parent a6a1821 commit 2e832bd

1 file changed

Lines changed: 32 additions & 60 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 32 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,41 @@
11
module MatrixAlgebraKitMooncakeExt
22

3-
using Mooncake
4-
using Mooncake: CoDual, Dual, NoRData, arrayify, primal, tangent, zero_fcodual
5-
import Mooncake: rrule!!
3+
using Mooncake: Mooncake as MC,
4+
CoDual, Dual, NoRData, arrayify, primal, tangent, zero_fcodual
65
using MatrixAlgebraKit
7-
using MatrixAlgebraKit: MatrixAlgebraKit as MAK, diagview, zero!, AbstractAlgorithm, TruncatedAlgorithm
6+
using MatrixAlgebraKit: MatrixAlgebraKit as MAK,
7+
diagview, zero!, AbstractAlgorithm, TruncatedAlgorithm
88
using LinearAlgebra
99

1010

1111
# Utility
1212
# -------
1313
# convenience helper for marking DefaultCtx ReverseMode signature as primitive
1414
macro is_rev_primitive(sig)
15-
return esc(:(Mooncake.@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode $sig))
15+
return esc(:(MC.@is_primitive MC.DefaultCtx MC.ReverseMode $sig))
1616
end
17+
1718
_warn_pullback_truncerror(dϵ::Real; tol = MatrixAlgebraKit.defaulttol(dϵ)) =
1819
abs(dϵ) tol || @warn "Pullback ignores non-zero tangents for truncation error"
1920

2021
const _nordata = Returns(NoRData())
2122

2223
# No derivatives
2324
# --------------
24-
Mooncake.tangent_type(::Type{<:AbstractAlgorithm}) = Mooncake.NoTangent
25+
MC.tangent_type(::Type{<:AbstractAlgorithm}) = MC.NoTangent
2526

26-
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(MAK.select_algorithm), Any, Any, Any}
27-
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(Core.kwcall), NamedTuple, typeof(MAK.select_algorithm), Any, Any, Any}
28-
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(MAK.initialize_output), Any, Any, Any}
29-
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(MAK.check_input), Any, Any, Any, Any}
27+
MC.@zero_derivative MC.DefaultCtx Tuple{typeof(MAK.select_algorithm), Any, Any, Any}
28+
MC.@zero_derivative MC.DefaultCtx Tuple{typeof(Core.kwcall), NamedTuple, typeof(MAK.select_algorithm), Any, Any, Any}
29+
MC.@zero_derivative MC.DefaultCtx Tuple{typeof(MAK.initialize_output), Any, Any, Any}
30+
MC.@zero_derivative MC.DefaultCtx Tuple{typeof(MAK.check_input), Any, Any, Any, Any}
3031

3132
@is_rev_primitive Tuple{typeof(MAK.copy_input), Any, Any}
32-
function rrule!!(::CoDual{typeof(MAK.copy_input)}, f_df::CoDual, A_dA::CoDual)
33+
function MC.rrule!!(::CoDual{typeof(MAK.copy_input)}, f_df::CoDual, A_dA::CoDual)
3334
Ac = MAK.copy_input(primal(f_df), primal(A_dA))
3435
Ac_dAc = zero_fcodual(Ac)
3536
dAc = tangent(Ac_dAc)
3637
function copy_input_pb(::NoRData)
37-
Mooncake.increment!!(tangent(A_dA), dAc)
38+
MC.increment!!(tangent(A_dA), dAc)
3839
return ntuple(_nordata, 3)
3940
end
4041
return Ac_dAc, copy_input_pb
@@ -75,7 +76,7 @@ for (f, pullback!, adjoint) in (
7576

7677
@eval begin
7778
@is_rev_primitive Tuple{typeof($f), Any, AbstractAlgorithm}
78-
function rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
79+
function MC.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
7980
# unpack variables
8081
A, dA = arrayify(A_dA)
8182
alg = primal(alg_dalg)
@@ -95,8 +96,8 @@ for (f, pullback!, adjoint) in (
9596
end
9697

9798
@is_rev_primitive Tuple{typeof($f!), Any, Tuple, AbstractAlgorithm}
98-
function rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
99-
args_dargs, pb! = rrule!!(zero_fcodual($f), A_dA, alg_dalg)
99+
function MC.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
100+
args_dargs, pb! = MC.rrule!!(zero_fcodual($f), A_dA, alg_dalg)
100101
return args_dargs, Returns(ntuple(_nordata, 4)) pb!
101102
end
102103
end
@@ -112,7 +113,7 @@ for (f, pullback!, adjoint) in (
112113

113114
@eval begin
114115
@is_rev_primitive Tuple{typeof($f), Any, AbstractAlgorithm}
115-
function rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
116+
function MC.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
116117
# unpack variables
117118
A, dA = arrayify(A_dA)
118119
alg = primal(alg_dalg)
@@ -132,8 +133,8 @@ for (f, pullback!, adjoint) in (
132133
end
133134

134135
@is_rev_primitive Tuple{typeof($f!), Any, Any, AbstractAlgorithm}
135-
function rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, N_dN::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
136-
arg_darg, pb! = rrule!!(zero_fcodual($f), A_dA, alg_dalg)
136+
function MC.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, N_dN::CoDual, alg_dalg::CoDual{<:AbstractAlgorithm})
137+
arg_darg, pb! = MC.rrule!!(zero_fcodual($f), A_dA, alg_dalg)
137138
return arg_darg, Returns(ntuple(_nordata, 4)) pb!
138139
end
139140
end
@@ -150,7 +151,7 @@ for f in (:eig, :eigh, :svd)
150151
# --------
151152
@eval begin
152153
@is_rev_primitive Tuple{typeof($f_vals), Any, AbstractAlgorithm}
153-
function rrule!!(::CoDual{typeof($f_vals)}, A_dA::CoDual, alg_dalg::CoDual)
154+
function MC.rrule!!(::CoDual{typeof($f_vals)}, A_dA::CoDual, alg_dalg::CoDual)
154155
# unpack variables
155156
A, dA = arrayify(A_dA)
156157
alg = primal(alg_dalg)
@@ -171,8 +172,8 @@ for f in (:eig, :eigh, :svd)
171172
end
172173

173174
@is_rev_primitive Tuple{typeof($f_vals!), Any, Any, AbstractAlgorithm}
174-
function rrule!!(::CoDual{typeof($f_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual)
175-
args_dargs, pb! = rrule!!(zero_fcodual($f_vals), A_dA, alg_dalg)
175+
function MC.rrule!!(::CoDual{typeof($f_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual)
176+
args_dargs, pb! = MC.rrule!!(zero_fcodual($f_vals), A_dA, alg_dalg)
176177
return args_dargs, Returns(ntuple(_nordata, 4)) pb!
177178
end
178179
end
@@ -184,10 +185,11 @@ for f in (:eig, :eigh, :svd)
184185
f_trunc! = Symbol(f_trunc, :!)
185186
pullback! = Symbol(f, :_pullback!)
186187
trunc_pullback! = Symbol(f_trunc, :_pullback!)
188+
f_trunc_no_error = Symbol(f_trunc, :_no_error)
187189

188190
@eval begin
189191
@is_rev_primitive Tuple{typeof($f_trunc), Any, AbstractAlgorithm}
190-
function rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
192+
function MC.rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
191193
# unpack variables
192194
A, dA = arrayify(A_dA)
193195
alg = primal(alg_dalg)
@@ -207,15 +209,15 @@ for f in (:eig, :eigh, :svd)
207209

208210
return argsϵ_dargsϵ, $adjoint
209211
end
210-
function rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
212+
function MC.rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
211213
# unpack variables
212214
A, dA = arrayify(A_dA)
213-
alg = Mooncake.primal(alg_dalg)
215+
alg = primal(alg_dalg)
214216

215217
# compute primal and pack output - capture full DV and ind
216218
args_full = $f_full(A, alg.alg)
217219
args, ind = MAK.truncate($f_trunc!, args_full, alg.trunc)
218-
ϵ = MAK.truncation_error(diagview(args[1]), ind)
220+
ϵ = MAK.truncation_error(diagview(args_full[$(f === :svd ? 2 : 1)]), ind)
219221
argsϵ = (args..., ϵ)
220222
argsϵ_dargsϵ = zero_fcodual(argsϵ)
221223

@@ -229,39 +231,15 @@ for f in (:eig, :eigh, :svd)
229231

230232
return argsϵ_dargsϵ, $adjoint
231233
end
234+
232235
@is_rev_primitive Tuple{typeof($f_trunc!), Any, Any, AbstractAlgorithm}
233-
function rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual)
234-
args_dargs, pb! = rrule!!(zero_fcodual($f_trunc), A_dA, alg_dalg)
236+
function MC.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual)
237+
args_dargs, pb! = MC.rrule!!(zero_fcodual($f_trunc), A_dA, alg_dalg)
235238
return args_dargs, Returns(ntuple(_nordata, 4)) pb!
236239
end
237-
end
238-
239-
# Truncated decompositions - no error
240-
# -----------------------------------
241-
f_trunc_no_error = Symbol(f_trunc, :_no_error)
242-
f_trunc_no_error! = Symbol(f_trunc_no_error, :!)
243240

244-
@eval begin
245-
@is_rev_primitive Tuple{typeof($f_trunc_no_error), Any, AbstractAlgorithm}
246-
function rrule!!(::CoDual{typeof($f_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
247-
# unpack variables
248-
A, dA = arrayify(A_dA)
249-
alg = primal(alg_dalg)
250-
251-
# compute primal and pack output
252-
args = $f_trunc(A, alg)
253-
args_dargs = zero_fcodual(args)
254-
255-
# define pullback
256-
dargs = last.(arrayify.(args, tangent(args_dargs)))
257-
function $adjoint(::NoRData)
258-
MAK.$trunc_pullback!(dA, A, args, dargs)
259-
return ntuple(_nordata, 3)
260-
end
261-
262-
return args_dargs, $adjoint
263-
end
264-
function rrule!!(::CoDual{typeof($f_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
241+
# still need specialized implementation for <:TruncatedAlgorithm
242+
function MC.rrule!!(::CoDual{typeof($f_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm})
265243
# unpack variables
266244
A, dA = arrayify(A_dA)
267245
alg = primal(alg_dalg)
@@ -280,12 +258,6 @@ for f in (:eig, :eigh, :svd)
280258

281259
return args_dargs, $adjoint
282260
end
283-
284-
@is_rev_primitive Tuple{typeof($f_trunc_no_error!), Any, Any, AbstractAlgorithm}
285-
function rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual)
286-
args_dargs, pb! = rrule!!(zero_fcodual($f_trunc_no_error), A_dA, alg_dalg)
287-
return args_dargs, Returns(ntuple(_nordata, 4)) pb!
288-
end
289261
end
290262
end
291263

0 commit comments

Comments
 (0)