Skip to content

Commit 948c6fc

Browse files
committed
Working eigh_trunc
1 parent 0d540b1 commit 948c6fc

2 files changed

Lines changed: 12 additions & 15 deletions

File tree

src/common/safemethods.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,4 @@ sign_safe(s::Complex) = ifelse(iszero(s), one(s), s / abs(s))
1818
Compute the inverse of a number `a`, but return zero if `a` is smaller than `tol`.
1919
"""
2020
inv_safe(a::Number, tol = defaulttol(a)) = abs(a) < tol ? zero(a) : inv(a)
21-
function inv_safe(a::ComplexF32, tol = defaulttol(a))
22-
str = string(a) # WHY does this fix the NaN issues??????
23-
return abs(a) < tol ? zero(a) : inv(a)
24-
end
21+
@noinline inv_safe(a::ComplexF32, tol = defaulttol(a)) = abs(a) < tol ? zero(a) : inv(a)

test/enzyme.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -275,26 +275,26 @@ end
275275
m = 19
276276
atol = rtol = m * m * precision(T)
277277
A = make_eigh_matrix(rng, T, m)
278-
Ac = copy(A)
279-
A = (A + A') / 2
278+
#A = (A + A') / 2
280279
D, V = eigh_full(A)
281280
D2 = Diagonal(D)
282281
ΔV = randn(rng, T, m, m)
283282
ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol)
284283
ΔD = randn(rng, real(T), m, m)
285284
ΔD2 = Diagonal(randn(rng, real(T), m))
285+
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
286286
@testset for alg in (
287287
LAPACK_QRIteration(),
288288
#LAPACK_DivideAndConquer(),
289289
#LAPACK_Bisection(),
290290
#LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI
291291
)
292292
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
293-
test_reverse(copy_eigh_full, RT, (Ac, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
294-
test_reverse(copy_eigh_full!, RT, (copy(Ac), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
295-
test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, Ac, (D, V), (ΔD2, ΔV), alg)
296-
test_reverse(copy_eigh_vals, RT, (Ac, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag))
297-
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, Ac, D.diag, ΔD2.diag, alg)
293+
test_reverse(copy_eigh_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm)
294+
test_reverse(copy_eigh_full!, RT, (copy(A), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm)
295+
test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg)
296+
test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm)
297+
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg)
298298
end
299299
@testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
300300
for r in 1:4:m
@@ -305,8 +305,8 @@ end
305305
Vtrunc = V[:, ind]
306306
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
307307
ΔVtrunc = ΔV[:, ind]
308-
test_reverse(copy_eigh_trunc_no_error, RT, (Ac, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
309-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
308+
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm)
309+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
310310
end
311311
Ddiag = diagview(D)
312312
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2))
@@ -315,8 +315,8 @@ end
315315
Vtrunc = V[:, ind]
316316
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
317317
ΔVtrunc = ΔV[:, ind]
318-
test_reverse(copy_eigh_trunc_no_error, RT, (Ac, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
319-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
318+
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm)
319+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
320320
end
321321
end
322322
end

0 commit comments

Comments
 (0)