@@ -3,6 +3,8 @@ using TensorKit
33using TensorOperations
44using VectorInterface: Zero, One
55using MatrixAlgebraKit
6+ using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!,
7+ remove_eigh_gauge_dependence!, remove_eig_gauge_dependence!, remove_svd_gauge_dependence!
68using Mooncake
79using Random
810
@@ -25,7 +27,7 @@ eltypes = (Float64, ComplexF64)
2527 # qr_full/qr_null requires being careful with gauges
2628 QR = qr_full (A)
2729 ΔQR = Mooncake. randn_tangent (rng, QR)
28- remove_qrgauge_dependence ! (ΔQR[ 1 ] , A, QR[ 1 ] )
30+ remove_qr_gauge_dependence ! (ΔQR... , A, QR... )
2931 Mooncake. TestUtils. test_rule (rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false )
3032 # TODO :
3133 # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
@@ -37,7 +39,7 @@ eltypes = (Float64, ComplexF64)
3739 # qr_full/qr_null requires being careful with gauges
3840 QR = qr_full (A)
3941 ΔQR = Mooncake. randn_tangent (rng, QR)
40- remove_qrgauge_dependence ! (ΔQR[ 1 ] , A, QR[ 1 ] )
42+ remove_qr_gauge_dependence ! (ΔQR... , A, QR... )
4143 Mooncake. TestUtils. test_rule (rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false )
4244 # TODO :
4345 # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
@@ -51,7 +53,7 @@ eltypes = (Float64, ComplexF64)
5153 # qr_full/qr_null requires being careful with gauges
5254 LQ = lq_full (A)
5355 ΔLQ = Mooncake. randn_tangent (rng, LQ)
54- remove_lqgauge_dependence ! (ΔLQ[ 2 ] , A, LQ[ 2 ] )
56+ remove_lq_gauge_dependence ! (ΔLQ... , A, LQ... )
5557 Mooncake. TestUtils. test_rule (rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false )
5658 # TODO :
5759 # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
@@ -63,7 +65,7 @@ eltypes = (Float64, ComplexF64)
6365 # qr_full/qr_null requires being careful with gauges
6466 LQ = lq_full (A)
6567 ΔLQ = Mooncake. randn_tangent (rng, LQ)
66- remove_lqgauge_dependence ! (ΔLQ[ 2 ] , A, LQ[ 2 ] )
68+ remove_lq_gauge_dependence ! (ΔLQ... , A, LQ... )
6769 Mooncake. TestUtils. test_rule (rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false )
6870 # TODO :
6971 # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
@@ -73,13 +75,13 @@ eltypes = (Float64, ComplexF64)
7375 for t in (randn (T, V[1 ] ← V[1 ]), rand (T, V[1 ] ⊗ V[2 ] ← V[1 ] ⊗ V[2 ]))
7476 DV = eig_full (t)
7577 ΔDV = Mooncake. randn_tangent (rng, DV)
76- remove_eiggauge_dependence ! (ΔDV[2 ], DV... )
78+ remove_eig_gauge_dependence ! (ΔDV[2 ], DV... )
7779 Mooncake. TestUtils. test_rule (rng, eig_full, t; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false )
7880
7981 th = project_hermitian (t)
8082 DV = eigh_full (th)
8183 ΔDV = Mooncake. randn_tangent (rng, DV)
82- remove_eighgauge_dependence ! (ΔDV[2 ], DV... )
84+ remove_eigh_gauge_dependence ! (ΔDV[2 ], DV... )
8385 Mooncake. TestUtils. test_rule (rng, eigh_full ∘ project_hermitian, th; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false )
8486 end
8587 end
@@ -88,7 +90,7 @@ eltypes = (Float64, ComplexF64)
8890 for t in (randn (T, V[1 ] ← V[1 ]), randn (T, V[1 ] ⊗ V[2 ] ← (V[3 ] ⊗ V[4 ] ⊗ V[5 ])' ))
8991 USVᴴ = svd_compact (t)
9092 ΔUSVᴴ = Mooncake. randn_tangent (rng, USVᴴ)
91- remove_svdgauge_dependence ! (ΔUSVᴴ[1 ], ΔUSVᴴ[3 ], USVᴴ... )
93+ remove_svd_gauge_dependence ! (ΔUSVᴴ[1 ], ΔUSVᴴ[3 ], USVᴴ... )
9294 Mooncake. TestUtils. test_rule (rng, svd_compact, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false )
9395
9496 # USVᴴ = svd_full(t)
@@ -101,7 +103,7 @@ eltypes = (Float64, ComplexF64)
101103 alg = MatrixAlgebraKit. select_algorithm (svd_trunc, t, nothing ; trunc)
102104 USVᴴtrunc = svd_trunc (t, alg)
103105 ΔUSVᴴtrunc = (Mooncake. randn_tangent (rng, Base. front (USVᴴtrunc))... , zero (last (USVᴴtrunc)))
104- remove_svdgauge_dependence ! (ΔUSVᴴtrunc[1 ], ΔUSVᴴtrunc[3 ], Base. front (USVᴴtrunc)... )
106+ remove_svd_gauge_dependence ! (ΔUSVᴴtrunc[1 ], ΔUSVᴴtrunc[3 ], Base. front (USVᴴtrunc)... )
105107 Mooncake. TestUtils. test_rule (rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode)
106108 end
107109 end
0 commit comments