Skip to content

Commit fc4f760

Browse files
committed
Some small fixes again
1 parent 4726163 commit fc4f760

3 files changed

Lines changed: 8 additions & 3 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ for f in (:svd_compact!, :svd_full!)
275275
) where {RT, TA}
276276
$f(A.val, USVᴴ.val, alg.val)
277277
if !isa(A, Const)
278+
if $(f == svd_compact!)
279+
make_zero!(USVᴴ.dval[2].diag)
280+
else
281+
make_zero!(USVᴴ.dval[2])
282+
end
278283
!isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval)
279284
make_zero!(A.dval)
280285
end

src/pushforwards/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d
77
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
88
r = svd_rank(S; rank_atol)
99

10-
vΔS = view(ΔS, 1:r, 1:r)
10+
vΔS = view(diagview(ΔS), 1:r)
1111

1212
vU = view(U, :, 1:r)
1313
vS = view(S, 1:r)
@@ -17,7 +17,7 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d
1717
# compact region
1818
vV = adjoint(vVᴴ)
1919
UΔAV = vU' * ΔA * vV
20-
copyto!(diagview(vΔS), real.(diagview(UΔAV)))
20+
copyto!(vΔS, real.(diagview(UΔAV)))
2121
F = inv_safe.(transpose(vS) .- vS)
2222
G = inv_safe.(transpose(vS) .+ vS)
2323
hUΔAV = F .* (UΔAV + UΔAV') ./ 2

test/enzyme/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using LinearAlgebra: Diagonal
44
using CUDA, AMDGPU
55

6-
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
6+
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
77
GenericFloats = ()
88
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
99
using .TestSuite

0 commit comments

Comments
 (0)