@@ -10,7 +10,6 @@ for f in
1010 :eig_trunc_no_error , :eigh_trunc_no_error ,
1111 :svd_compact , :svd_trunc , :svd_trunc_no_error , :svd_vals ,
1212 :left_polar , :right_polar ,
13- :project_hermitian , :project_antihermitian , :project_isometric ,
1413 )
1514 copy_f = Symbol (:cr_copy_ , f)
1615 f! = Symbol (f, ' !' )
@@ -622,47 +621,14 @@ function test_chainrules_projections(
622621 return @testset " Projections Chainrules AD rules $summary_str " begin
623622 A = instantiate_matrix (T, sz)
624623 m, n = size (A)
625- config = Zygote. ZygoteRuleConfig ()
626624 if m == n
627- alg_h = MatrixAlgebraKit. default_hermitian_algorithm (A)
628625 @testset " project_hermitian" begin
629- Aₕ, ΔAₕ = ad_project_hermitian_setup (A)
630- test_rrule (
631- cr_copy_project_hermitian, A, alg_h ⊢ NoTangent ();
632- output_tangent = ΔAₕ, atol = atol, rtol = rtol
633- )
634- test_rrule (
635- config, project_hermitian, A;
636- output_tangent = ΔAₕ,
637- atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
638- )
626+ alg = MatrixAlgebraKit. default_hermitian_algorithm (A)
627+ test_rrule (project_hermitian, A, alg; atol, rtol)
639628 end
640629 @testset " project_antihermitian" begin
641- Aₐ, ΔAₐ = ad_project_antihermitian_setup (A)
642- test_rrule (
643- cr_copy_project_antihermitian, A, alg_h ⊢ NoTangent ();
644- output_tangent = ΔAₐ, atol = atol, rtol = rtol
645- )
646- test_rrule (
647- config, project_antihermitian, A;
648- output_tangent = ΔAₐ,
649- atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
650- )
651- end
652- end
653- if m > n
654- @testset " project_isometric" begin
655- W, ΔW = ad_project_isometric_setup (A)
656- alg_iso = MatrixAlgebraKit. default_polar_algorithm (A)
657- test_rrule (
658- cr_copy_project_isometric, A, alg_iso ⊢ NoTangent ();
659- output_tangent = ΔW, atol = atol, rtol = rtol
660- )
661- test_rrule (
662- config, project_isometric, A;
663- output_tangent = ΔW,
664- atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
665- )
630+ alg = MatrixAlgebraKit. default_hermitian_algorithm (A)
631+ test_rrule (project_antihermitian, A, alg; atol, rtol)
666632 end
667633 end
668634 end
0 commit comments