We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fc4f760 commit f446751Copy full SHA for f446751
3 files changed
ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
@@ -275,7 +275,7 @@ for f in (:svd_compact!, :svd_full!)
275
) where {RT, TA}
276
$f(A.val, USVᴴ.val, alg.val)
277
if !isa(A, Const)
278
- if $(f == svd_compact!)
+ if $(f == svd_compact!)
279
make_zero!(USVᴴ.dval[2].diag)
280
else
281
make_zero!(USVᴴ.dval[2])
src/pushforwards/svd.jl
@@ -44,7 +44,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d
44
view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV
45
view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV'
46
rhs = vcat(adjoint(Uperp * ΔA * Vᴴ), Vᴴperp * ΔA' * U)
47
- superKM = -sylvester(UÃÃV, Smat, rhs)
+ superKM = -_sylvester(UÃÃV, Smat, rhs)
48
K̇perp = view(superKM, 1:size(aUAV, 2))
49
Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2)))
50
∂U .+= Uperp * K̇perp
@@ -62,7 +62,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d
62
view(ÃÃ, (1:m), m .+ (1:n)) .= Ã
63
view(ÃÃ, m .+ (1:n), 1:m) .= Ã'
64
65
- superLN = -sylvester(ÃÃ, vSmat, rhs)
+ superLN = -_sylvester(ÃÃ, vSmat, rhs)
66
∂U += view(superLN, 1:size(upper, 1), :)
67
∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :)
68
end
test/testsuite/enzyme/svd.jl
@@ -48,7 +48,7 @@ function test_enzyme_svd_full(
USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A)
test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_reverse(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
51
- if eltype(T) <: Real
+ if eltype(T) <: Real && size(A, 1) == size(A, 2) # finite differences check for free component is very finicky
52
test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm)
53
test_forward(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
54
0 commit comments