@@ -275,26 +275,26 @@ end
275275 m = 19
276276 atol = rtol = m * m * precision (T)
277277 A = make_eigh_matrix (rng, T, m)
278- Ac = copy (A)
279- A = (A + A' ) / 2
278+ # A = (A + A') / 2
280279 D, V = eigh_full (A)
281280 D2 = Diagonal (D)
282281 ΔV = randn (rng, T, m, m)
283282 ΔV = remove_eighgauge_dependence! (ΔV, D, V; degeneracy_atol = atol)
284283 ΔD = randn (rng, real (T), m, m)
285284 ΔD2 = Diagonal (randn (rng, real (T), m))
285+ fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 , max_range = 1.0e-2 ) : EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 )
286286 @testset for alg in (
287287 LAPACK_QRIteration (),
288288 # LAPACK_DivideAndConquer(),
289289 # LAPACK_Bisection(),
290290 # LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI
291291 )
292292 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
293- test_reverse (copy_eigh_full, RT, (Ac , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)))
294- test_reverse (copy_eigh_full!, RT, (copy (Ac ), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)))
295- test_pullbacks_match (rng, copy_eigh_full!, copy_eigh_full, Ac , (D, V), (ΔD2, ΔV), alg)
296- test_reverse (copy_eigh_vals, RT, (Ac , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy (ΔD2. diag))
297- test_pullbacks_match (rng, copy_eigh_vals!, copy_eigh_vals, Ac , D. diag, ΔD2. diag, alg)
293+ test_reverse (copy_eigh_full, RT, (A , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)), fdm = fdm )
294+ test_reverse (copy_eigh_full!, RT, (copy (A ), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)), fdm = fdm )
295+ test_pullbacks_match (rng, copy_eigh_full!, copy_eigh_full, A , (D, V), (ΔD2, ΔV), alg)
296+ test_reverse (copy_eigh_vals, RT, (A , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy (ΔD2. diag), fdm = fdm )
297+ test_pullbacks_match (rng, copy_eigh_vals!, copy_eigh_vals, A , D. diag, ΔD2. diag, alg)
298298 end
299299 @testset " eigh_trunc reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
300300 for r in 1 : 4 : m
305305 Vtrunc = V[:, ind]
306306 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
307307 ΔVtrunc = ΔV[:, ind]
308- test_reverse (copy_eigh_trunc_no_error, RT, (Ac , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
309- test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
308+ test_reverse (copy_eigh_trunc_no_error, RT, (A , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm )
309+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
310310 end
311311 Ddiag = diagview (D)
312312 truncalg = TruncatedAlgorithm (alg, trunctol (; atol = maximum (abs, Ddiag) / 2 ))
315315 Vtrunc = V[:, ind]
316316 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
317317 ΔVtrunc = ΔV[:, ind]
318- test_reverse (copy_eigh_trunc_no_error, RT, (Ac , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
319- test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
318+ test_reverse (copy_eigh_trunc_no_error, RT, (A , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm )
319+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
320320 end
321321 end
322322end
0 commit comments