@@ -11,21 +11,23 @@ include("ad_utils.jl")
1111for f in
1212 (
1313 :qr_compact , :qr_full , :qr_null , :lq_compact , :lq_full , :lq_null ,
14- :eig_full , :eig_trunc , :eigh_full , :eigh_trunc , :svd_compact , :svd_trunc ,
14+ :eig_full , :eig_trunc , :eig_vals , :eigh_full , :eigh_trunc , :eigh_vals ,
15+ :svd_compact , :svd_trunc , :svd_vals ,
1516 :left_polar , :right_polar ,
1617 )
1718 copy_f = Symbol (:copy_ , f)
1819 f! = Symbol (f, ' !' )
20+ _hermitian = startswith (string (f), " eigh" )
1921 @eval begin
2022 function $copy_f (input, alg)
21- if $ f === eigh_full || $ f === eigh_trunc
23+ if $ _hermitian
2224 input = (input + input' ) / 2
2325 end
2426 return $ f (input, alg)
2527 end
2628 function ChainRulesCore. rrule (:: typeof ($ copy_f), input, alg)
2729 output = MatrixAlgebraKit. initialize_output ($ f!, input, alg)
28- if $ f === eigh_full || $ f === eigh_trunc
30+ if $ _hermitian
2931 input = (input + input' ) / 2
3032 else
3133 input = copy (input)
@@ -228,12 +230,13 @@ end
228230 ΔD2 = Diagonal (randn (rng, complex (T), m))
229231 for alg in (LAPACK_Simple (), LAPACK_Expert ())
230232 test_rrule (
231- copy_eig_full, A, alg ⊢ NoTangent ();
232- output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol
233+ copy_eig_full, A, alg ⊢ NoTangent (); output_tangent = (ΔD, ΔV), atol, rtol
233234 )
234235 test_rrule (
235- copy_eig_full, A, alg ⊢ NoTangent ();
236- output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol
236+ copy_eig_full, A, alg ⊢ NoTangent (); output_tangent = (ΔD2, ΔV), atol, rtol
237+ )
238+ test_rrule (
239+ copy_eig_vals, A, alg ⊢ NoTangent (); output_tangent = diagview (ΔD), atol, rtol
237240 )
238241 for r in 1 : 4 : m
239242 truncalg = TruncatedAlgorithm (alg, truncrank (r; by = abs))
284287 config, last ∘ eig_full, A;
285288 output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
286289 )
290+ test_rrule (
291+ config, eig_vals, A;
292+ output_tangent = diagview (ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
293+ )
287294end
288295
289296@timedtestset " EIGH AD Rules with eltype $T " for T in (Float64, ComplexF64, Float32)
@@ -304,12 +311,13 @@ end
304311 )
305312 # copy_eigh_full includes a projector onto the Hermitian part of the matrix
306313 test_rrule (
307- copy_eigh_full, A, alg ⊢ NoTangent (); output_tangent = (ΔD, ΔV),
308- atol = atol, rtol = rtol
314+ copy_eigh_full, A, alg ⊢ NoTangent (); output_tangent = (ΔD, ΔV), atol, rtol
309315 )
310316 test_rrule (
311- copy_eigh_full, A, alg ⊢ NoTangent (); output_tangent = (ΔD2, ΔV),
312- atol = atol, rtol = rtol
317+ copy_eigh_full, A, alg ⊢ NoTangent (); output_tangent = (ΔD2, ΔV), atol, rtol
318+ )
319+ test_rrule (
320+ copy_eigh_vals, A, alg ⊢ NoTangent (); output_tangent = diagview (ΔD), atol, rtol
313321 )
314322 for r in 1 : 4 : m
315323 truncalg = TruncatedAlgorithm (alg, truncrank (r; by = abs))
361369 config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A;
362370 output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
363371 )
372+ test_rrule (
373+ config, eigh_vals ∘ Matrix ∘ Hermitian, A;
374+ output_tangent = diagview (ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
375+ )
364376 eigh_trunc2 (A; kwargs... ) = eigh_trunc (Matrix (Hermitian (A)); kwargs... )
365377 for r in 1 : 4 : m
366378 trunc = truncrank (r; by = real)
404416 copy_svd_compact, A, alg ⊢ NoTangent ();
405417 output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol
406418 )
419+ test_rrule (
420+ copy_svd_vals, A, alg ⊢ NoTangent ();
421+ output_tangent = diagview (ΔS), atol, rtol
422+ )
407423 for r in 1 : 4 : minmn
408424 truncalg = TruncatedAlgorithm (alg, truncrank (r))
409425 ind = MatrixAlgebraKit. findtruncated (diagview (S), truncalg. trunc)
451467 output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol,
452468 rrule_f = rrule_via_ad, check_inferred = false
453469 )
470+ test_rrule (
471+ config, svd_vals, A;
472+ output_tangent = diagview (ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
473+ )
454474 for r in 1 : 4 : minmn
455475 trunc = truncrank (r)
456476 ind = MatrixAlgebraKit. findtruncated (diagview (S), trunc)
0 commit comments