Skip to content

Commit fb6c989

Browse files
committed
SVD working for diag
1 parent 750529c commit fb6c989

2 files changed

Lines changed: 11 additions & 2 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,10 +448,17 @@ function EnzymeRules.reverse(
448448
# appropriately here)
449449
Aval = nothing
450450
Sval = something(cache_S, S.val)
451+
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval
451452
if !isa(A, Const)
452-
svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS)
453+
if A_is_arg
454+
ΔA = make_zero(A.dval)
455+
svd_vals_pullback!(ΔA, Aval, (U, Diagonal(Sval), Vᴴ), dS)
456+
A.dval .= ΔA
457+
else
458+
svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS)
459+
end
453460
end
454-
if !isa(S, Const) && !(TA <: Diagonal && (diagview(A.dval) === S.dval))
461+
if !isa(S, Const) && !A_is_arg
455462
make_zero!(S.dval)
456463
end
457464
return (nothing, nothing, nothing)

test/enzyme/svd.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
1515
TestSuite.seed_rng!(1234)
1616
if !is_buildkite
1717
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
18+
AT = Diagonal{T, Vector{T}}
19+
m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
1820
end
1921
end

0 commit comments

Comments
 (0)