Skip to content

Commit 743950b

Browse files
committed
In-place tests and fixes
1 parent 1d1ca63 commit 743950b

5 files changed

Lines changed: 530 additions & 317 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4747

4848
[targets]
4949
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "Mooncake"]
50+
51+
[sources]
52+
Mooncake = {path="/Users/khyatt/.julia/dev/Mooncake"}

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 141 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,156 +3,220 @@ module MatrixAlgebraKitMooncakeExt
33
using Mooncake
44
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55
using MatrixAlgebraKit
6-
using MatrixAlgebraKit: inv_safe, diagview
6+
using MatrixAlgebraKit: inv_safe, diagview, copy_input
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
99
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!
1010
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
1111
using LinearAlgebra
1212

13+
14+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
15+
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
16+
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
17+
dAc = Mooncake.zero_tangent(Ac)
18+
function copy_input_pb(::Mooncake.NoRData)
19+
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
20+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
21+
end
22+
return CoDual(Ac, dAc), copy_input_pb
23+
end
24+
1325
# two-argument factorizations like LQ, QR, EIG
14-
for (f, pb, adj) in ((qr_full!, qr_pullback!, :dqr_adjoint),
15-
(qr_compact!, qr_pullback!, :dqr_adjoint),
16-
(lq_full!, lq_pullback!, :dlq_adjoint),
17-
(lq_compact!, lq_pullback!, :dlq_adjoint),
18-
(eig_full!, eig_pullback!, :deig_adjoint),
19-
(eigh_full!, eigh_pullback!, :deigh_adjoint),
20-
(left_polar!, left_polar_pullback!, :dleft_polar_adjoint),
21-
(right_polar!, right_polar_pullback!, :dright_polar_adjoint),
22-
)
26+
for (f, pb, adj) in (
27+
(qr_full!, qr_pullback!, :dqr_adjoint),
28+
(qr_compact!, qr_pullback!, :dqr_adjoint),
29+
(lq_full!, lq_pullback!, :dlq_adjoint),
30+
(lq_compact!, lq_pullback!, :dlq_adjoint),
31+
(eig_full!, eig_pullback!, :deig_adjoint),
32+
(eigh_full!, eigh_pullback!, :deigh_adjoint),
33+
(left_polar!, left_polar_pullback!, :dleft_polar_adjoint),
34+
(right_polar!, right_polar_pullback!, :dright_polar_adjoint),
35+
)
2336

2437
@eval begin
25-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
26-
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
27-
A, dA = arrayify(A_dA)
28-
dA .= zero(eltype(A))
29-
args = Mooncake.primal(args_dargs)
30-
dargs = Mooncake.tangent(args_dargs)
31-
arg1, darg1 = arrayify(args[1], dargs[1])
32-
arg2, darg2 = arrayify(args[2], dargs[2])
38+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
39+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
40+
A, dA = arrayify(A_dA)
41+
args = Mooncake.primal(args_dargs)
42+
dargs = Mooncake.tangent(args_dargs)
43+
arg1, darg1 = arrayify(args[1], dargs[1])
44+
arg2, darg2 = arrayify(args[2], dargs[2])
45+
Ac = copy(A)
46+
arg1c = copy(arg1)
47+
arg2c = copy(arg2)
48+
output = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
3349
function $adj(::Mooncake.NoRData)
34-
dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
50+
dAtmp_ = zero(dA)
51+
dAtmp_ .= $pb(dAtmp_, A, (arg1, arg2), (darg1, darg2); kwargs...)
52+
dAtmp = if eltype(dA) <: Real
53+
dAtmp_
54+
else
55+
map(A_ -> Mooncake.build_tangent(typeof(A_), real(A_), imag(A_)), dAtmp_)
56+
end
57+
Mooncake.increment!!(Mooncake.tangent(A_dA), dAtmp)
58+
arg1 .= arg1c
59+
arg2 .= arg2c
60+
A .= Ac
61+
darg1 .= zero(darg1)
62+
darg2 .= zero(darg2)
3563
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
3664
end
37-
args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
38-
darg1 .= zero(eltype(arg1))
39-
darg2 .= zero(eltype(arg2))
4065
return Mooncake.CoDual(args, dargs), $adj
4166
end
4267
end
4368
end
4469

45-
for (f, f_full, pb, adj) in ((qr_null!, qr_full, qr_null_pullback!, :dqr_null_adjoint),
46-
(lq_null!, lq_full, lq_null_pullback!, :dlq_null_adjoint),
47-
)
70+
for (f, f_full, pb, adj) in (
71+
(qr_null!, qr_full, qr_null_pullback!, :dqr_null_adjoint),
72+
(lq_null!, lq_full, lq_null_pullback!, :dlq_null_adjoint),
73+
)
4874
@eval begin
49-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
50-
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
51-
A, dA = arrayify(A_dA)
52-
Ac = MatrixAlgebraKit.copy_input($f_full, A)
53-
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
54-
arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
75+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
76+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
77+
A, dA = arrayify(A_dA)
78+
Ac = copy(A)
79+
arg, darg = arrayify(arg_darg)
80+
argc = copy(arg)
81+
# WHY is this copy needed?
82+
arg = $f(copy(A), arg, Mooncake.primal(alg_dalg))
5583
function $adj(::Mooncake.NoRData)
56-
dA .= zero(eltype(A))
57-
$pb(dA, A, arg, darg; kwargs...)
84+
dAtmp_ = zero(dA)
85+
dAtmp_ .= $pb(dAtmp_, A, arg, darg; kwargs...)
86+
dAtmp = if eltype(dA) <: Real
87+
dAtmp_
88+
else
89+
map(A_ -> Mooncake.build_tangent(typeof(A_), real(A_), imag(A_)), dAtmp_)
90+
end
91+
Mooncake.increment!!(Mooncake.tangent(A_dA), dAtmp)
92+
A .= Ac
93+
arg .= argc
94+
darg .= zero(darg)
5895
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
5996
end
60-
return arg_darg, $adj
97+
return arg_darg, $adj
6198
end
6299
end
63100
end
64101

65-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
102+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
66103
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
67104
# compute primal
68-
D_ = Mooncake.primal(D_dD)
69-
dD_ = Mooncake.tangent(D_dD)
70-
A_ = Mooncake.primal(A_dA)
71-
dA_ = Mooncake.tangent(A_dA)
105+
D_ = Mooncake.primal(D_dD)
106+
dD_ = Mooncake.tangent(D_dD)
107+
A_ = Mooncake.primal(A_dA)
108+
dA_ = Mooncake.tangent(A_dA)
72109
A, dA = arrayify(A_, dA_)
73110
D, dD = arrayify(D_, dD_)
74-
dA .= zero(eltype(dA))
75-
# update primal
76-
DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
77-
V = DV[2]
78-
dD .= zero(eltype(D))
111+
Ac = copy(A)
112+
Dc = copy(D)
113+
dDc = copy(dD)
114+
# update primal
115+
DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
116+
V = DV[2]
117+
eig_vals!(A, D, Mooncake.primal(alg_dalg))
79118
function deig_vals_adjoint(::Mooncake.NoRData)
119+
dA .= zero(eltype(dA))
120+
A .= Ac
80121
PΔV = V' \ Diagonal(dD)
81122
if eltype(dA) <: Real
82123
ΔAc = PΔV * V'
83124
dA .+= real.(ΔAc)
84125
else
85126
mul!(dA, PΔV, V', 1, 0)
86127
end
128+
D .= Dc
129+
dD .= dDc
87130
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
88131
end
89-
return Mooncake.CoDual(DV[1].diag, dD_), deig_vals_adjoint
132+
dD .= zero(eltype(D))
133+
return D_dD, deig_vals_adjoint
90134
end
91135

92-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
136+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
93137
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
94138
# compute primal
95-
D_ = Mooncake.primal(D_dD)
96-
dD_ = Mooncake.tangent(D_dD)
97-
A_ = Mooncake.primal(A_dA)
98-
dA_ = Mooncake.tangent(A_dA)
139+
D_ = Mooncake.primal(D_dD)
140+
dD_ = Mooncake.tangent(D_dD)
141+
A_ = Mooncake.primal(A_dA)
142+
dA_ = Mooncake.tangent(A_dA)
99143
A, dA = arrayify(A_, dA_)
100144
D, dD = arrayify(D_, dD_)
101-
DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
145+
DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
102146
function deigh_vals_adjoint(::Mooncake.NoRData)
103147
mul!(dA, DV[2] * Diagonal(real(dD)), DV[2]', 1, 0)
104148
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
105149
end
150+
dD .= zero(eltype(D))
106151
return Mooncake.CoDual(DV[1].diag, dD_), deigh_vals_adjoint
107152
end
108153

109154

110-
for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal))
155+
for f in (svd_full!, svd_compact!)
111156
@eval begin
112-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
157+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
113158
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...)
114-
A, dA = arrayify(A_dA)
115-
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
116-
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
117-
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
118-
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
159+
A, dA = arrayify(A_dA)
160+
Ac = copy(A)
161+
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
162+
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
163+
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
164+
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
119165
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
120-
USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
166+
Uc = copy(U)
167+
Sc = copy(S)
168+
Vᴴc = copy(Vᴴ)
169+
USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
170+
minmn = min(size(A)...)
121171
function dsvd_adjoint(::Mooncake.NoRData)
122-
dA .= zero(eltype(A))
123-
minmn = min(size(A)...)
172+
dAtmp_ = zero(dA)
173+
A .= Ac
124174
if size(U, 2) == size(Vᴴ, 1) == minmn # compact
125-
dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
175+
dAtmp_ = MatrixAlgebraKit.svd_pullback!(dAtmp_, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
126176
else # full
127-
vU = view(U, :, 1:minmn)
128-
vS = Diagonal(diagview(S)[1:minmn])
129-
vVᴴ = view(Vᴴ, 1:minmn, :)
130-
vdU = view(dU, :, 1:minmn)
131-
vdS = Diagonal(diagview(dS)[1:minmn])
132-
vdVᴴ = view(dVᴴ, 1:minmn, :)
133-
dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
177+
vU = view(U, :, 1:minmn)
178+
vS = Diagonal(diagview(S)[1:minmn])
179+
vVᴴ = view(Vᴴ, 1:minmn, :)
180+
vdU = view(dU, :, 1:minmn)
181+
vdS = Diagonal(diagview(dS)[1:minmn])
182+
vdVᴴ = view(dVᴴ, 1:minmn, :)
183+
dAtmp_ = MatrixAlgebraKit.svd_pullback!(dAtmp_, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
184+
end
185+
dAtmp = if eltype(dA) <: Real
186+
dAtmp_
187+
else
188+
map(A_ -> Mooncake.build_tangent(typeof(A_), real(A_), imag(A_)), dAtmp_)
134189
end
190+
Mooncake.increment!!(Mooncake.tangent(A_dA), dAtmp)
191+
U .= Uc
192+
S .= Sc
193+
Vᴴ .= Vᴴc
194+
dU .= zero(dU)
195+
dS .= zero(dS)
196+
dVᴴ .= zero(dVᴴ)
135197
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
136198
end
137199
return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint
138200
end
139201
end
140202
end
141203

142-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
204+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
143205
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
144206
# compute primal
145-
S_ = Mooncake.primal(S_dS)
146-
dS_ = Mooncake.tangent(S_dS)
147-
A_ = Mooncake.primal(A_dA)
148-
dA_ = Mooncake.tangent(A_dA)
207+
S_ = Mooncake.primal(S_dS)
208+
dS_ = Mooncake.tangent(S_dS)
209+
A_ = Mooncake.primal(A_dA)
210+
dA_ = Mooncake.tangent(A_dA)
149211
A, dA = arrayify(A_, dA_)
150212
S, dS = arrayify(S_, dS_)
213+
Ac = copy(A)
151214
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
152-
S .= diagview(nS)
153-
dS .= zero(eltype(S))
215+
S .= diagview(nS)
154216
function dsvd_vals_adjoint(::Mooncake.NoRData)
155-
dA .= U * Diagonal(dS) * Vᴴ
217+
dA .= U * Diagonal(dS) * Vᴴ
218+
A .= Ac
219+
dS .= zero(eltype(S))
156220
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
157221
end
158222
return S_dS, dsvd_vals_adjoint

src/common/view.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# diagind: provided by LinearAlgebra.jl
2-
diagview(D::Diagonal) = D.diag
2+
diagview(D::Diagonal) = D.diag
33
diagview(D::AbstractMatrix) = view(D, diagind(D))
44

55
# triangularind

test/ad_utils.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,32 @@
1-
function remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ;
2-
degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S))
1+
function remove_svdgauge_dependence!(
2+
ΔU, ΔVᴴ, U, S, Vᴴ;
3+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
4+
)
35
gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
46
gaugepart = (gaugepart - gaugepart') / 2
57
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
68
mul!(ΔU, U, gaugepart, -1, 1)
79
return ΔU, ΔVᴴ
810
end
9-
function remove_eiggauge_dependence!(ΔV, D, V;
10-
degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S))
11+
function remove_eiggauge_dependence!(
12+
ΔV, D, V;
13+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
14+
)
1115
gaugepart = V' * ΔV
1216
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
1317
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
1418
return ΔV
1519
end
16-
function remove_eighgauge_dependence!(ΔV, D, V;
17-
degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S))
20+
function remove_eighgauge_dependence!(
21+
ΔV, D, V;
22+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
23+
)
1824
gaugepart = V' * ΔV
1925
gaugepart = (gaugepart - gaugepart') / 2
2026
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
2127
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
2228
return ΔV
2329
end
2430

25-
precision(::Type{<:Union{Float32,Complex{Float32}}}) = sqrt(eps(Float32))
26-
precision(::Type{<:Union{Float64,Complex{Float64}}}) = sqrt(eps(Float64))
31+
precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32))
32+
precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64))

0 commit comments

Comments
 (0)