Skip to content

Commit 38f39e6

Browse files
committed
split up chainrules tests
1 parent 48cccb7 commit 38f39e6

3 files changed

Lines changed: 459 additions & 295 deletions

File tree

Lines changed: 0 additions & 295 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,6 @@ using Zygote
1111
using MatrixAlgebraKit
1212
using MatrixAlgebraKit: LAPACK_HouseholderQR, LAPACK_HouseholderLQ, diagview
1313

14-
const _repartition = @static if isdefined(Base, :get_extension)
15-
Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition
16-
else
17-
TensorKit.TensorKitChainRulesCoreExt._repartition
18-
end
19-
2014
# Test utility
2115
# -------------
2216
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap)
@@ -40,12 +34,6 @@ end
4034
precision(::Type{<:Union{Float32, Complex{Float32}}}) = 1.0e-2
4135
precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-5
4236

43-
function randindextuple(N::Int, k::Int = rand(0:N))
44-
@assert 0 k N
45-
_p = randperm(N)
46-
return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...))
47-
end
48-
4937
function test_ad_rrule(f, args...; check_inferred = false, kwargs...)
5038
test_rrule(
5139
Zygote.ZygoteRuleConfig(), f, args...;
@@ -133,8 +121,6 @@ end
133121
# Tests
134122
# -----
135123

136-
ChainRulesTestUtils.test_method_tables()
137-
138124
spacelist = (
139125
(ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
140126
(
@@ -178,250 +164,14 @@ for V in spacelist
178164
I = sectortype(eltype(V))
179165
Istr = type_repr(I)
180166
eltypes = isreal(sectortype(eltype(V))) ? (Float64, ComplexF64) : (ComplexF64,)
181-
symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding
182167
println("---------------------------------------")
183168
println("Auto-diff with symmetry: $Istr")
184169
println("---------------------------------------")
185170
@timedtestset "AD with symmetry $Istr" verbose = true begin
186171
V1, V2, V3, V4, V5 = V
187172
W = V1 V2
188-
@timedtestset "Basic utility" begin
189-
T1 = randn(Float64, V[1] V[2] V[3] V[4])
190-
T2 = randn(ComplexF64, V[1] V[2] V[3] V[4])
191-
192-
P1 = ProjectTo(T1)
193-
@test P1(T1) == T1
194-
@test P1(T2) == real(T2)
195-
196-
test_rrule(copy, T1)
197-
test_rrule(copy, T2)
198-
test_rrule(TensorKit.copy_oftype, T1, ComplexF64)
199-
if symmetricbraiding
200-
test_rrule(convert, Array, T1)
201-
test_rrule(
202-
TensorMap, convert(Array, T1), codomain(T1), domain(T1);
203-
fkwargs = (; tol = Inf)
204-
)
205-
end
206-
207-
test_rrule(Base.getproperty, T1, :data)
208-
test_rrule(TensorMap{scalartype(T1)}, T1.data, T1.space)
209-
test_rrule(Base.getproperty, T2, :data)
210-
test_rrule(TensorMap{scalartype(T2)}, T2.data, T2.space)
211-
end
212-
213-
@timedtestset "Basic utility (DiagonalTensor)" begin
214-
for v in V
215-
rdim = reduceddim(v)
216-
D1 = DiagonalTensorMap(randn(rdim), v)
217-
D2 = DiagonalTensorMap(randn(rdim), v)
218-
D = D1 + im * D2
219-
T1 = TensorMap(D1)
220-
T2 = TensorMap(D2)
221-
T = T1 + im * T2
222-
223-
# real -> real
224-
P1 = ProjectTo(D1)
225-
@test P1(D1) == D1
226-
@test P1(T1) == D1
227-
228-
# complex -> complex
229-
P2 = ProjectTo(D)
230-
@test P2(D) == D
231-
@test P2(T) == D
232-
233-
# real -> complex
234-
@test P2(D1) == D1 + 0 * im * D1
235-
@test P2(T1) == D1 + 0 * im * D1
236-
237-
# complex -> real
238-
@test P1(D) == D1
239-
@test P1(T) == D1
240-
241-
test_rrule(DiagonalTensorMap, D1.data, D1.domain)
242-
test_rrule(DiagonalTensorMap, D.data, D.domain)
243-
test_rrule(Base.getproperty, D, :data)
244-
test_rrule(Base.getproperty, D1, :data)
245-
246-
test_rrule(DiagonalTensorMap, rand!(T1))
247-
test_rrule(DiagonalTensorMap, randn!(T))
248-
end
249-
end
250-
251-
@timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes
252-
A = randn(T, V[1] V[2] V[3] V[4] V[5])
253-
B = randn(T, space(A))
254-
255-
test_rrule(real, A)
256-
test_rrule(imag, A)
257-
258-
test_rrule(+, A, B)
259-
test_rrule(-, A)
260-
test_rrule(-, A, B)
261-
262-
α = randn(T)
263-
test_rrule(*, α, A)
264-
test_rrule(*, A, α)
265-
266-
C = randn(T, domain(A), codomain(A))
267-
test_rrule(*, A, C)
268-
269-
test_rrule(transpose, A, ((2, 5, 4), (1, 3)))
270-
symmetricbraiding && test_rrule(permute, A, ((1, 3, 2), (5, 4)))
271-
test_rrule(twist, A, 1)
272-
test_rrule(twist, A, [1, 3])
273-
274-
test_rrule(flip, A, 1)
275-
test_rrule(flip, A, [1, 3, 4])
276-
277-
D = randn(T, V[1] V[2] V[3])
278-
E = randn(T, V[4] V[5])
279-
symmetricbraiding && test_rrule(, D, E)
280-
end
281-
282-
@timedtestset "Linear Algebra part II with scalartype $T" for T in eltypes
283-
atol = precision(T)
284-
rtol = precision(T)
285-
for i in 1:3
286-
E = randn(T, (V[1:i]...) (V[1:i]...))
287-
test_rrule(LinearAlgebra.tr, E; atol, rtol)
288-
test_rrule(exp, E; check_inferred = false, atol, rtol)
289-
test_rrule(inv, E; atol, rtol)
290-
end
291-
292-
A = randn(T, V[1] V[2] V[3] V[4] V[5])
293-
test_rrule(LinearAlgebra.adjoint, A; atol, rtol)
294-
test_rrule(LinearAlgebra.norm, A, 2; atol, rtol)
295-
296-
B = randn(T, space(A))
297-
test_rrule(LinearAlgebra.dot, A, B; atol, rtol)
298-
end
299-
300-
@timedtestset "Matrix functions ($T)" for T in eltypes
301-
atol = precision(T)
302-
rtol = precision(T)
303-
for f in (sqrt, exp)
304-
check_inferred = false # !(T <: Real) # not type-stable for real functions
305-
t1 = randn(T, V[1] V[1])
306-
t2 = randn(T, V[2] V[2])
307-
d = DiagonalTensorMap{T}(undef, V[1])
308-
d2 = DiagonalTensorMap{T}(undef, V[1])
309-
d3 = DiagonalTensorMap{T}(undef, V[1])
310-
if (T <: Real && f === sqrt)
311-
# ensuring no square root of negative numbers
312-
randexp!(d.data)
313-
d.data .+= 5
314-
randexp!(d2.data)
315-
d2.data .+= 5
316-
randexp!(d3.data)
317-
d3.data .+= 5
318-
else
319-
randn!(d.data)
320-
randn!(d2.data)
321-
randn!(d3.data)
322-
end
323-
324-
test_rrule(f, t1; rrule_f = Zygote.rrule_via_ad, check_inferred, atol, rtol)
325-
test_rrule(f, t2; rrule_f = Zygote.rrule_via_ad, check_inferred, atol, rtol)
326-
test_rrule(f, d d2; check_inferred, output_tangent = d3, atol, rtol)
327-
end
328-
end
329-
330-
symmetricbraiding &&
331-
@timedtestset "TensorOperations with scalartype $T" for T in eltypes
332-
atol = precision(T)
333-
rtol = precision(T)
334-
335-
@timedtestset "tensortrace!" begin
336-
for _ in 1:5
337-
k1 = rand(0:2)
338-
k2 = rand(1:2)
339-
V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1))
340-
V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2))
341-
342-
(_p, _q) = randindextuple(k1 + 2 * k2, k1)
343-
p = _repartition(_p, rand(0:k1))
344-
q = _repartition(_q, k2)
345-
ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2)))
346-
A = randn(T, permute(prod(V1) prod(V2) prod(V2), ip))
347-
348-
α = randn(T)
349-
β = randn(T)
350-
for conjA in (false, true)
351-
C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA, Val(false)))
352-
test_rrule(tensortrace!, C, A, p, q, conjA, α, β; atol, rtol)
353-
end
354-
end
355-
end
356-
357-
@timedtestset "tensoradd!" begin
358-
A = randn(T, V[1] V[2] V[4] V[5])
359-
α = randn(T)
360-
β = randn(T)
361-
362-
# repeat a couple times to get some distribution of arrows
363-
for _ in 1:5
364-
p = randindextuple(numind(A))
365-
366-
C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false)))
367-
test_rrule(tensoradd!, C1, A, p, false, α, β; atol, rtol)
368-
369-
C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false)))
370-
test_rrule(tensoradd!, C2, A, p, true, α, β; atol, rtol)
371-
372-
A = rand(Bool) ? C1 : C2
373-
end
374-
end
375-
376-
@timedtestset "tensorcontract!" begin
377-
for _ in 1:5
378-
d = 0
379-
local V1, V2, V3
380-
# retry a couple times to make sure there are at least some nonzero elements
381-
for _ in 1:10
382-
k1 = rand(0:3)
383-
k2 = rand(0:2)
384-
k3 = rand(0:2)
385-
V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1]))
386-
V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1]))
387-
V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1]))
388-
d = min(dim(V1 V2), dim(V1' V2), dim(V2 V3), dim(V2' V3))
389-
d > 0 && break
390-
end
391-
ipA = randindextuple(length(V1) + length(V2))
392-
pA = _repartition(invperm(linearize(ipA)), length(V1))
393-
ipB = randindextuple(length(V2) + length(V3))
394-
pB = _repartition(invperm(linearize(ipB)), length(V2))
395-
pAB = randindextuple(length(V1) + length(V3))
396-
397-
α = randn(T)
398-
β = randn(T)
399-
V2_conj = prod(conj, V2; init = one(V[1]))
400-
401-
for conjA in (false, true), conjB in (false, true)
402-
A = randn(T, permute(V1 (conjA ? V2_conj : V2), ipA))
403-
B = randn(T, permute((conjB ? V2_conj : V2) V3, ipB))
404-
C = randn!(
405-
TensorOperations.tensoralloc_contract(
406-
T, A, pA, conjA, B, pB, conjB, pAB, Val(false)
407-
)
408-
)
409-
test_rrule(
410-
tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β;
411-
atol, rtol
412-
)
413-
end
414-
end
415-
end
416-
417-
@timedtestset "tensorscalar" begin
418-
A = randn(T, ProductSpace{typeof(V[1]), 0}())
419-
test_rrule(tensorscalar, A)
420-
end
421-
end
422173

423174
@timedtestset "Factorizations" begin
424-
W = V[1] V[2]
425175
@testset "QR" begin
426176
for T in eltypes,
427177
t in (
@@ -456,16 +206,6 @@ for V in spacelist
456206
fkwargs, atol, rtol, output_tangent = ΔQ
457207
)
458208
test_ad_rrule(last qr_full, t; fkwargs, atol, rtol, output_tangent = ΔR)
459-
460-
# TODO: figure out the following:
461-
# N = qr_null(t)
462-
# ΔN = Q * rand(T, domain(Q) ← domain(N))
463-
# test_ad_rrule(qr_null, t; fkwargs, atol, rtol, output_tangent=ΔN)
464-
465-
# if fuse(domain(t)) ≺ fuse(codomain(t))
466-
# _, null_pb = Zygote.pullback(qr_null, t)
467-
# @test_logs (:warn, r"^`qr") match_mode = :any null_pb(rand_tangent(N))
468-
# end
469209
end
470210
end
471211

@@ -504,17 +244,6 @@ for V in spacelist
504244
fkwargs, atol, rtol, output_tangent = ΔL
505245
)
506246
test_ad_rrule(last lq_full, t; fkwargs, atol, rtol, output_tangent = ΔQ)
507-
508-
# TODO: figure out the following
509-
# Nᴴ = lq_null(t)
510-
# ΔN = rand(T, codomain(Nᴴ) ← codomain(Q)) * Q
511-
# test_ad_rrule(lq_null, t; fkwargs, atol, rtol, output_tangent=Nᴴ)
512-
513-
# if fuse(codomain(t)) ≺ fuse(domain(t))
514-
# _, null_pb = Zygote.pullback(lq_null, t)
515-
# # broken due to typo in MAK
516-
# # @test_logs (:warn, r"^`lq") match_mode = :any null_pb(rand_tangent(Nᴴ))
517-
# end
518247
end
519248
end
520249

@@ -614,17 +343,6 @@ for V in spacelist
614343
@test g1 g2
615344
end
616345
end
617-
618-
# let D = LinearAlgebra.eigvals(C)
619-
# ΔD = diag(randn(complex(scalartype(C)), space(C)))
620-
# test_rrule(LinearAlgebra.eigvals, C; atol, output_tangent=ΔD,
621-
# fkwargs=(; sortby=nothing))
622-
# end
623-
624-
# let S = LinearAlgebra.svdvals(C)
625-
# ΔS = diag(randn(real(scalartype(C)), space(C)))
626-
# test_rrule(LinearAlgebra.svdvals, C; atol, output_tangent=ΔS)
627-
# end
628346
end
629347
end
630348
end
@@ -657,16 +375,3 @@ end
657375
grad4, = Zygote.gradient(g, convert(Array, B₀))
658376
@test convert(Array, grad3) grad4
659377
end
660-
661-
# https://github.com/quantumkithub/TensorKit.jl/issues/209
662-
@testset "Issue #209" begin
663-
function f(T, D)
664-
@tensor T[1, 4, 1, 3] * D[3, 4]
665-
end
666-
V = Z2Space(2, 2)
667-
D = DiagonalTensorMap(randn(4), V)
668-
T = randn(V V V V)
669-
g1, = Zygote.gradient(f, T, D)
670-
g2, = Zygote.gradient(f, T, TensorMap(D))
671-
@test g1 g2
672-
end

0 commit comments

Comments
 (0)