Skip to content

Commit 29113c1

Browse files
committed
also update mooncake rules
1 parent b0eec92 commit 29113c1

1 file changed

Lines changed: 14 additions & 12 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ using MatrixAlgebraKit
66
using MatrixAlgebraKit: inv_safe, diagview, copy_input
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
9-
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
9+
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
10+
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1011
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
11-
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!
12+
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1213
using LinearAlgebra
1314

1415

@@ -122,8 +123,8 @@ for (f!, f, pb, adj) in (
122123
end
123124

124125
for (f!, f, f_full, pb, adj) in (
125-
(:eig_vals!, :eig_vals, :eig_full, :eig_pullback!, :eig_vals_adjoint),
126-
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_pullback!, :eigh_vals_adjoint),
126+
(:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_adjoint),
127+
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint),
127128
)
128129
@eval begin
129130
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@@ -136,7 +137,7 @@ for (f!, f, f_full, pb, adj) in (
136137
copy!(D, diagview(DV[1]))
137138
V = DV[2]
138139
function $adj(::NoRData)
139-
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
140+
$pb(dA, A, DV, dD)
140141
MatrixAlgebraKit.zero!(dD)
141142
return NoRData(), NoRData(), NoRData(), NoRData()
142143
end
@@ -153,7 +154,7 @@ for (f!, f, f_full, pb, adj) in (
153154
output_codual = CoDual(output, Mooncake.zero_tangent(output))
154155
function $adj(::NoRData)
155156
D, dD = arrayify(output_codual)
156-
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
157+
$pb(dA, A, DV, dD)
157158
MatrixAlgebraKit.zero!(dD)
158159
return NoRData(), NoRData(), NoRData()
159160
end
@@ -272,10 +273,10 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
272273
# compute primal
273274
A, dA = arrayify(A_dA)
274275
S, dS = arrayify(S_dS)
275-
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
276-
copy!(S, diagview(nS))
276+
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
277+
copy!(S, diagview(USVᴴ[2]))
277278
function svd_vals_adjoint(::NoRData)
278-
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
279+
svd_vals_pullback!(dA, A, USVᴴ, dS)
279280
MatrixAlgebraKit.zero!(dS)
280281
return NoRData(), NoRData(), NoRData(), NoRData()
281282
end
@@ -286,15 +287,16 @@ end
286287
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
287288
# compute primal
288289
A, dA = arrayify(A_dA)
289-
U, S, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
290+
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
290291
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
291292
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
292293
# pass). For many types this is done automatically when the forward step returns, but
293294
# not for nested structs with various fields (like Diagonal{Complex})
294-
S_codual = CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(diagview(S))))
295+
S = diagview(USVᴴ[2])
296+
S_codual = CoDual(S, Mooncake.fdata(Mooncake.zero_tangent(S)))
295297
function svd_vals_adjoint(::NoRData)
296298
S, dS = arrayify(S_codual)
297-
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
299+
svd_vals_pullback!(dA, A, USVᴴ, dS)
298300
MatrixAlgebraKit.zero!(dS)
299301
return NoRData(), NoRData(), NoRData()
300302
end

0 commit comments

Comments
 (0)