Skip to content

Commit e697d06

Browse files
committed
Tangent type and bad index
1 parent e13c881 commit e697d06

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals)}, A_dA::CoD
291291
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
292292
# pass). For many types this is done automatically when the forward step returns, but
293293
# not for nested structs with various fields (like Diagonal{Complex})
294-
S_codual = Mooncake.CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(S)))
294+
S_codual = Mooncake.CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(diagview(S))))
295295
function dsvd_vals_adjoint(::Mooncake.NoRData)
296296
S, dS = Mooncake.arrayify(S_codual)
297297
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
@@ -317,7 +317,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, A_dA::Co
317317
function dsvd_trunc_adjoint(dy::Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T}) where {T <: Real}
318318
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
319319
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
320-
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
320+
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
321321
U, dU = Mooncake.arrayify(Utrunc, dUtrunc_)
322322
S, dS = Mooncake.arrayify(Strunc, dStrunc_)
323323
Vᴴ, dVᴴ = Mooncake.arrayify(Vᴴtrunc, dVᴴtrunc_)

0 commit comments

Comments
 (0)