@@ -136,7 +136,7 @@ for (f!, f, f_full, pb, adj) in (
136136 copy! (D, diagview (DV[1 ]))
137137 V = DV[2 ]
138138 function $adj (:: Mooncake.NoRData )
139- $ pb (dA, A, (D , V), (dD , nothing ))
139+ $ pb (dA, A, (Diagonal (D) , V), (Diagonal (dD) , nothing ))
140140 MatrixAlgebraKit. zero! (dD)
141141 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
142142 end
@@ -152,8 +152,8 @@ for (f!, f, f_full, pb, adj) in (
152152 output = diagview (DV[1 ])
153153 output_codual = Mooncake. CoDual (output, Mooncake. zero_tangent (output))
154154 function $adj (:: Mooncake.NoRData )
155- D_dD = Mooncake. arrayify (D_dD )
156- $ pb (dA, A, (D , V), (dD , nothing ))
155+ D, dD = Mooncake. arrayify (output_codual )
156+ $ pb (dA, A, (Diagonal (D) , V), (Diagonal (dD) , nothing ))
157157 MatrixAlgebraKit. zero! (dD)
158158 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
159159 end
@@ -181,7 +181,7 @@ for (f, pb, adj) in (
181181 function $adj (dy:: Tuple{Mooncake.NoRData, Mooncake.NoRData, T} ) where {T <: Real }
182182 Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
183183 dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
184- abs (dϵ ) > MatrixAlgebraKit. defaulttol (dϵ ) && @warn " Pullback for $f does not yet support non-zero tangent for the truncation error"
184+ abs (dy[ 3 ] ) > MatrixAlgebraKit. defaulttol (dy[ 3 ] ) && @warn " Pullback for $f does not yet support non-zero tangent for the truncation error"
185185 D, dD = Mooncake. arrayify (Dtrunc, dDtrunc_)
186186 V, dV = Mooncake. arrayify (Vtrunc, dVtrunc_)
187187 $ pb (dA, A, (D, V), (dD, dV))
@@ -275,7 +275,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Co
275275 U, nS, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg))
276276 copy! (S, diagview (nS))
277277 function dsvd_vals_adjoint (:: Mooncake.NoRData )
278- svd_pullback! (dA, A, (U, S , Vᴴ), (nothing , dS , nothing ))
278+ svd_pullback! (dA, A, (U, Diagonal (S) , Vᴴ), (nothing , Diagonal (dS) , nothing ))
279279 MatrixAlgebraKit. zero! (dS)
280280 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
281281 end
@@ -294,7 +294,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals)}, A_dA::CoD
294294 S_codual = Mooncake. CoDual (diagview (S), Mooncake. fdata (Mooncake. zero_tangent (S)))
295295 function dsvd_vals_adjoint (:: Mooncake.NoRData )
296296 S, dS = Mooncake. arrayify (S_codual)
297- svd_pullback! (dA, A, (U, S , Vᴴ), (nothing , dS , nothing ))
297+ svd_pullback! (dA, A, (U, Diagonal (S) , Vᴴ), (nothing , Diagonal (dS) , nothing ))
298298 MatrixAlgebraKit. zero! (dS)
299299 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
300300 end
@@ -317,7 +317,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, A_dA::Co
317317 function dsvd_trunc_adjoint (dy:: Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T} ) where {T <: Real }
318318 Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake. primal (output_codual)
319319 dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake. tangent (output_codual)
320- abs (dϵ ) > MatrixAlgebraKit. defaulttol (dϵ ) && @warn " Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
320+ abs (dy[ 3 ] ) > MatrixAlgebraKit. defaulttol (dy[ 3 ] ) && @warn " Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
321321 U, dU = Mooncake. arrayify (Utrunc, dUtrunc_)
322322 S, dS = Mooncake. arrayify (Strunc, dStrunc_)
323323 Vᴴ, dVᴴ = Mooncake. arrayify (Vᴴtrunc, dVᴴtrunc_)
0 commit comments