9393 Q = randn (rng, T, m, minmn)
9494 R = randn (rng, T, minmn, n)
9595 Mooncake. TestUtils. test_rule (rng, qr_compact, A, alg; mode = Mooncake. ReverseMode, is_primitive = false , atol = atol, rtol = rtol)
96- qr_c = (A, QR, alg) -> qr_compact! (MatrixAlgebraKit. copy_input (qr_compact, A), QR, alg)
97- Mooncake. TestUtils. test_rule (rng, qr_c, A, (Q, R), alg; mode = Mooncake. ReverseMode, is_primitive = false , atol = atol, rtol = rtol)
9896 test_pullbacks_match (rng, qr_compact!, qr_compact, A, (Q, R), (randn (rng, T, m, minmn), randn (rng, T, minmn, n)), alg)
9997 end
10098 @testset " qr_null" begin
236234 dVtrunc = make_mooncake_tangent (ΔVtrunc)
237235 dDVtrunc = Mooncake. build_tangent (typeof ((ΔDtrunc, ΔVtrunc, zero (real (T)))), dDtrunc, dVtrunc, zero (real (T)))
238236 Mooncake. TestUtils. test_rule (rng, eig_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
239- dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
240- dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
241- @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
242237 test_pullbacks_match (rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
243238 end
244239 truncalg = TruncatedAlgorithm (alg, truncrank (5 ; by = real))
251246 dVtrunc = make_mooncake_tangent (ΔVtrunc)
252247 dDVtrunc = Mooncake. build_tangent (typeof ((ΔDtrunc, ΔVtrunc, zero (real (T)))), dDtrunc, dVtrunc, zero (real (T)))
253248 Mooncake. TestUtils. test_rule (rng, eig_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
254- dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
255- dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
256- @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
257249 test_pullbacks_match (rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
258250 end
259251 end
@@ -334,9 +326,6 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop
334326 dVtrunc = make_mooncake_tangent (ΔVtrunc)
335327 dDVtrunc = Mooncake. build_tangent (typeof ((ΔDtrunc, ΔVtrunc, zero (real (T)))), dDtrunc, dVtrunc, zero (real (T)))
336328 Mooncake. TestUtils. test_rule (rng, copy_eigh_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
337- dA1 = MatrixAlgebraKit. eigh_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
338- dA2 = MatrixAlgebraKit. eigh_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
339- @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
340329 test_pullbacks_match (rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
341330 end
342331 truncalg = TruncatedAlgorithm (alg, trunctol (; atol = maximum (abs, Ddiag) / 2 ))
@@ -349,9 +338,6 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop
349338 dVtrunc = make_mooncake_tangent (ΔVtrunc)
350339 dDVtrunc = Mooncake. build_tangent (typeof ((ΔDtrunc, ΔVtrunc, zero (real (T)))), dDtrunc, dVtrunc, zero (real (T)))
351340 Mooncake. TestUtils. test_rule (rng, copy_eigh_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
352- dA1 = MatrixAlgebraKit. eigh_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
353- dA2 = MatrixAlgebraKit. eigh_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
354- @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
355341 test_pullbacks_match (rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
356342 end
357343 end
429415 dVᴴtrunc = make_mooncake_tangent (ΔVᴴtrunc)
430416 ϵ = zero (real (T))
431417 dUSVᴴerr = Mooncake. build_tangent (typeof ((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ)
432- dA1 = MatrixAlgebraKit. svd_pullback! (zero (A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
433- dA2 = MatrixAlgebraKit. svd_trunc_pullback! (zero (A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
434- @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
435418 Mooncake. TestUtils. test_rule (rng, svd_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
436419 test_pullbacks_match (rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
437420 end
455438 dVᴴtrunc = make_mooncake_tangent (ΔVᴴtrunc)
456439 ϵ = zero (real (T))
457440 dUSVᴴerr = Mooncake. build_tangent (typeof ((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ)
458- dA1 = MatrixAlgebraKit. svd_pullback! (zero (A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
459- dA2 = MatrixAlgebraKit. svd_trunc_pullback! (zero (A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
460- @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
461441 Mooncake. TestUtils. test_rule (rng, svd_trunc, A, truncalg; mode = Mooncake. ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
462442 test_pullbacks_match (rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
463443 end
0 commit comments