164164 rng = StableRNG (12345 )
165165 m = 19
166166 atol = rtol = m * m * precision (T)
167- A = randn (rng, T, m , m)
167+ A = make_eig_matrix (rng, T, m)
168168 D, V = eig_full (A)
169169 Ddiag = diagview (D)
170170 ΔV = randn (rng, complex (T), m, m)
274274 rng = StableRNG (12345 )
275275 m = 19
276276 atol = rtol = m * m * precision (T)
277- A = randn (rng, T, m, m)
278- A = A + A'
277+ A = make_eigh_matrix (rng, T, m)
278+ Ac = copy (A)
279+ A = (A + A' ) / 2
279280 D, V = eigh_full (A)
280281 D2 = Diagonal (D)
281282 ΔV = randn (rng, T, m, m)
@@ -289,11 +290,11 @@ end
289290 # LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI
290291 )
291292 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
292- test_reverse (copy_eigh_full, RT, (A , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)))
293- test_reverse (copy_eigh_full!, RT, (copy (A ), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)))
294- test_pullbacks_match (rng, copy_eigh_full!, copy_eigh_full, A , (D, V), (ΔD2, ΔV), alg)
295- test_reverse (copy_eigh_vals, RT, (A , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy (ΔD2. diag))
296- test_pullbacks_match (rng, copy_eigh_vals!, copy_eigh_vals, A , D. diag, ΔD2. diag, alg)
293+ test_reverse (copy_eigh_full, RT, (Ac , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)))
294+ test_reverse (copy_eigh_full!, RT, (copy (Ac ), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)))
295+ test_pullbacks_match (rng, copy_eigh_full!, copy_eigh_full, Ac , (D, V), (ΔD2, ΔV), alg)
296+ test_reverse (copy_eigh_vals, RT, (Ac , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy (ΔD2. diag))
297+ test_pullbacks_match (rng, copy_eigh_vals!, copy_eigh_vals, Ac , D. diag, ΔD2. diag, alg)
297298 end
298299 @testset " eigh_trunc reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
299300 for r in 1 : 4 : m
304305 Vtrunc = V[:, ind]
305306 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
306307 ΔVtrunc = ΔV[:, ind]
307- test_reverse (copy_eigh_trunc_no_error, RT, (A , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
308- test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
308+ test_reverse (copy_eigh_trunc_no_error, RT, (Ac , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
309+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
309310 end
310311 Ddiag = diagview (D)
311312 truncalg = TruncatedAlgorithm (alg, trunctol (; atol = maximum (abs, Ddiag) / 2 ))
314315 Vtrunc = V[:, ind]
315316 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
316317 ΔVtrunc = ΔV[:, ind]
317- test_reverse (copy_eigh_trunc_no_error, RT, (A , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
318- test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
318+ test_reverse (copy_eigh_trunc_no_error, RT, (Ac , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
319+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
319320 end
320321 end
321322end
0 commit comments