Skip to content

Commit 7db71db

Browse files
committed
Some more fixes
1 parent 07052f9 commit 7db71db

4 files changed

Lines changed: 19 additions & 26 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,16 +269,13 @@ for f in (:svd_compact!, :svd_full!)
269269
config::EnzymeRules.FwdConfigWidth{1},
270270
func::Const{typeof($f)},
271271
::Type{RT},
272-
A::Annotation,
272+
A::Annotation{TA},
273273
USVᴴ::Annotation,
274274
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
275-
) where {RT}
275+
) where {RT, TA}
276276
$f(A.val, USVᴴ.val, alg.val)
277-
if !isa(A, Const) && !isa(USVᴴ, Const)
278-
make_zero!(USVᴴ.dval)
279-
svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval)
280-
end
281-
!isa(A, Const) && make_zero!(A.dval)
277+
!isa(A, Const) && !isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval)
278+
make_zero!(A.dval)
282279
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
283280
return USVᴴ
284281
elseif EnzymeRules.needs_primal(config)
@@ -542,7 +539,7 @@ function EnzymeRules.forward(
542539
svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(diagview(S_)), Vᴴ), ΔS)
543540
A_is_arg && (S.dval .= ΔS)
544541
end
545-
!isa(A, Const) && !A_is_arg && make_zero!(A.dval)
542+
!A_is_arg && make_zero!(A.dval)
546543
copyto!(S.val, diagview(S_))
547544
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
548545
return S

src/pushforwards/svd.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = d
1818
vV = adjoint(vVᴴ)
1919
UΔAV = vU' * ΔA * vV
2020
copyto!(diagview(vΔS), diag(real.(UΔAV)))
21-
F = one(eltype(S)) ./ (transpose(vS) .- vS)
22-
G = one(eltype(S)) ./ (transpose(vS) .+ vS)
23-
diagview(F) .= zero(eltype(F))
21+
F = inv_safe.(transpose(vS) .- vS)
22+
G = inv_safe.(transpose(vS) .+ vS)
2423
hUΔAV = F .* (UΔAV + UΔAV') ./ 2
2524
aUΔAV = G .* (UΔAV - UΔAV') ./ 2
2625
= hUΔAV + aUΔAV

test/testsuite/enzyme/svd.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ function test_enzyme_svd_compact(
2323
alg = MatrixAlgebraKit.select_algorithm(svd_compact, A)
2424
USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A)
2525
test_reverse(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
26-
test_reverse(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
26+
test_reverse(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
2727
if eltype(T) <: Real
28-
A = instantiate_matrix(T, sz)
2928
test_forward(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, fdm)
30-
test_forward(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
29+
test_forward(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
3130
end
3231
end
3332
end
@@ -48,11 +47,10 @@ function test_enzyme_svd_full(
4847
alg = MatrixAlgebraKit.select_algorithm(svd_full, A)
4948
USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A)
5049
test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
51-
test_reverse(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
50+
test_reverse(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
5251
if eltype(T) <: Real
53-
A = instantiate_matrix(T, sz)
5452
test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm)
55-
test_forward(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
53+
test_forward(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
5654
end
5755
end
5856
end
@@ -72,10 +70,9 @@ function test_enzyme_svd_vals(
7270
alg = MatrixAlgebraKit.select_algorithm(svd_vals, A)
7371
S, ΔS = ad_svd_vals_setup(A)
7472
test_reverse(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
75-
test_reverse(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
76-
A = instantiate_matrix(T, sz)
73+
test_reverse(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
7774
test_forward(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm)
78-
test_forward(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
75+
test_forward(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
7976
end
8077
end
8178

@@ -99,15 +96,15 @@ function test_enzyme_svd_trunc(
9996
trunc = truncrank(r)
10097
truncalg = TruncatedAlgorithm(alg, trunc)
10198
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
102-
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
99+
test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
103100
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
104101
end
105102
@testset "trunctol" begin
106103
S = svd_vals(A, alg)
107104
trunc = trunctol(atol = maximum(S) / 2)
108105
truncalg = TruncatedAlgorithm(alg, trunc)
109106
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
110-
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
107+
test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
111108
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
112109
end
113110
end

test/testsuite/mooncake/svd.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function test_mooncake_svd_compact(
3333
mode = Mooncake.ReverseMode, output_tangent, atol, rtol
3434
)
3535
Mooncake.TestUtils.test_rule(
36-
rng, call_and_zero!, svd_compact!, A, alg;
36+
rng, call_and_zero!, svd_compact!, copy(A), alg;
3737
mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false
3838
)
3939
if eltype(T) <: Real # gauge freedom in complex outputs
@@ -42,7 +42,7 @@ function test_mooncake_svd_compact(
4242
mode = Mooncake.ForwardMode, atol, rtol
4343
)
4444
Mooncake.TestUtils.test_rule(
45-
rng, call_and_zero!, svd_compact!, A, alg;
45+
rng, call_and_zero!, svd_compact!, copy(A), alg;
4646
mode = Mooncake.ForwardMode, atol, rtol, is_primitive = false
4747
)
4848
end
@@ -70,7 +70,7 @@ function test_mooncake_svd_full(
7070
mode = Mooncake.ReverseMode, output_tangent, atol, rtol
7171
)
7272
Mooncake.TestUtils.test_rule(
73-
rng, call_and_zero!, svd_full!, A, alg;
73+
rng, call_and_zero!, svd_full!, copy(A), alg;
7474
mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false
7575
)
7676
if eltype(T) <: Real # gauge freedom in complex outputs
@@ -79,7 +79,7 @@ function test_mooncake_svd_full(
7979
mode = Mooncake.ForwardMode, atol, rtol
8080
)
8181
Mooncake.TestUtils.test_rule(
82-
rng, call_and_zero!, svd_full!, A, alg;
82+
rng, call_and_zero!, svd_full!, copy(A), alg;
8383
mode = Mooncake.ForwardMode, atol, rtol, is_primitive = false
8484
)
8585
end

0 commit comments

Comments
 (0)