Skip to content

Commit e0c709b

Browse files
committed
simplify chainrules tests
1 parent fbfcbcc commit e0c709b

2 files changed

Lines changed: 4 additions & 62 deletions

File tree

test/testsuite/ad_utils.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -421,27 +421,3 @@ function ad_right_null_setup(A)
421421
ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2]
422422
return Nᴴ, ΔNᴴ
423423
end
424-
425-
function ad_project_hermitian_setup(A)
426-
m, n = size(A)
427-
T = eltype(A)
428-
Aₕ = project_hermitian(A)
429-
ΔAₕ = randn!(similar(A, T, m, n))
430-
return Aₕ, ΔAₕ
431-
end
432-
433-
function ad_project_antihermitian_setup(A)
434-
m, n = size(A)
435-
T = eltype(A)
436-
Aₐ = project_antihermitian(A)
437-
ΔAₐ = randn!(similar(A, T, m, n))
438-
return Aₐ, ΔAₐ
439-
end
440-
441-
function ad_project_isometric_setup(A)
442-
m, n = size(A)
443-
T = eltype(A)
444-
W = project_isometric(A)
445-
ΔW = randn!(similar(A, T, m, n))
446-
return W, ΔW
447-
end

test/testsuite/chainrules.jl

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ for f in
1010
:eig_trunc_no_error, :eigh_trunc_no_error,
1111
:svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals,
1212
:left_polar, :right_polar,
13-
:project_hermitian, :project_antihermitian, :project_isometric,
1413
)
1514
copy_f = Symbol(:cr_copy_, f)
1615
f! = Symbol(f, '!')
@@ -622,47 +621,14 @@ function test_chainrules_projections(
622621
return @testset "Projections Chainrules AD rules $summary_str" begin
623622
A = instantiate_matrix(T, sz)
624623
m, n = size(A)
625-
config = Zygote.ZygoteRuleConfig()
626624
if m == n
627-
alg_h = MatrixAlgebraKit.default_hermitian_algorithm(A)
628625
@testset "project_hermitian" begin
629-
Aₕ, ΔAₕ = ad_project_hermitian_setup(A)
630-
test_rrule(
631-
cr_copy_project_hermitian, A, alg_h NoTangent();
632-
output_tangent = ΔAₕ, atol = atol, rtol = rtol
633-
)
634-
test_rrule(
635-
config, project_hermitian, A;
636-
output_tangent = ΔAₕ,
637-
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
638-
)
626+
alg = MatrixAlgebraKit.default_hermitian_algorithm(A)
627+
test_rrule(project_hermitian, A, alg; atol, rtol)
639628
end
640629
@testset "project_antihermitian" begin
641-
Aₐ, ΔAₐ = ad_project_antihermitian_setup(A)
642-
test_rrule(
643-
cr_copy_project_antihermitian, A, alg_h NoTangent();
644-
output_tangent = ΔAₐ, atol = atol, rtol = rtol
645-
)
646-
test_rrule(
647-
config, project_antihermitian, A;
648-
output_tangent = ΔAₐ,
649-
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
650-
)
651-
end
652-
end
653-
if m > n
654-
@testset "project_isometric" begin
655-
W, ΔW = ad_project_isometric_setup(A)
656-
alg_iso = MatrixAlgebraKit.default_polar_algorithm(A)
657-
test_rrule(
658-
cr_copy_project_isometric, A, alg_iso NoTangent();
659-
output_tangent = ΔW, atol = atol, rtol = rtol
660-
)
661-
test_rrule(
662-
config, project_isometric, A;
663-
output_tangent = ΔW,
664-
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
665-
)
630+
alg = MatrixAlgebraKit.default_hermitian_algorithm(A)
631+
test_rrule(project_antihermitian, A, alg; atol, rtol)
666632
end
667633
end
668634
end

0 commit comments

Comments
 (0)