Skip to content

Commit 187d5f2

Browse files
committed
simplify mooncake tests
1 parent e0c709b commit 187d5f2

2 files changed

Lines changed: 10 additions & 17 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ for (f!, f, adj) in (
796796
dA .+= $f(darg)
797797
dA === darg || zero!(darg)
798798
copy!(arg, argc)
799-
return NoRData(), NoRData(), NoRData(), NoRData()
799+
return ntuple(Returns(NoRData()), 4)
800800
end
801801
return arg_darg, $adj
802802
end
@@ -811,10 +811,11 @@ for (f!, f, adj) in (
811811
function $adj(::NoRData)
812812
# TODO: need accumulating projection to avoid intermediate here
813813
dA .+= $f(doutput)
814-
return ntuple(Returns(NoRData(), 3))
814+
zero!(doutput)
815+
return ntuple(Returns(NoRData()), 3)
815816
end
816817

817-
return output_codual, $adj
818+
return output_doutput, $adj
818819
end
819820
end
820821
end

test/testsuite/mooncake.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -550,25 +550,17 @@ function test_mooncake_projections(
550550
m, n = size(A)
551551
if m == n
552552
@testset "project_hermitian" begin
553-
Aₕ, ΔAₕ = ad_project_hermitian_setup(A)
554-
dAₕ = make_mooncake_tangent(ΔAₕ)
555-
Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₕ, atol, rtol)
553+
Aₕ = project_hermitian(A)
554+
ΔAₕ = make_mooncake_tangent(Aₕ)
555+
Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol)
556556
test_pullbacks_match(project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ)
557557
end
558558
@testset "project_antihermitian" begin
559-
Aₐ, ΔAₐ = ad_project_antihermitian_setup(A)
560-
dAₐ = make_mooncake_tangent(ΔAₐ)
561-
Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₐ, atol, rtol)
559+
Aₐ = project_antihermitian(A)
560+
ΔAₐ = make_mooncake_tangent(Aₐ)
561+
Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol)
562562
test_pullbacks_match(project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ)
563563
end
564564
end
565-
if m > n
566-
@testset "project_isometric" begin
567-
W, ΔW = ad_project_isometric_setup(A)
568-
dW = make_mooncake_tangent(ΔW)
569-
Mooncake.TestUtils.test_rule(rng, project_isometric, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dW, atol, rtol)
570-
test_pullbacks_match(project_isometric!, project_isometric, A, W, ΔW)
571-
end
572-
end
573565
end
574566
end

0 commit comments

Comments
 (0)