@@ -6,9 +6,10 @@ using MatrixAlgebraKit
66using MatrixAlgebraKit: inv_safe, diagview, copy_input
77using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
9- using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
9+ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
10+ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1011using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
11- using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!
12+ using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1213using LinearAlgebra
1314
1415
@@ -122,8 +123,8 @@ for (f!, f, pb, adj) in (
122123end
123124
124125for (f!, f, f_full, pb, adj) in (
125- (:eig_vals! , :eig_vals , :eig_full , :eig_pullback ! , :eig_vals_adjoint ),
126- (:eigh_vals! , :eigh_vals , :eigh_full , :eigh_pullback ! , :eigh_vals_adjoint ),
126+ (:eig_vals! , :eig_vals , :eig_full , :eig_vals_pullback ! , :eig_vals_adjoint ),
127+ (:eigh_vals! , :eigh_vals , :eigh_full , :eigh_vals_pullback ! , :eigh_vals_adjoint ),
127128 )
128129 @eval begin
129130 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
@@ -136,7 +137,7 @@ for (f!, f, f_full, pb, adj) in (
136137 copy! (D, diagview (DV[1 ]))
137138 V = DV[2 ]
138139 function $adj (:: NoRData )
139- $ pb (dA, A, ( Diagonal (D), V), ( Diagonal (dD), nothing ) )
140+ $ pb (dA, A, DV, dD )
140141 MatrixAlgebraKit. zero! (dD)
141142 return NoRData (), NoRData (), NoRData (), NoRData ()
142143 end
@@ -153,7 +154,7 @@ for (f!, f, f_full, pb, adj) in (
153154 output_codual = CoDual (output, Mooncake. zero_tangent (output))
154155 function $adj (:: NoRData )
155156 D, dD = arrayify (output_codual)
156- $ pb (dA, A, ( Diagonal (D), V), ( Diagonal (dD), nothing ) )
157+ $ pb (dA, A, DV, dD )
157158 MatrixAlgebraKit. zero! (dD)
158159 return NoRData (), NoRData (), NoRData ()
159160 end
@@ -275,7 +276,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
275276 U, nS, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
276277 copy! (S, diagview (nS))
277278 function svd_vals_adjoint (:: NoRData )
278- svd_pullback ! (dA, A, (U, Diagonal (S) , Vᴴ), ( nothing , Diagonal (dS), nothing ) )
279+ svd_vals_pullback ! (dA, A, (U, nS , Vᴴ), dS )
279280 MatrixAlgebraKit. zero! (dS)
280281 return NoRData (), NoRData (), NoRData (), NoRData ()
281282 end
@@ -294,7 +295,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
294295 S_codual = CoDual (diagview (S), Mooncake. fdata (Mooncake. zero_tangent (diagview (S))))
295296 function svd_vals_adjoint (:: NoRData )
296297 S, dS = arrayify (S_codual)
297- svd_pullback ! (dA, A, (U, Diagonal (S) , Vᴴ), ( nothing , Diagonal (dS), nothing ) )
298+ svd_vals_pullback ! (dA, A, (U, S , Vᴴ), dS )
298299 MatrixAlgebraKit. zero! (dS)
299300 return NoRData (), NoRData (), NoRData ()
300301 end
0 commit comments