@@ -6,9 +6,10 @@ using MatrixAlgebraKit
66using MatrixAlgebraKit: inv_safe, diagview, copy_input
77using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
9- using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
9+ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
10+ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1011using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
11- using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!
12+ using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1213using LinearAlgebra
1314
1415
@@ -122,8 +123,8 @@ for (f!, f, pb, adj) in (
122123end
123124
124125for (f!, f, f_full, pb, adj) in (
125- (:eig_vals! , :eig_vals , :eig_full , :eig_pullback ! , :eig_vals_adjoint ),
126- (:eigh_vals! , :eigh_vals , :eigh_full , :eigh_pullback ! , :eigh_vals_adjoint ),
126+ (:eig_vals! , :eig_vals , :eig_full , :eig_vals_pullback ! , :eig_vals_adjoint ),
127+ (:eigh_vals! , :eigh_vals , :eigh_full , :eigh_vals_pullback ! , :eigh_vals_adjoint ),
127128 )
128129 @eval begin
129130 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
@@ -136,7 +137,7 @@ for (f!, f, f_full, pb, adj) in (
136137 copy! (D, diagview (DV[1 ]))
137138 V = DV[2 ]
138139 function $adj (:: NoRData )
139- $ pb (dA, A, ( Diagonal (D), V), ( Diagonal (dD), nothing ) )
140+ $ pb (dA, A, DV, dD )
140141 MatrixAlgebraKit. zero! (dD)
141142 return NoRData (), NoRData (), NoRData (), NoRData ()
142143 end
@@ -153,7 +154,7 @@ for (f!, f, f_full, pb, adj) in (
153154 output_codual = CoDual (output, Mooncake. zero_tangent (output))
154155 function $adj (:: NoRData )
155156 D, dD = arrayify (output_codual)
156- $ pb (dA, A, ( Diagonal (D), V), ( Diagonal (dD), nothing ) )
157+ $ pb (dA, A, DV, dD )
157158 MatrixAlgebraKit. zero! (dD)
158159 return NoRData (), NoRData (), NoRData ()
159160 end
@@ -272,10 +273,10 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
272273 # compute primal
273274 A, dA = arrayify (A_dA)
274275 S, dS = arrayify (S_dS)
275- U, nS, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
276- copy! (S, diagview (nS ))
276+ USVᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
277+ copy! (S, diagview (USVᴴ[ 2 ] ))
277278 function svd_vals_adjoint (:: NoRData )
278- svd_pullback ! (dA, A, (U, Diagonal (S), Vᴴ), ( nothing , Diagonal (dS), nothing ) )
279+ svd_vals_pullback ! (dA, A, USVᴴ, dS )
279280 MatrixAlgebraKit. zero! (dS)
280281 return NoRData (), NoRData (), NoRData (), NoRData ()
281282 end
@@ -286,15 +287,16 @@ end
286287function Mooncake. rrule!! (:: CoDual{typeof(svd_vals)} , A_dA:: CoDual , alg_dalg:: CoDual )
287288 # compute primal
288289 A, dA = arrayify (A_dA)
289- U, S, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
290+ USVᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
290291 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
291292 # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
292293 # pass). For many types this is done automatically when the forward step returns, but
293294 # not for nested structs with various fields (like Diagonal{Complex})
294- S_codual = CoDual (diagview (S), Mooncake. fdata (Mooncake. zero_tangent (diagview (S))))
295+ S = diagview (USVᴴ[2 ])
296+ S_codual = CoDual (S, Mooncake. fdata (Mooncake. zero_tangent (S)))
295297 function svd_vals_adjoint (:: NoRData )
296298 S, dS = arrayify (S_codual)
297- svd_pullback ! (dA, A, (U, Diagonal (S), Vᴴ), ( nothing , Diagonal (dS), nothing ) )
299+ svd_vals_pullback ! (dA, A, USVᴴ, dS )
298300 MatrixAlgebraKit. zero! (dS)
299301 return NoRData (), NoRData (), NoRData ()
300302 end
0 commit comments