@@ -291,7 +291,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals)}, A_dA::CoD
291291 # of ComplexF32) into the correct **forwards** data type (since we are now in the forward
292292 # pass). For many types this is done automatically when the forward step returns, but
293293 # not for nested structs with various fields (like Diagonal{Complex})
294- S_codual = Mooncake. CoDual (diagview (S), Mooncake. fdata (Mooncake. zero_tangent (S )))
294+ S_codual = Mooncake. CoDual (diagview (S), Mooncake. fdata (Mooncake. zero_tangent (diagview (S) )))
295295 function dsvd_vals_adjoint (:: Mooncake.NoRData )
296296 S, dS = Mooncake. arrayify (S_codual)
297297 svd_pullback! (dA, A, (U, Diagonal (S), Vᴴ), (nothing , Diagonal (dS), nothing ))
@@ -317,7 +317,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, A_dA::Co
317317 function dsvd_trunc_adjoint (dy:: Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T} ) where {T <: Real }
318318 Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake. primal (output_codual)
319319 dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake. tangent (output_codual)
320- abs (dy[3 ]) > MatrixAlgebraKit. defaulttol (dy[3 ]) && @warn " Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
320+ abs (dy[4 ]) > MatrixAlgebraKit. defaulttol (dy[4 ]) && @warn " Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
321321 U, dU = Mooncake. arrayify (Utrunc, dUtrunc_)
322322 S, dS = Mooncake. arrayify (Strunc, dStrunc_)
323323 Vᴴ, dVᴴ = Mooncake. arrayify (Vᴴtrunc, dVᴴtrunc_)
0 commit comments