179179 end
180180end
181181
182- #=
183182@timedtestset " LQ AD Rules with eltype $T " for T in ETs
184183 rng = StableRNG (12345 )
185184 m = 19
@@ -193,57 +192,50 @@ end
193192 )
194193 @testset " lq_compact" begin
195194 L, Q = lq_compact (A, alg)
196- Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; is_primitive = false, atol = atol, rtol = rtol)
195+ Mooncake. TestUtils. test_rule (rng, lq_compact, A, alg; atol = atol, rtol = rtol)
197196 test_pullbacks_match (rng, lq_compact!, lq_compact, A, (L, Q), (randn (rng, T, m, minmn), randn (rng, T, minmn, n)), alg)
198- ΔL = randn(rng, T, m, minmn)
199- ΔQ = randn(rng, T, minmn, n)
200- dL = make_mooncake_tangent(ΔL)
201- dQ = make_mooncake_tangent(ΔQ)
202- dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ)
203- Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; is_primitive=false, atol=atol, rtol=rtol, output_tangent = dLQ)
204197 end
205198 @testset " lq_null" begin
206199 L, Q = lq_compact (A, alg)
207- ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q
208- Nᴴ = randn(rng, T, max(0, n - minmn), n)
209- dNᴴ = make_mooncake_tangent(ΔNᴴ)
210- Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; output_tangent = dNᴴ, is_primitive = false, atol = atol, rtol = rtol)
200+ ΔNᴴ = randn (rng, T, max (0 , n - minmn), minmn) * Q
201+ Nᴴ = randn (rng, T, max (0 , n - minmn), n)
202+ dNᴴ = make_mooncake_tangent (ΔNᴴ)
203+ Mooncake. TestUtils. test_rule (rng, lq_null, A, alg; output_tangent = dNᴴ, atol = atol, rtol = rtol)
211204 test_pullbacks_match (rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg)
212205 end
213206 @testset " lq_full" begin
214207 L, Q = lq_full (A, alg)
215- Q1 = view(Q, 1:minmn, 1:n)
216- ΔQ = randn(rng, T, n, n)
217- ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n)
208+ Q1 = view (Q, 1 : minmn, 1 : n)
209+ ΔQ = randn (rng, T, n, n)
210+ ΔQ2 = view (ΔQ, (minmn + 1 ): n, 1 : n)
218211 mul! (ΔQ2, ΔQ2 * Q1' , Q1)
219- ΔL = randn(rng, T, m, n)
220- dL = make_mooncake_tangent(ΔL)
221- dQ = make_mooncake_tangent(ΔQ)
222- dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
223- Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol)
212+ ΔL = randn (rng, T, m, n)
213+ dL = make_mooncake_tangent (ΔL)
214+ dQ = make_mooncake_tangent (ΔQ)
215+ dLQ = Mooncake. build_tangent (typeof ((ΔL, ΔQ)), dL, dQ)
216+ Mooncake. TestUtils. test_rule (rng, lq_full, A, alg; output_tangent = dLQ, atol = atol, rtol = rtol)
224217 test_pullbacks_match (rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg)
225218 end
226219 @testset " lq_compact - rank-deficient A" begin
227- r = minmn - 5
228- Ard = randn(rng, T, m, r) * randn(rng, T, r, n)
220+ r = minmn - 5
221+ Ard = randn (rng, T, m, r) * randn (rng, T, r, n)
229222 L, Q = lq_compact (Ard, alg)
230- ΔL = randn(rng, T, m, minmn)
231- ΔQ = randn(rng, T, minmn, n)
232- Q1 = view(Q, 1:r, 1:n)
233- Q2 = view(Q, (r + 1):minmn, 1:n)
234- ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n)
223+ ΔL = randn (rng, T, m, minmn)
224+ ΔQ = randn (rng, T, minmn, n)
225+ Q1 = view (Q, 1 : r, 1 : n)
226+ Q2 = view (Q, (r + 1 ): minmn, 1 : n)
227+ ΔQ2 = view (ΔQ, (r + 1 ): minmn, 1 : n)
235228 ΔQ2 .= 0
236229 view (ΔL, :, (r + 1 ): minmn) .= 0
237- dL = make_mooncake_tangent(ΔL)
238- dQ = make_mooncake_tangent(ΔQ)
239- dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
240- Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol)
230+ dL = make_mooncake_tangent (ΔL)
231+ dQ = make_mooncake_tangent (ΔQ)
232+ dLQ = Mooncake. build_tangent (typeof ((ΔL, ΔQ)), dL, dQ)
233+ Mooncake. TestUtils. test_rule (rng, lq_compact, Ard, alg; output_tangent = dLQ, atol = atol, rtol = rtol)
241234 test_pullbacks_match (rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg)
242235 end
243236 end
244237 end
245238end
246- =#
247239
248240@timedtestset " EIG AD Rules with eltype $T " for T in ETs
249241 rng = StableRNG (12345 )
283275 dDtrunc = make_mooncake_tangent (ΔDtrunc)
284276 dVtrunc = make_mooncake_tangent (ΔVtrunc)
285277 dDVtrunc = Mooncake. build_tangent (typeof ((ΔDtrunc, ΔVtrunc, zero (real (T)))), dDtrunc, dVtrunc, zero (real (T)))
286- Mooncake. TestUtils. test_rule (rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
278+ Mooncake. TestUtils. test_rule (rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol)
287279 test_pullbacks_match (rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
288280 end
289281 truncalg = TruncatedAlgorithm (alg, truncrank (5 ; by = real))
295287 dDtrunc = make_mooncake_tangent (ΔDtrunc)
296288 dVtrunc = make_mooncake_tangent (ΔVtrunc)
297289 dDVtrunc = Mooncake. build_tangent (typeof ((ΔDtrunc, ΔVtrunc, zero (real (T)))), dDtrunc, dVtrunc, zero (real (T)))
298- Mooncake. TestUtils. test_rule (rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
290+ Mooncake. TestUtils. test_rule (rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol)
299291 test_pullbacks_match (rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
300292 end
301293 end
@@ -357,11 +349,11 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop
357349 LAPACK_MultipleRelativelyRobustRepresentations (),
358350 )
359351 @testset " eigh_full" begin
360- Mooncake. TestUtils. test_rule (rng, copy_eigh_full, A, alg; mode = Mooncake . ReverseMode, output_tangent = dDV, is_primitive = false , atol = atol, rtol = rtol)
352+ Mooncake. TestUtils. test_rule (rng, copy_eigh_full, A, alg; output_tangent = dDV, is_primitive = false , atol = atol, rtol = rtol)
361353 test_pullbacks_match (rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg)
362354 end
363355 @testset " eigh_vals" begin
364- Mooncake. TestUtils. test_rule (rng, copy_eigh_vals, A, alg; mode = Mooncake . ReverseMode, is_primitive = false , atol = atol, rtol = rtol)
356+ Mooncake. TestUtils. test_rule (rng, copy_eigh_vals, A, alg; is_primitive = false , atol = atol, rtol = rtol)
365357 test_pullbacks_match (rng, copy_eigh_vals!, copy_eigh_vals, A, D. diag, ΔD2. diag, alg)
366358 end
367359 @testset " eigh_trunc" begin
@@ -375,7 +367,7 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop
375367 dDtrunc = make_mooncake_tangent (ΔDtrunc)
376368 dVtrunc = make_mooncake_tangent (ΔVtrunc)
377369 dDVtrunc = Mooncake. build_tangent (typeof ((ΔDtrunc, ΔVtrunc, zero (real (T)))), dDtrunc, dVtrunc, zero (real (T)))
378- Mooncake. TestUtils. test_rule (rng, copy_eigh_trunc, A, truncalg; mode = Mooncake . ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
370+ Mooncake. TestUtils. test_rule (rng, copy_eigh_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
379371 test_pullbacks_match (rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
380372 end
381373 truncalg = TruncatedAlgorithm (alg, trunctol (; atol = maximum (abs, Ddiag) / 2 ))
@@ -387,7 +379,7 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop
387379 dDtrunc = make_mooncake_tangent (ΔDtrunc)
388380 dVtrunc = make_mooncake_tangent (ΔVtrunc)
389381 dDVtrunc = Mooncake. build_tangent (typeof ((ΔDtrunc, ΔVtrunc, zero (real (T)))), dDtrunc, dVtrunc, zero (real (T)))
390- Mooncake. TestUtils. test_rule (rng, copy_eigh_trunc, A, truncalg; mode = Mooncake . ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
382+ Mooncake. TestUtils. test_rule (rng, copy_eigh_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false )
391383 test_pullbacks_match (rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake. NoRData (), Mooncake. NoRData (), zero (real (T))))
392384 end
393385 end
@@ -523,12 +515,12 @@ right_orth_lq(X) = right_orth(X; alg = :lq)
523515right_orth_polar (X) = right_orth (X; alg = :polar )
524516right_null_lq (X) = right_null (X; alg = :lq )
525517
526- MatrixAlgebraKit. copy_input (:: typeof (left_orth_qr), A) = MatrixAlgebraKit. copy_input (left_orth, A)
527- MatrixAlgebraKit. copy_input (:: typeof (left_orth_polar), A) = MatrixAlgebraKit. copy_input (left_orth, A)
528- MatrixAlgebraKit. copy_input (:: typeof (left_null_qr), A) = MatrixAlgebraKit. copy_input (left_null, A)
529- MatrixAlgebraKit. copy_input (:: typeof (right_orth_lq), A) = MatrixAlgebraKit. copy_input (right_orth, A)
518+ MatrixAlgebraKit. copy_input (:: typeof (left_orth_qr), A) = MatrixAlgebraKit. copy_input (left_orth, A)
519+ MatrixAlgebraKit. copy_input (:: typeof (left_orth_polar), A) = MatrixAlgebraKit. copy_input (left_orth, A)
520+ MatrixAlgebraKit. copy_input (:: typeof (left_null_qr), A) = MatrixAlgebraKit. copy_input (left_null, A)
521+ MatrixAlgebraKit. copy_input (:: typeof (right_orth_lq), A) = MatrixAlgebraKit. copy_input (right_orth, A)
530522MatrixAlgebraKit. copy_input (:: typeof (right_orth_polar), A) = MatrixAlgebraKit. copy_input (right_orth, A)
531- MatrixAlgebraKit. copy_input (:: typeof (right_null_lq), A) = MatrixAlgebraKit. copy_input (right_null, A)
523+ MatrixAlgebraKit. copy_input (:: typeof (right_null_lq), A) = MatrixAlgebraKit. copy_input (right_null, A)
532524
533525@timedtestset " Orth and null with eltype $T " for T in ETs
534526 rng = StableRNG (12345 )
0 commit comments