@@ -46,7 +46,7 @@ function enz_copy_eigh_trunc_no_error!(A, DV, alg)
4646end
4747
4848function test_pullbacks_match (rng, f!, f, A, args, Δargs, alg = nothing ; ȳ = copy .(Δargs), return_act = Duplicated)
49- ΔA = randn (rng, eltype (A), size (A) ... )
49+ ΔA = randn! ( similar (A))
5050 A_ΔA () = Duplicated (copy (A), copy (ΔA))
5151 function args_Δargs ()
5252 if isnothing (args)
@@ -143,8 +143,8 @@ function test_enzyme_qr(
143143 r = min (m, n) - 5
144144 Ard = instantiate_matrix (T, (m, r)) * instantiate_matrix (T, (r, n))
145145 QR, ΔQR = ad_qr_rank_deficient_compact_setup (Ard)
146- eltype (T) <: BlasFloat && test_reverse (qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = (ΔQ, ΔR) , fdm)
147- test_pullbacks_match (rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR) , alg)
146+ eltype (T) <: BlasFloat && test_reverse (qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔQR , fdm)
147+ test_pullbacks_match (rng, qr_compact!, qr_compact, Ard, QR, ΔQR , alg)
148148 end
149149 end
150150 end
@@ -163,8 +163,8 @@ function test_enzyme_lq(
163163 @testset " lq_compact" begin
164164 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
165165 LQ, ΔLQ = ad_lq_compact_setup (A)
166- eltype (T) <: BlasFloat && test_reverse (lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ) , fdm)
167- test_pullbacks_match (rng, lq_compact!, lq_compact, A, (L, Q), (ΔL, ΔQ) , alg)
166+ eltype (T) <: BlasFloat && test_reverse (lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ , fdm)
167+ test_pullbacks_match (rng, lq_compact!, lq_compact, A, LQ, ΔLQ , alg)
168168 end
169169 end
170170 @testset " lq_null" begin
@@ -177,8 +177,8 @@ function test_enzyme_lq(
177177 @testset " lq_full" begin
178178 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
179179 LQ, ΔLQ = ad_lq_full_setup (A)
180- eltype (T) <: BlasFloat && test_reverse (lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ) , fdm)
181- test_pullbacks_match (rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ) , alg)
180+ eltype (T) <: BlasFloat && test_reverse (lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ , fdm)
181+ test_pullbacks_match (rng, lq_full!, lq_full, A, LQ, ΔLQ , alg)
182182 end
183183 end
184184 @testset " lq_compact -- rank-deficient A" begin
@@ -187,8 +187,8 @@ function test_enzyme_lq(
187187 r = min (m, n) - 5
188188 Ard = instantiate_matrix (T, (m, r)) * instantiate_matrix (T, (r, n))
189189 LQ, ΔLQ = ad_lq_rank_deficient_compact_setup (Ard)
190- eltype (T) <: BlasFloat && test_reverse (lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ) , fdm)
191- test_pullbacks_match (rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ) , alg)
190+ eltype (T) <: BlasFloat && test_reverse (lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ , fdm)
191+ test_pullbacks_match (rng, lq_compact!, lq_compact, Ard, LQ, ΔLQ , alg)
192192 end
193193 end
194194 end
@@ -209,8 +209,8 @@ function test_enzyme_eig(
209209 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
210210 DV, ΔDV, ΔD2V = ad_eig_full_setup (A)
211211 if eltype (T) <: BlasFloat
212- test_reverse (eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ( copy (ΔD2), copy (ΔV)) , fdm)
213- test_pullbacks_match (rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV) , alg)
212+ test_reverse (eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD2V , fdm)
213+ test_pullbacks_match (rng, eig_full!, eig_full, A, DV, ΔD2V , alg)
214214 else
215215 test_pullbacks_match (rng, eig_full!, eig_full, A, (nothing , nothing ), (nothing , nothing ), alg; ȳ = (ΔD2, ΔV))
216216 end
@@ -221,9 +221,9 @@ function test_enzyme_eig(
221221 D, ΔD = ad_eig_vals_setup (A)
222222 if eltype (T) <: BlasFloat
223223 test_reverse (eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = copy (ΔD2. diag), fdm)
224- test_pullbacks_match (rng, eig_vals!, eig_vals, A, D. diag, ΔD2 . diag, alg)
224+ test_pullbacks_match (rng, eig_vals!, eig_vals, A, D. diag, ΔD . diag, alg)
225225 else
226- test_pullbacks_match (rng, eig_vals!, eig_vals, A, nothing , nothing , alg; ȳ = ΔD2 . diag)
226+ test_pullbacks_match (rng, eig_vals!, eig_vals, A, nothing , nothing , alg; ȳ = ΔD . diag)
227227 end
228228 end
229229 end
@@ -233,19 +233,19 @@ function test_enzyme_eig(
233233 truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eig_algorithm (A), truncrank (r; by = abs))
234234 DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup (A, truncalg)
235235 if eltype (T) <: BlasFloat
236- test_reverse (eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc) , fdm)
237- test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc) )
236+ test_reverse (eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc , fdm)
237+ test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc )
238238 else
239- test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing , nothing ), (nothing , nothing ), truncalg, ȳ = (ΔDtrunc, ΔVtrunc) )
239+ test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing , nothing ), (nothing , nothing ), truncalg, ȳ = ΔDVtrunc )
240240 end
241241 end
242242 truncalg = TruncatedAlgorithm (MatrixAlgebraKit. default_eig_algorithm (A), truncrank (5 ; by = real))
243243 DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup (A, truncalg)
244244 if eltype (T) <: BlasFloat
245- test_reverse (eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc) , fdm)
246- test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc) )
245+ test_reverse (eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc , fdm)
246+ test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc )
247247 else
248- test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing , nothing ), (nothing , nothing ), truncalg, ȳ = (ΔDtrunc, ΔVtrunc) )
248+ test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing , nothing ), (nothing , nothing ), truncalg, ȳ = ΔDVtrunc )
249249 end
250250 end
251251 end
@@ -265,17 +265,19 @@ function test_enzyme_eigh(
265265 fdm = eltype (T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 , max_range = 1.0e-2 ) : EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 )
266266 @testset " eigh_full" begin
267267 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
268+ DV, ΔDV, ΔD2V = ad_eigh_full_setup (A)
268269 if eltype (T) <: BlasFloat
269- test_reverse (copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ( copy (ΔD2), copy (ΔV)) , fdm)
270- test_reverse (copy_eigh_full!, RT, (copy (A) , TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = ( copy (ΔD2), copy (ΔV)) , fdm)
270+ test_reverse (copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V , fdm)
271+ test_reverse (copy_eigh_full!, RT, (A , TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = ΔD2V , fdm)
271272 end
272- test_pullbacks_match (rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV) , alg)
273+ test_pullbacks_match (rng, copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V , alg)
273274 end
274275 end
275276 @testset " eigh_vals" begin
276277 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
277- eltype (T) <: BlasFloat && test_reverse (copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = copy (ΔD2. diag), fdm)
278- test_pullbacks_match (rng, copy_eigh_vals!, copy_eigh_vals, A, D. diag, ΔD2. diag, alg)
278+ D, ΔD = ad_eigh_vals_setup (A)
279+ eltype (T) <: BlasFloat && test_reverse (copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm)
280+ test_pullbacks_match (rng, copy_eigh_vals!, copy_eigh_vals, A, D, ΔD, alg)
279281 end
280282 end
281283 @testset " eigh_trunc" begin
@@ -284,14 +286,14 @@ function test_enzyme_eigh(
284286 Ddiag = diagview (D)
285287 truncalg = TruncatedAlgorithm (alg, truncrank (r; by = abs))
286288 DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup (A, truncalg)
287- eltype (T) <: BlasFloat && test_reverse (copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc) , fdm)
288- test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc) , return_act = RT)
289+ eltype (T) <: BlasFloat && test_reverse (copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc , fdm)
290+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc , return_act = RT)
289291 end
290292 D = eigh_vals (A / 2 )
291293 truncalg = TruncatedAlgorithm (alg, trunctol (; atol = maximum (abs, D) / 2 ))
292294 DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup (A, truncalg)
293- eltype (T) <: BlasFloat && test_reverse (copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc) , fdm)
294- test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc) , return_act = RT)
295+ eltype (T) <: BlasFloat && test_reverse (copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc , fdm)
296+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc , return_act = RT)
295297 end
296298 end
297299 end
@@ -312,11 +314,11 @@ function test_enzyme_svd(
312314 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
313315 USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup (A)
314316 if eltype (T) <: BlasFloat
315- test_reverse (svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (ΔU, ΔS, ΔVᴴ) , fdm)
316- test_pullbacks_match (rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) , alg)
317+ test_reverse (svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔUSVᴴ , fdm)
318+ test_pullbacks_match (rng, svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ , alg)
317319 else
318320 USVᴴ = MatrixAlgebraKit. initialize_output (svd_compact!, A, alg)
319- test_pullbacks_match (rng, svd_compact!, svd_compact, A, USVᴴ, (nothing , nothing , nothing ), alg; ȳ = (ΔU, ΔS, ΔVᴴ) )
321+ test_pullbacks_match (rng, svd_compact!, svd_compact, A, USVᴴ, (nothing , nothing , nothing ), alg; ȳ = ΔUSVᴴ )
320322 end
321323 end
322324 end
0 commit comments