@@ -11,7 +11,6 @@ is_ci = get(ENV, "CI", "false") == "true"
1111
1212ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631
1313include (" ad_utils.jl" )
14-
1514function test_pullbacks_match (rng, f!, f, A, args, Δargs, alg = nothing ; ȳ = copy .(Δargs), return_act = Duplicated)
1615 ΔA = randn (rng, eltype (A), size (A)... )
1716 A_ΔA () = Duplicated (copy (A), copy (ΔA))
@@ -46,7 +45,7 @@ function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ =
4645 end
4746 return
4847end
49-
48+ #=
5049@timedtestset "QR AD Rules with eltype $T" for T in ETs
5150 rng = StableRNG(12345)
5251 m = 19
190189 ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
191190 ΔVtrunc = ΔV[:, ind]
192191 # broken due to Enzyme
193- # test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
192+ test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
194193 # broken due to Enzyme
195- # test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg, (ΔDtrunc, ΔVtrunc, zero(real(T))))
194+ test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ= (ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT )
196195 dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
197196 dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
198197 @test isapprox(dA1, dA2; atol = atol, rtol = rtol)
204203 ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
205204 ΔVtrunc = ΔV[:, ind]
206205 # broken due to Enzyme
207- # test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
206+ test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
208207 # broken due to Enzyme
209- # test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=MixedDuplicated )
208+ test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT )
210209 dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
211210 dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
212211 @test isapprox(dA1, dA2; atol = atol, rtol = rtol)
309308 ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
310309 ΔVtrunc = ΔV[:, ind]
311310 # broken due to Enzyme
312- # test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
313- # test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
311+ test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
312+ test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
314313 dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
315314 dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
316315 @test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -323,15 +322,15 @@ end
323322 ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
324323 ΔVtrunc = ΔV[:, ind]
325324 # broken due to Enzyme
326- # test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
327- # test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T)), return_act=RT) )
325+ test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
326+ test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))) , return_act=RT)
328327 dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
329328 dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
330329 @test isapprox(dA1, dA2; atol = atol, rtol = rtol)
331330 end
332331 end
333332end
334-
333+ =#
335334@timedtestset " SVD AD Rules with eltype $T " for T in ETs
336335 rng = StableRNG (12345 )
337336 m = 19
343342 LAPACK_QRIteration (),
344343 LAPACK_DivideAndConquer (),
345344 )
346- @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
345+ #= @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
347346 @testset "svd_compact" begin
348347 U, S, Vᴴ = svd_compact(A)
349348 ΔU = randn(rng, T, m, minmn)
@@ -378,39 +377,16 @@ end
378377 test_reverse(svd_vals, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm = fdm)
379378 test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg)
380379 end
381- end
380+ end=#
382381 @testset " svd_trunc reverse: RT $RT , TA $TA " for RT in (MixedDuplicated,), TA in (Duplicated,)
383- @testset " svd_trunc" begin
384- for r in 1 : 4 : minmn
385- U, S, Vᴴ = svd_compact (A)
386- ΔU = randn (rng, T, m, minmn)
387- ΔS = randn (rng, real (T), minmn, minmn)
388- ΔS2 = Diagonal (randn (rng, real (T), minmn))
389- ΔVᴴ = randn (rng, T, minmn, n)
390- ΔU, ΔVᴴ = remove_svdgauge_dependence! (ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
391- truncalg = TruncatedAlgorithm (alg, truncrank (r))
392- ind = MatrixAlgebraKit. findtruncated (diagview (S), truncalg. trunc)
393- Strunc = Diagonal (diagview (S)[ind])
394- Utrunc = U[:, ind]
395- Vᴴtrunc = Vᴴ[ind, :]
396- ΔStrunc = Diagonal (diagview (ΔS2)[ind])
397- ΔUtrunc = ΔU[:, ind]
398- ΔVᴴtrunc = ΔVᴴ[ind, :]
399- fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 , max_range = 1.0e-2 ) : EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 )
400- # broken due to Enzyme
401- # test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
402- # test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
403- dA1 = MatrixAlgebraKit. svd_pullback! (zero (A), copy (A), (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
404- dA2 = MatrixAlgebraKit. svd_trunc_pullback! (zero (A), copy (A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
405- @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
406- end
382+ for r in 1 : 4 : minmn
407383 U, S, Vᴴ = svd_compact (A)
408384 ΔU = randn (rng, T, m, minmn)
409385 ΔS = randn (rng, real (T), minmn, minmn)
410386 ΔS2 = Diagonal (randn (rng, real (T), minmn))
411387 ΔVᴴ = randn (rng, T, minmn, n)
412388 ΔU, ΔVᴴ = remove_svdgauge_dependence! (ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
413- truncalg = TruncatedAlgorithm (alg, trunctol (atol = S[ 1 , 1 ] / 2 ))
389+ truncalg = TruncatedAlgorithm (alg, truncrank (r ))
414390 ind = MatrixAlgebraKit. findtruncated (diagview (S), truncalg. trunc)
415391 Strunc = Diagonal (diagview (S)[ind])
416392 Utrunc = U[:, ind]
@@ -419,18 +395,33 @@ end
419395 ΔUtrunc = ΔU[:, ind]
420396 ΔVᴴtrunc = ΔVᴴ[ind, :]
421397 fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 , max_range = 1.0e-2 ) : EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 )
422- # broken due to Enzyme
423- # test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
424- # test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
425- dA1 = MatrixAlgebraKit. svd_pullback! (zero (A), copy (A), (copy (U), copy (S), copy (Vᴴ)), (copy (ΔUtrunc), copy (ΔStrunc), copy (ΔVᴴtrunc)), ind)
426- dA2 = MatrixAlgebraKit. svd_trunc_pullback! (zero (A), copy (A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
427- @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
398+ # broken due to Enzyme -- copying in gaugefix????
399+ test_reverse (svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero (real (T))), fdm = fdm)
400+ test_pullbacks_match (rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero (real (T))), return_act= RT)
428401 end
402+ U, S, Vᴴ = svd_compact (A)
403+ ΔU = randn (rng, T, m, minmn)
404+ ΔS = randn (rng, real (T), minmn, minmn)
405+ ΔS2 = Diagonal (randn (rng, real (T), minmn))
406+ ΔVᴴ = randn (rng, T, minmn, n)
407+ ΔU, ΔVᴴ = remove_svdgauge_dependence! (ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
408+ truncalg = TruncatedAlgorithm (alg, trunctol (atol = S[1 , 1 ] / 2 ))
409+ ind = MatrixAlgebraKit. findtruncated (diagview (S), truncalg. trunc)
410+ Strunc = Diagonal (diagview (S)[ind])
411+ Utrunc = U[:, ind]
412+ Vᴴtrunc = Vᴴ[ind, :]
413+ ΔStrunc = Diagonal (diagview (ΔS2)[ind])
414+ ΔUtrunc = ΔU[:, ind]
415+ ΔVᴴtrunc = ΔVᴴ[ind, :]
416+ fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 , max_range = 1.0e-2 ) : EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 )
417+ # broken due to Enzyme
418+ test_reverse (svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero (real (T))), fdm = fdm)
419+ test_pullbacks_match (rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero (real (T))), return_act= RT)
429420 end
430421 end
431422 end
432423end
433-
424+ #=
434425@timedtestset "Polar AD Rules with eltype $T" for T in ETs
435426 rng = StableRNG(12345)
436427 m = 19
513504 end
514505 end
515506end
507+ =#
0 commit comments