Skip to content

Commit fad3f4a

Browse files
committed
Test trunc on GPU too
1 parent 295d5ad commit fad3f4a

1 file changed

Lines changed: 48 additions & 54 deletions

File tree

test/testsuite/mooncake.jl

Lines changed: 48 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -264,20 +264,9 @@ function test_mooncake_eig(
264264
Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dD, atol = atol, rtol = rtol)
265265
test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD)
266266
end
267-
if T <: Number # not a GPU array
268-
@testset "eig_trunc" begin
269-
for r in 1:4:m
270-
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs))
271-
DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg)
272-
ϵ = zero(real(T))
273-
dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ))
274-
Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol)
275-
test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
276-
dDVtrunc = make_mooncake_tangent(ΔDVtrunc)
277-
Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol)
278-
test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
279-
end
280-
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real))
267+
@testset "eig_trunc" begin
268+
for r in 1:4:m
269+
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs))
281270
DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg)
282271
ϵ = zero(real(T))
283272
dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ))
@@ -287,6 +276,15 @@ function test_mooncake_eig(
287276
Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol)
288277
test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
289278
end
279+
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real))
280+
DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg)
281+
ϵ = zero(real(T))
282+
dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ))
283+
Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol)
284+
test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
285+
dDVtrunc = make_mooncake_tangent(ΔDVtrunc)
286+
Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol)
287+
test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
290288
end
291289
end
292290
end
@@ -312,21 +310,9 @@ function test_mooncake_eigh(
312310
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol = atol, rtol = rtol)
313311
test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD)
314312
end
315-
if T <: Number
316-
@testset "eigh_trunc" begin
317-
for r in 1:4:m
318-
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs))
319-
DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg)
320-
ϵ = zero(real(T))
321-
dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ))
322-
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false)
323-
test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
324-
dDVtrunc = make_mooncake_tangent(ΔDVtrunc)
325-
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
326-
test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg)
327-
end
328-
D = eigh_vals(A / 2)
329-
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2))
313+
@testset "eigh_trunc" begin
314+
for r in 1:4:m
315+
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs))
330316
DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg)
331317
ϵ = zero(real(T))
332318
dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ))
@@ -336,6 +322,16 @@ function test_mooncake_eigh(
336322
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
337323
test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg)
338324
end
325+
D = eigh_vals(A / 2)
326+
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2))
327+
DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg)
328+
ϵ = zero(real(T))
329+
dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ))
330+
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false)
331+
test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
332+
dDVtrunc = make_mooncake_tangent(ΔDVtrunc)
333+
Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
334+
test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg)
339335
end
340336
end
341337
end
@@ -366,31 +362,29 @@ function test_mooncake_svd(
366362
Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol)
367363
test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS)
368364
end
369-
if T <: Number # not a GPU array
370-
@testset "svd_trunc" begin
371-
S, ΔS = ad_svd_vals_setup(A)
372-
@testset for r in 1:4:minmn
373-
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r))
374-
USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
375-
ϵ = zero(real(T))
376-
dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ))
377-
Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
378-
test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
379-
dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc)
380-
Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
381-
test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
382-
end
383-
@testset "trunctol" begin
384-
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2))
385-
USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
386-
ϵ = zero(real(T))
387-
dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ))
388-
Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
389-
test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
390-
dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc)
391-
Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
392-
test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
393-
end
365+
@testset "svd_trunc" begin
366+
S, ΔS = ad_svd_vals_setup(A)
367+
@testset for r in 1:4:minmn
368+
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r))
369+
USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
370+
ϵ = zero(real(T))
371+
dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ))
372+
Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
373+
test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
374+
dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc)
375+
Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
376+
test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
377+
end
378+
@testset "trunctol" begin
379+
truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2))
380+
USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
381+
ϵ = zero(real(T))
382+
dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ))
383+
Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
384+
test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
385+
dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc)
386+
Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
387+
test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
394388
end
395389
end
396390
end

0 commit comments

Comments
 (0)