@@ -264,20 +264,9 @@ function test_mooncake_eig(
264264 Mooncake. TestUtils. test_rule (rng, eig_vals, A; is_primitive = false , mode = Mooncake. ReverseMode, output_tangent = dD, atol = atol, rtol = rtol)
265265 test_pullbacks_match (eig_vals!, eig_vals, A, D, ΔD)
266266 end
267- if T <: Number # not a GPU array
268- @testset " eig_trunc" begin
269- for r in 1 : 4 : m
270- truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eig_algorithm (A), truncrank (r; by = abs))
271- DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup (A, truncalg)
272- ϵ = zero (real (T))
273- dDVerr = make_mooncake_tangent ((ΔDVtrunc... , ϵ))
274- Mooncake. TestUtils. test_rule (rng, eig_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol)
275- test_pullbacks_match (eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
276- dDVtrunc = make_mooncake_tangent (ΔDVtrunc)
277- Mooncake. TestUtils. test_rule (rng, eig_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol)
278- test_pullbacks_match (eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
279- end
280- truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eig_algorithm (A), truncrank (5 ; by = real))
267+ @testset " eig_trunc" begin
268+ for r in 1 : 4 : m
269+ truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eig_algorithm (A), truncrank (r; by = abs))
281270 DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup (A, truncalg)
282271 ϵ = zero (real (T))
283272 dDVerr = make_mooncake_tangent ((ΔDVtrunc... , ϵ))
@@ -287,6 +276,15 @@ function test_mooncake_eig(
287276 Mooncake. TestUtils. test_rule (rng, eig_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol)
288277 test_pullbacks_match (eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
289278 end
279+ truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eig_algorithm (A), truncrank (5 ; by = real))
280+ DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup (A, truncalg)
281+ ϵ = zero (real (T))
282+ dDVerr = make_mooncake_tangent ((ΔDVtrunc... , ϵ))
283+ Mooncake. TestUtils. test_rule (rng, eig_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol)
284+ test_pullbacks_match (eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
285+ dDVtrunc = make_mooncake_tangent (ΔDVtrunc)
286+ Mooncake. TestUtils. test_rule (rng, eig_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol)
287+ test_pullbacks_match (eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg)
290288 end
291289 end
292290end
@@ -312,21 +310,9 @@ function test_mooncake_eigh(
312310 Mooncake. TestUtils. test_rule (rng, mc_copy_eigh_vals, A; mode = Mooncake. ReverseMode, output_tangent = dD, is_primitive = false , atol = atol, rtol = rtol)
313311 test_pullbacks_match (mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD)
314312 end
315- if T <: Number
316- @testset " eigh_trunc" begin
317- for r in 1 : 4 : m
318- truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eigh_algorithm (A), truncrank (r; by = abs))
319- DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup (A, truncalg)
320- ϵ = zero (real (T))
321- dDVerr = make_mooncake_tangent ((ΔDVtrunc... , ϵ))
322- Mooncake. TestUtils. test_rule (rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false )
323- test_pullbacks_match (mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
324- dDVtrunc = make_mooncake_tangent (ΔDVtrunc)
325- Mooncake. TestUtils. test_rule (rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
326- test_pullbacks_match (mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg)
327- end
328- D = eigh_vals (A / 2 )
329- truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eigh_algorithm (A), trunctol (; atol = maximum (abs, D) / 2 ))
313+ @testset " eigh_trunc" begin
314+ for r in 1 : 4 : m
315+ truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eigh_algorithm (A), truncrank (r; by = abs))
330316 DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup (A, truncalg)
331317 ϵ = zero (real (T))
332318 dDVerr = make_mooncake_tangent ((ΔDVtrunc... , ϵ))
@@ -336,6 +322,16 @@ function test_mooncake_eigh(
336322 Mooncake. TestUtils. test_rule (rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
337323 test_pullbacks_match (mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg)
338324 end
325+ D = eigh_vals (A / 2 )
326+ truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eigh_algorithm (A), trunctol (; atol = maximum (abs, D) / 2 ))
327+ DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup (A, truncalg)
328+ ϵ = zero (real (T))
329+ dDVerr = make_mooncake_tangent ((ΔDVtrunc... , ϵ))
330+ Mooncake. TestUtils. test_rule (rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false )
331+ test_pullbacks_match (mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
332+ dDVtrunc = make_mooncake_tangent (ΔDVtrunc)
333+ Mooncake. TestUtils. test_rule (rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
334+ test_pullbacks_match (mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg)
339335 end
340336 end
341337end
@@ -366,31 +362,29 @@ function test_mooncake_svd(
366362 Mooncake. TestUtils. test_rule (rng, svd_vals, A; is_primitive = false , mode = Mooncake. ReverseMode, atol = atol, rtol = rtol)
367363 test_pullbacks_match (svd_vals!, svd_vals, A, S, ΔS)
368364 end
369- if T <: Number # not a GPU array
370- @testset " svd_trunc" begin
371- S, ΔS = ad_svd_vals_setup (A)
372- @testset for r in 1 : 4 : minmn
373- truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_svd_algorithm (A), truncrank (r))
374- USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup (A, truncalg)
375- ϵ = zero (real (T))
376- dUSVᴴerr = make_mooncake_tangent ((ΔUSVᴴtrunc... , ϵ))
377- Mooncake. TestUtils. test_rule (rng, svd_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
378- test_pullbacks_match (svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
379- dUSVᴴ = make_mooncake_tangent (ΔUSVᴴtrunc)
380- Mooncake. TestUtils. test_rule (rng, svd_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
381- test_pullbacks_match (svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
382- end
383- @testset " trunctol" begin
384- truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_svd_algorithm (A), trunctol (atol = S[1 , 1 ] / 2 ))
385- USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup (A, truncalg)
386- ϵ = zero (real (T))
387- dUSVᴴerr = make_mooncake_tangent ((ΔUSVᴴtrunc... , ϵ))
388- Mooncake. TestUtils. test_rule (rng, svd_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
389- test_pullbacks_match (svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
390- dUSVᴴ = make_mooncake_tangent (ΔUSVᴴtrunc)
391- Mooncake. TestUtils. test_rule (rng, svd_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
392- test_pullbacks_match (svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
393- end
365+ @testset " svd_trunc" begin
366+ S, ΔS = ad_svd_vals_setup (A)
367+ @testset for r in 1 : 4 : minmn
368+ truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_svd_algorithm (A), truncrank (r))
369+ USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup (A, truncalg)
370+ ϵ = zero (real (T))
371+ dUSVᴴerr = make_mooncake_tangent ((ΔUSVᴴtrunc... , ϵ))
372+ Mooncake. TestUtils. test_rule (rng, svd_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
373+ test_pullbacks_match (svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
374+ dUSVᴴ = make_mooncake_tangent (ΔUSVᴴtrunc)
375+ Mooncake. TestUtils. test_rule (rng, svd_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
376+ test_pullbacks_match (svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
377+ end
378+ @testset " trunctol" begin
379+ truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_svd_algorithm (A), trunctol (atol = S[1 , 1 ] / 2 ))
380+ USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup (A, truncalg)
381+ ϵ = zero (real (T))
382+ dUSVᴴerr = make_mooncake_tangent ((ΔUSVᴴtrunc... , ϵ))
383+ Mooncake. TestUtils. test_rule (rng, svd_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
384+ test_pullbacks_match (svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
385+ dUSVᴴ = make_mooncake_tangent (ΔUSVᴴtrunc)
386+ Mooncake. TestUtils. test_rule (rng, svd_trunc_no_error, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol)
387+ test_pullbacks_match (svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg)
394388 end
395389 end
396390 end
0 commit comments